In [None]:
import torch 
import numpy as np
import torch.utils.data as dt
from proteindataset import ProteinDataset
from sklearn.metrics import average_precision_score

In [None]:
def set_all_seeds(SEED):
    # REPRODUCIBILITY
    torch.manual_seed(SEED)
    np.random.seed(SEED)

In [None]:
TRAIN_DIR = '../input/human-protein-atlas-image-classification/train/'
TEST_DIR = '../input/human-protein-atlas-image-classification/test/'
LABELS = '../input/hpa-dataset-models/train_upsampled.csv'
dataset = ProteinDataset(TRAIN_DIR, LABELS, image_size=512)

In [None]:
!pip install timm

In [None]:
import timm
from tqdm import tqdm

# Create dataloaders
batch_size = 10
dataset_len = dataset.__len__()
train_size = int(dataset_len*0.8)
if train_size % batch_size==1:
        train_size += 1

val_size = dataset_len - train_size
train_set, val_set = torch.utils.data.random_split(dataset, [train_size, val_size])
    
trainloader = dt.DataLoader(train_set, batch_size=batch_size)
testloader = dt.DataLoader(val_set, batch_size=batch_size)   

# paramaters
SEED = 123
lr = 1e-3
weight_decay = 1e-5

# model
set_all_seeds(SEED)
model= timm.create_model('efficientnet_b4',pretrained=True,num_classes=28,in_chans=4)
model.load_state_dict(torch.load('../input/hpa-dataset-models/modified_pretrained_model_kaggle_2_1_last.pth'))
model = model.cuda()

# define loss & optimizer
criterion = torch.nn.BCEWithLogitsLoss().cuda()
optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)
scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.7, verbose=True)
# training
best_val_score = 0 
for epoch in range(10):
    model.train()
    batch_index = 0
    epoch_loss = []
    for data in tqdm(trainloader):
        batch_index +=1
        train_data, train_labels = data
        train_data, train_labels  = train_data.cuda(), train_labels.cuda()
        y_pred = model(train_data)
        loss = criterion(y_pred, train_labels)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
          
        epoch_loss.append(loss.item())
        torch.save(model.state_dict(), f'model_kaggle_{epoch}.pth')
        # validation  
    test_loss=[]
    model.eval()
    with torch.no_grad():    
        test_pred = []
        test_true = [] 
        for data in tqdm(testloader):
                    test_data, test_labels = data
                    test_data = test_data.cuda()
                    test_labels = test_labels.cuda()
                    y_pred = model(test_data)
                    test_batch_loss = criterion(y_pred, test_labels)
                    test_loss.append(test_batch_loss.item())
                    y_pred = torch.sigmoid(y_pred)
                    test_pred.append(y_pred.cpu().detach().numpy())
                    test_true.append(test_labels.cpu().detach().numpy())
                    
                    
        test_true = np.concatenate(test_true)
        test_pred = np.concatenate(test_pred)
        val_metric =  average_precision_score(test_true, test_pred, average='macro') 
        model.train()

        if best_val_score < val_metric:
            best_val_score = val_metric
            torch.save(model.state_dict(), f'model_kaggle_best.pth')

        print ('Epoch=%s, Val_average_precision_score=%.4f, Best_Val_score=%.4f'%(epoch, val_metric, best_val_score ))
    scheduler.step() 
    print (f"Epoch loss {np.mean(epoch_loss)}     Val loss {np.mean(test_loss)}")

In [None]:
optimizer = torch.optim.Adam(model.parameters(), lr=1e-5, weight_decay=weight_decay)

# training
best_val_score = 0 
for epoch in range(5):
    model.train()
    batch_index = 0
    epoch_loss = []
    for data in tqdm(trainloader):
        batch_index +=1
        train_data, train_labels = data
        train_data, train_labels  = train_data.cuda(), train_labels.cuda()
        y_pred = model(train_data)
        loss = criterion(y_pred, train_labels)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
          
        epoch_loss.append(loss.item())
        torch.save(model.state_dict(), f'modified_pretrained_model_kaggle_2_{epoch}.pth')
        # validation  
    test_loss=[]
    model.eval()
    with torch.no_grad():    
        test_pred = []
        test_true = [] 
        for data in tqdm(testloader):
                    test_data, test_labels = data
                    test_data = test_data.cuda()
                    test_labels = test_labels.cuda()
                    y_pred = model(test_data)
                    test_batch_loss = criterion(y_pred, test_labels)
                    test_loss.append(test_batch_loss.item())
                    y_pred = torch.sigmoid(y_pred)
                    test_pred.append(y_pred.cpu().detach().numpy())
                    test_true.append(test_labels.cpu().detach().numpy())
                    
                    
        test_true = np.concatenate(test_true)
        test_pred = np.concatenate(test_pred)
        val_metric =  average_precision_score(test_true, test_pred, average='macro') 
        model.train()

        if best_val_score < val_metric:
            best_val_score = val_metric
            torch.save(model.state_dict(), f'modified_pretrained_model_kaggle_best.pth')

        print ('Epoch=%s, Val_average_precision_score=%.4f, Best_Val_score=%.4f'%(epoch, val_metric, best_val_score ))
    print (f"Epoch loss {np.mean(epoch_loss)}     Val loss {np.mean(test_loss)}")