# Imports

In [None]:
import os
from torch.utils.data import Dataset, DataLoader
import cv2
import numpy as np
import torchvision.models as models
import torch
from torch import nn
from sklearn.metrics import f1_score,precision_score,recall_score,confusion_matrix,ConfusionMatrixDisplay,accuracy_score,roc_auc_score
import copy
import tqdm
import time
import torchvision
from collections import defaultdict
from torch.optim import lr_scheduler
import wandb
import warnings; warnings.simplefilter('ignore')
from albumentations.pytorch import ToTensorV2
import albumentations as A
import gc

In [None]:
patch_paths={
    'train_positive':'../data/train_patient_level/positive_patch_overlapped_v4/',
    'train_negative':'../data/train_patient_level/negative_patch_overlapped_v4/',
    'val_positive':'../data/val_patient_level/positive_patch_overlapped_v4/',
    'val_negative':'../data/val_patient_level/negative_patch_overlapped_v4/',
    'test_positive':'../data/test_patient_level/positive_patch_overlapped_v4/',
    'test_negative':'../data/test_patient_level/negative_patch_overlapped_v4/'  
}

In [None]:
config={
    'random_seed':48,
    'IM_W':64,
    'IM_H':64,
    'Batch':10,
    'device':torch.device('cuda' if torch.cuda.is_available() else 'cpu'),
    'ds_type':'binary_cls',
    'LR':0.0001,
    'epoch':50,
    'local_contrastive_epoch':50,
    'global_contrastive_epoch':50,
    'contrastive_early_stopping':True,
    'early_stop_patience':5,
    'model':'resnet18',
    'train_type':'ffcl',
    'imgnet_pretrained':True,
    'loss_fn':'focal',
    'positive_patches_train':-1,
    'negative_patches_train':-1,
    'positive_patches_val':-1,
    'negative_patches_val':-1,
    'positive_patches_test':-1,
    'negative_patches_test':-1,
    'focal_alpha':0.8,
    'focal_gamma':3.0,
    'scheduler_warmup':1,
    'scheduler':'cos',
    'scheduler_step':15
}

config['saved_model_name']=config['ds_type']+'_'+config['model']+'_'+config['train_type']+'_'+str(len(os.listdir('saved_models'))+1)
torch.manual_seed(config['random_seed'])
np.random.seed(config['random_seed'])
config['patch_paths']=patch_paths

In [None]:
# os.environ["WANDB_NOTEBOOK_NAME"] = "FFCL_margin.ipynb"
# wandb.login()

wandb.init(
    entity='atik',
    group=config['ds_type'],
    name="FFCL: step=3 (positive), Focal, ResNet-18, random samples (New 4 data training)",
    project="margin",
    config=config
)

# Data Setup

In [None]:
if config['ds_type']=='binary_cls':
    print("Dealing with binary_cls dataset")
    positive_patch_train=os.listdir(patch_paths['train_positive'])  
    positive_patch_train=[patch_paths['train_positive']+s for s in positive_patch_train]
    
    negative_patch_train=os.listdir(patch_paths['train_negative'])
    negative_patch_train=[patch_paths['train_negative']+s for s in negative_patch_train]
    
    x_train=positive_patch_train+negative_patch_train
    y_train=[0]*len(positive_patch_train) + [1]*len(negative_patch_train)
    
    config['pos_weight']=len(os.listdir(patch_paths['train_positive']))/len(os.listdir(patch_paths['train_negative']))
    config['positive_patches_train']=len(positive_patch_train)
    config['negative_patches_train']=len(negative_patch_train)
   
    positive_patch_val=os.listdir(patch_paths['val_positive'])  
    positive_patch_val=[patch_paths['val_positive']+s for s in positive_patch_val]
    
    negative_patch_val=os.listdir(patch_paths['val_negative'])
    negative_patch_val=[patch_paths['val_negative']+s for s in negative_patch_val]
    
    x_val=positive_patch_val+negative_patch_val
    y_val=[0]*len(positive_patch_val) + [1]*len(negative_patch_val)

    config['positive_patches_val']=len(positive_patch_val)
    config['negative_patches_val']=len(negative_patch_val)
 
    positive_patch_test=os.listdir(patch_paths['test_positive']) 
    positive_patch_test=[patch_paths['test_positive']+s for s in positive_patch_test]
     
    negative_patch_test=os.listdir(patch_paths['test_negative'])
    negative_patch_test=[patch_paths['test_negative']+s for s in negative_patch_test]
    
    x_test=positive_patch_test+negative_patch_test
    y_test=[0]*len(positive_patch_test) + [1]*len(negative_patch_test)

    config['positive_patches_test']=len(positive_patch_test)
    config['negative_patches_test']=len(negative_patch_test)
    
    x_train=np.array(x_train)
    y_train=np.array(y_train)
    
    x_val=np.array(x_val)
    y_val=np.array(y_val)
    
    x_test=np.array(x_test)
    y_test=np.array(y_test)
    
config["classes"]=len(np.unique(y_train))

In [None]:
class Generate_Contrastive_data(Dataset):
    def __init__(self,length,data_type,transform=None):
        self.transform = transform
        self.length=length
        self.data_type=data_type

    def __len__(self):
        return self.length

    def __getitem__(self, idx):
        if self.data_type=='train':
          image1=cv2.imread(x_train[idx],0)
          
          another_idx=np.random.randint(0,self.length)

          if config['ds_type']=='binary_cls':
            while another_idx==idx:
              another_idx=np.random.randint(0,self.length)
          image2=cv2.imread(x_train[another_idx],0)

          if self.transform:
            image1=self.transform(image=image1)["image"]
            image2=self.transform(image=image2)["image"]
            image1=(image1-torch.min(image1))/(torch.max(image1)-torch.min(image1))
            image2=(image2-torch.min(image2))/(torch.max(image2)-torch.min(image2))

          label1=y_train[idx]
          label2=y_train[another_idx]

          return [image1,image2,label1,label2]
        
        if self.data_type=='val':
          image1=cv2.imread(x_val[idx],0)
          another_idx=np.random.randint(0,self.length)
          if config['ds_type']=='binary_cls':
            while another_idx==idx:
              another_idx=np.random.randint(0,self.length)
          image2=cv2.imread(x_val[another_idx],0)
          
          if self.transform:
            image1=self.transform(image=image1)["image"]
            image2=self.transform(image=image2)["image"]
            image1=(image1-torch.min(image1))/(torch.max(image1)-torch.min(image1))
            image2=(image2-torch.min(image2))/(torch.max(image2)-torch.min(image2))
          
          label1=y_val[idx]
          label2=y_val[another_idx]

          return [image1,image2,label1,label2]


transform_train = A.Compose(
    [
        A.Resize(width=config['IM_W'], height=config['IM_H']),
        A.HorizontalFlip(p=0.5),
        A.VerticalFlip(p=0.5),
        A.RandomRotate90(p=0.5),
        ToTensorV2(),
    ]
)

transform_test = A.Compose(
    [
       
        A.Resize(width=config['IM_W'], height=config['IM_H'],p=1),
        ToTensorV2(),
    ]
)

contrastive_train_set = Generate_Contrastive_data(length=x_train.shape[0],data_type='train', transform = transform_train)
contrastive_valid_set = Generate_Contrastive_data(length=x_val.shape[0],data_type='val', transform = transform_test)

ffa_dataloader = {
    'train': DataLoader(contrastive_train_set, batch_size=config['Batch'], shuffle=True, num_workers=0),
    'val': DataLoader(contrastive_valid_set, batch_size=config['Batch'], shuffle=False, num_workers=0)
}

In [None]:
def calc_loss(preds,labels,metrics,loss_fn,class_weights=None):
    preds=preds.squeeze()
    loss=loss_fn(preds,labels)
    # print(loss)
    metrics['loss'] += loss.data.cpu().numpy() * labels.size(0)
    return loss

def cosine_similarity(x1,x2,y1,y2,metrics=None):
    
    cos_sim=nn.CosineEmbeddingLoss()
    x1_fl=torch.flatten(x1,start_dim=1)
    x2_fl=torch.flatten(x2,start_dim=1)
    y_vectors=torch.zeros(size=(y1.shape[0],))
    y_vectors=y_vectors.to(config['device'])
    for i in range(y1.shape[0]):
        if y1[i].item()==y2[i].item():
            y_vectors[i]=1
        else:
            y_vectors[i]=-1
    cos_sim_loss=cos_sim(x1_fl,x2_fl,y_vectors)
    if metrics is not None:
        metrics['loss'] = metrics['loss']+cos_sim_loss.data.cpu().numpy() * y1.size(0)
    return cos_sim_loss

class FocalLoss(nn.Module):
    """Implemented From https://www.kaggle.com/code/hmendonca/efficientnetb4-fastai-blindness-detection?scriptVersionId=19310748&cellId=16"""
    def __init__(self, gamma=config['focal_gamma'], reduction='mean'):
        super().__init__()
        self.gamma = gamma
        self.reduction = reduction

    def forward(self, inputs, targets):
        CE_loss = nn.CrossEntropyLoss(reduction='none')(inputs, targets)
        pt = torch.exp(-CE_loss)
        F_loss = ((1 - pt)**self.gamma) * CE_loss
        if self.reduction == 'sum':
            return F_loss.sum()
        elif self.reduction == 'mean':
            return F_loss.mean()

class FocalLoss_BCE(nn.Module):
    #implemented from https://www.kaggle.com/code/bigironsphere/loss-function-library-keras-pytorch
    def __init__(self, weight=None, size_average=True):
        super(FocalLoss_BCE, self).__init__()

    def forward(self, inputs, targets, alpha=config['focal_alpha'], gamma=config['focal_gamma'], smooth=1):
        
        #comment out if your model contains a sigmoid or equivalent activation layer
        #inputs = torch.nn.Sigmoid()(inputs)       
        
        #flatten label and prediction tensors
        inputs = inputs.view(-1)
        targets = targets.view(-1)
        bce=torch.nn.BCEWithLogitsLoss(reduction='mean')
        #first compute binary cross-entropy 
        BCE = bce(inputs, targets)
        BCE_EXP = torch.exp(-BCE)
        focal_loss = alpha * (1-BCE_EXP)**gamma * BCE
                       
        return focal_loss

def calc_acc(preds,labels,metrics):
    if config["classes"]==2:
        sig=nn.Sigmoid()
        preds=sig(preds)
        preds[preds < 0.5] = 0
        preds[preds >= 0.5] = 1
    else:
        softmax=nn.Softmax()
        preds=softmax(preds)
        _, preds = torch.max(preds, 1)

    preds = preds.long().squeeze()
    preds=preds.data.cpu().numpy()
    labels=labels.data.cpu().numpy()
    preds=preds.flatten()
    labels=labels.flatten()

    acc=accuracy_score(labels,preds) * labels.shape[0]
    metrics['Acc.'] += acc

def calc_other_metric(preds,labels,metrics):
    if config["classes"]==2:
        sig=nn.Sigmoid()
        preds=sig(preds)
        preds[preds < 0.5] = 0
        preds[preds >= 0.5] = 1
    else:
        softmax=nn.Softmax()
        preds=softmax(preds)
        _, preds = torch.max(preds, 1)

    preds=preds.data.cpu().numpy()
    labels=labels.data.cpu().numpy()
    preds=preds.flatten()
    labels=labels.flatten()

    f1=f1_score(labels,preds,average = 'macro')* labels.shape[0]
    metrics['F1_Score'] +=f1

    p_score=precision_score(labels,preds,average='macro')* labels.shape[0]
    metrics['Precision'] +=p_score

    r_score=recall_score(labels,preds,average='macro')* labels.shape[0]
    metrics['Recall'] +=r_score
    
def print_metrics(metrics, epoch_samples, phase):
    outputs = []
    for k in metrics.keys():
        outputs.append("{}: {:4f}".format(k, metrics[k] / epoch_samples))

    print("{}: {}".format(phase, ", ".join(outputs)))

In [None]:
if config["model"]=="resnet18":
    model = models.resnet18(pretrained=config['imgnet_pretrained'])
    model.conv1=nn.Conv2d(in_channels=1,out_channels=64, kernel_size=(7,7),stride=(2,2),padding=(3,3),bias=False)
elif config["model"]=="resnet50":
    print("loading ResNet50")
    model = models.resnet50(pretrained=False)
elif config["model"]=="resnet34":
    print("loading resnet34")
    model = models.resnet34(pretrained=False)
elif config["model"]=="resnext50":
    print("loading resnext 50")
    model=models.resnext50_32x4d()
elif config["model"]=="vgg19":
    print("loading VGG-19")
    model=models.vgg19(pretrained=False)
elif config["model"]=="effnetv2s":
    print("loading effnetv2s")
    model=models.efficientnet_v2_s()

elif config["model"]=="vgg16":
    print("loading VGG-16")
    model=models.vgg16(pretrained=False)
model=model.to(config['device'])

all_layers = []
def get_individual_layer(network):
    for layer in network.children():
        if list(layer.children()) != []:
            get_individual_layer(layer)
        if list(layer.children()) == []: 
            all_layers.append(layer)
get_individual_layer(model)

In [None]:
train_loss_list=[]
val_loss_list=[]

def train_ffa_contrastive_model(model,opt, scheduler, num_epochs=25):
    best_model_wts = copy.deepcopy(model.state_dict())
    best_loss = 1e10

    tolerance_count=0
    loss_fn=torch.nn.CosineEmbeddingLoss() 
    for epoch in range(num_epochs):
        if config['contrastive_early_stopping']==True:
            if tolerance_count>=config['early_stop_patience']:
                print("EARLY STOPPED")
                break 
        
        print('Epoch {}/{}'.format(epoch+1, num_epochs))
        print('~' * 10)

        since = time.time()

        # Each epoch has a training and validation phase
        for phase in ['train', 'val']:
            if phase == 'train':
                
                for param_group in opt.param_groups:
                    print("LR", param_group['lr'])

                model.train()  # Set model to training mode
            else:
                model.eval()   # Set model to evaluate mode

            metrics = defaultdict(float)
            epoch_samples = 0
            batch_loss=0
            num_batch=0
            for btch,feed_dict in enumerate(tqdm.tqdm(ffa_dataloader[phase])):

                inputs1=feed_dict[0]
                inputs2=feed_dict[1]
                labels1=feed_dict[2]
                labels2=feed_dict[3]


                inputs1 = inputs1.to(config['device'])
                inputs2 = inputs2.to(config['device'])

                # labels = labels.type(torch.LongTensor)
                labels1 = labels1.type(torch.FloatTensor)
                labels2 = labels2.type(torch.FloatTensor)


                labels1 = labels1.to(config['device'])
                labels2 = labels2.to(config['device'])
                y_vectors=torch.zeros(size=(labels1.shape[0],))
                y_vectors=y_vectors.to(config['device'])
                for i in range(labels1.shape[0]):
                    if labels1[i].item()==labels2[i].item():
                        y_vectors[i]=1
                    else:
                        y_vectors[i]=-1
               
                pred1=inputs1 
                pred2=inputs2

                overall_loss=0
                num_layers=0
                with torch.set_grad_enabled(phase == 'train'):
                    #####################################Local Contrastie Learning################################
                    for i in range(len(all_layers)):     
                        # print(all_layers[i])      
                        if isinstance(all_layers[i],torch.nn.Conv2d): # check if current layer is a CNN
                            if pred1.shape[1]==all_layers[i].in_channels: # Skipping the downsample module in-case of resnet      
                                pred1=nn.ReLU()(all_layers[i].forward(pred1)) # Forward Pass
                                pred2=nn.ReLU()(all_layers[i].forward(pred2)) # Forward Pass

                                x1_fl=torch.flatten(pred1,start_dim=1)
                                x2_fl=torch.flatten(pred2,start_dim=1)

                                loss=loss_fn(x1_fl,x2_fl,y_vectors) # calculate loss for this module
                                ### Learn this module locally
                                
                                if phase == 'train':
                                    opt.zero_grad()
                                    loss.backward(retain_graph=False)
                                    opt.step()
                                overall_loss+=loss.item()
                                num_layers+=1
                                pred1=pred1.detach() # Detaching from the gradient-computation graph
                                pred2=pred2.detach() # Detaching from the gradient-computation graph
                            


                        else: #incase other layer (e.g: BatchNorm )
                            if isinstance(all_layers[i],torch.nn.Linear):
                                pred1=torch.flatten(pred1,start_dim=1)
                                pred2=torch.flatten(pred2,start_dim=1)
                                pred1=nn.ReLU()(all_layers[i].forward(pred1)) # Forward pass
                                pred2=nn.ReLU()(all_layers[i].forward(pred2)) # Forward pass
                                x1_fl=pred1
                                x2_fl=pred2
                            else: #incase it is not linear layer (e.g: batch norm)
                                pred1=all_layers[i].forward(pred1) # Forward pass
                                pred2=all_layers[i].forward(pred2) # Forward pass
                                x1_fl=torch.flatten(pred1,start_dim=1)
                                x2_fl=torch.flatten(pred2,start_dim=1)

                            if isinstance(all_layers[i],(torch.nn.Dropout,torch.nn.ReLU,torch.nn.MaxPool2d,torch.nn.AdaptiveAvgPool2d,torch.nn.SiLU,torch.nn.Sigmoid,torchvision.ops.StochasticDepth))==False:
                                #print(all_layers[i])
                                # Learning the current module except Non-learnable layers (e.g: ReLu layer), since it is not learnable and does not have any grad_fn. However, the outputs goes through this layer as well
                                # print(all_layers[i])
                                loss=loss_fn(x1_fl,x2_fl,y_vectors) # calculate loss for this module
                                if phase=='train':
                                    opt.zero_grad() 
                                    loss.backward(retain_graph=False)
                                    opt.step()
                                overall_loss+=loss.item()
                                num_layers+=1
                        
                            pred1=pred1.detach()
                            pred2=pred2.detach() 
                    
               
                overall_loss=overall_loss/num_layers
                # overall_loss+= loss_global.item()
                batch_loss+=overall_loss
                num_batch+=1
            epoch_loss=batch_loss/num_batch
            if phase=='train':
                print(phase,"Loss: ",epoch_loss)
                train_loss_list.append(epoch_loss)
                wandb.log({"train ffa loss":epoch_loss})
            if phase=='val':
                print(phase,"Loss: ",epoch_loss)
                val_loss_list.append(epoch_loss)
                if epoch_loss<best_loss:
                    torch.save(model.state_dict(), 'saved_models/'+config['saved_model_name']+'ffa_pretrained')
                    best_loss=epoch_loss
                    tolerance_count=0
                    print("Saving best model")
                else:
                    tolerance_count+=1
                wandb.log({"val ffa loss":epoch_loss})

        scheduler.step()                
        time_elapsed = time.time() - since
        print('{:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))
        #print('val loss: ',epoch_loss,' val acc: ',epoch_acc,' val f1 ',epoch_f1)

    
    print('Best val loss: {:4f}'.format(min(val_loss_list)))
    # print('Best val acc: {:4f}'.format(max(val_acc_list)))
    # print('Best val f1: {:4f}'.format(max(val_f1_list)))
    # print('Best val prec: {:4f}'.format(max(val_precision_list)))
    # print('Best val recall: {:4f}'.format(max(val_recall_list)))
     # load best model weights
    #model.load_state_dict(best_model_wts)
    return model

"""
Model is trained with adam optimizer 
and saved everytime if current epoch loss is less than the observed best validation loss
"""
# Define optimizer
optimizer_ft = torch.optim.Adam(model.parameters(), lr=config['LR'],weight_decay=1e-5)
# Define LR scheduler
cos_lr_scheduler = lr_scheduler.CosineAnnealingLR(optimizer_ft, T_max=config['local_contrastive_epoch'], eta_min=0,verbose=True)

model = train_ffa_contrastive_model(model, optimizer_ft, cos_lr_scheduler, num_epochs=config['local_contrastive_epoch'])

In [None]:
model.load_state_dict(torch.load('saved_models/'+config['saved_model_name']+'ffa_pretrained'))

In [None]:
train_loss_list=[]
val_loss_list=[]

def train_global_contrastive_model(model,opt, scheduler, num_epochs=25):
    # scaler = torch.cuda.amp.GradScaler()
    best_model_wts = copy.deepcopy(model.state_dict())
    best_loss = 1e10
    #best_acc=-1
    #best_f1 = -1
    tolerance_count=0
    loss_fn=torch.nn.CosineEmbeddingLoss() 
    for epoch in range(num_epochs):
        if config['contrastive_early_stopping']==True:
            if tolerance_count>=config['early_stop_patience']:
                print("EARLY STOPPED")
                break 
        print('Epoch {}/{}'.format(epoch+1, num_epochs))
        print('~' * 10)

        since = time.time()

        # Each epoch has a training and validation phase
        for phase in ['train', 'val']:
            if phase == 'train':
                
                for param_group in opt.param_groups:
                    print("LR", param_group['lr'])

                model.train()  # Set model to training mode
            else:
                model.eval()   # Set model to evaluate mode

            metrics = defaultdict(float)
            epoch_samples = 0
            batch_loss=0
            num_batch=0
            for btch,feed_dict in enumerate(tqdm.tqdm(ffa_dataloader[phase])):

                inputs1=feed_dict[0]    
                inputs2=feed_dict[1]
                labels1=feed_dict[2]
                labels2=feed_dict[3]


                inputs1 = inputs1.to(config['device'])
                inputs2 = inputs2.to(config['device'])

                # labels = labels.type(torch.LongTensor)
                labels1 = labels1.type(torch.FloatTensor)
                labels2 = labels2.type(torch.FloatTensor)


                labels1 = labels1.to(config['device'])
                labels2 = labels2.to(config['device'])
                y_vectors=torch.zeros(size=(labels1.shape[0],))
                y_vectors=y_vectors.to(config['device'])
                for i in range(labels1.shape[0]):
                    if labels1[i].item()==labels2[i].item():
                        y_vectors[i]=1
                    else:
                        y_vectors[i]=-1
               
                pred1=inputs1 
                pred2=inputs2

                overall_loss=0
                num_layers=0
                with torch.set_grad_enabled(phase == 'train'):
                    ###########################Global Contrastive Leraning#######################
                    pred1=model(inputs1)
                    pred2=model(inputs2)
                    loss_global=loss_fn(pred1,pred2,y_vectors)
                    if phase=='train':
                        opt.zero_grad()
                        loss_global.backward()
                        opt.step()
                # overall_loss=overall_loss/num_layers
                overall_loss+= loss_global.item()
                batch_loss+=overall_loss
                num_batch+=1
            epoch_loss=batch_loss/num_batch
            if phase=='train':
                print(phase,"Loss: ",epoch_loss)
                train_loss_list.append(epoch_loss)
                wandb.log({"train Global loss":epoch_loss})
            if phase=='val':
                print(phase,"Loss: ",epoch_loss)
                val_loss_list.append(epoch_loss)
                if epoch_loss<best_loss:
                    torch.save(model.state_dict(), 'saved_models/'+config['saved_model_name']+'global_contrastive')
                    best_loss=epoch_loss
                    tolerance_count=0
                else:
                    tolerance_count+=1
                wandb.log({"val Global loss":epoch_loss})

        scheduler.step()                
        time_elapsed = time.time() - since
        print('{:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))
        #print('val loss: ',epoch_loss,' val acc: ',epoch_acc,' val f1 ',epoch_f1)

    
    print('Best val loss: {:4f}'.format(min(val_loss_list)))
    # print('Best val acc: {:4f}'.format(max(val_acc_list)))
    # print('Best val f1: {:4f}'.format(max(val_f1_list)))
    # print('Best val prec: {:4f}'.format(max(val_precision_list)))
    # print('Best val recall: {:4f}'.format(max(val_recall_list)))
     # load best model weights
    #model.load_state_dict(best_model_wts)
    return model

"""
Model is trained with adam optimizer 

and saved everytime if current epoch loss is less than the observed best validation loss
"""
# Define optimizer
optimizer_ft = torch.optim.Adam(model.parameters(), lr=config['LR'],weight_decay=1e-5)
# Define LR scheduler
cos_lr_scheduler = lr_scheduler.CosineAnnealingLR(optimizer_ft, T_max=wandb.config['global_contrastive_epoch'], eta_min=0,verbose=True)

model = train_global_contrastive_model(model, optimizer_ft, cos_lr_scheduler, num_epochs=wandb.config['global_contrastive_epoch'])


In [None]:
model.load_state_dict(torch.load('saved_models/'+wandb.config['saved_model_name']+'global_contrastive'))

In [None]:
class Generate_data(Dataset):
    def __init__(self,length,data_type,transform=None):
        self.transform = transform
        self.length=length
        self.data_type=data_type

    def __len__(self):
        return self.length

    def __getitem__(self, idx):
        if self.data_type=='train':
          image=cv2.imread(x_train[idx],0)

          
          if self.transform:
            image=self.transform(image=image)["image"]
            image=(image-torch.min(image))/(torch.max(image)-torch.min(image))


          
          #print(image.shape)
          label=y_train[idx]
          return [image,label]
        
        if self.data_type=='val':
          image=cv2.imread(x_val[idx],0)
          

          # image=Image.fromarray(image)

          # image=torch.cuda.FloatTensor(image)

          if self.transform:
            image=self.transform(image=image)["image"]
            image=(image-torch.min(image))/(torch.max(image)-torch.min(image))

    
          label=y_val[idx]
          return [image,label]
        if self.data_type=='test':
          image=cv2.imread(x_test[idx],0)
          image=self.transform(image=image)["image"]
          image=(image-torch.min(image))/(torch.max(image)-torch.min(image))
          label=y_test[idx]
          return [image,label]




transform_train = A.Compose(
    [
        A.Resize(width=config['IM_W'], height=config['IM_H']),
        # A.RandomCrop(height=728, width=728),
        A.HorizontalFlip(p=0.5),
        A.VerticalFlip(p=0.5),
        A.RandomRotate90(p=0.5),
        # A.Blur(p=0.3),
        # A.CLAHE(p=0.3),
        # A.ColorJitter(p=0.3),
        # A.CoarseDropout(max_holes=12, max_height=20, max_width=20, p=0.3),
        # A.IAAAffine(shear=30, rotate=0, p=0.2, mode="constant"),
        # A.Normalize(
        #     mean=[0.485, 0.456, 0.406],
        #     std=[0.229, 0.224, 0.225],    
        #     max_pixel_value=255.0
        # ),
        # torchvision.transforms.ToPILImage(),
        # torchvision.transforms.ToTensor()
        ToTensorV2()
    ]
)
transform_test = A.Compose(
    [
        A.Resize(width=config['IM_W'], height=config['IM_H']),
        # A.Normalize(
        #     mean=[0.485, 0.456, 0.406],
        #     std=[0.229, 0.224, 0.225],    
        #     max_pixel_value=255.0
        # ),
        ToTensorV2()

    ]
)
train_set = Generate_data(length=x_train.shape[0],data_type='train', transform = transform_train)
valid_set = Generate_data(length=x_val.shape[0],data_type='val', transform = transform_test)
test_set = Generate_data(length=x_test.shape[0],data_type='test', transform = transform_test)
# cls_weights = class_weight.compute_class_weight(class_weight='balanced',

In [None]:
image_datasets = {
    'train': train_set, 'val': valid_set, 'test': test_set
}

dataloader = {
    'train': DataLoader(train_set, batch_size=config['Batch'], shuffle=True, num_workers=0),
    'val': DataLoader(valid_set, batch_size=config['Batch'], shuffle=True, num_workers=0),
    'test': DataLoader(test_set, batch_size=config['Batch'], shuffle=False, num_workers=0)
}

In [None]:
train_acc_list=[]
val_acc_list=[]

train_f1_list=[]
val_f1_list=[]

train_precision_list=[]
val_precision_list=[]

train_recall_list=[]
val_recall_list=[]

train_loss_list=[]
val_loss_list=[]


train_kappa_list=[]
def train_model(optimizer, scheduler,loss_fn, num_epochs=25):
    best_model_wts = copy.deepcopy(model.state_dict())
    best_loss = 1e10
    best_test_loss=1e10

    for epoch in range(num_epochs):
        
        print('Epoch {}/{}'.format(epoch+1, num_epochs))
        print('~' * 10)

        since = time.time()

        # Each epoch has a training and validation phase
        for phase in ['train', 'val']:
            if phase == 'train':
                
                for param_group in optimizer.param_groups:
                    print("LR", param_group['lr'])

                model.train()  # Set model to training mode
            else:
                model.eval()   # Set model to evaluate mode

            metrics = defaultdict(float)
            epoch_samples = 0

            for btch,feed_dict in enumerate(tqdm.tqdm(dataloader[phase])):
                inputs=feed_dict[0]
                labels=feed_dict[1]
                inputs = inputs.to(config['device'])
                if config['classes']>2:
                    labels = labels.type(torch.LongTensor)
                else:
                    labels = labels.type(torch.FloatTensor)
                labels = labels.to(config['device'])
                
                # zero the parameter gradients
                optimizer.zero_grad()

                # forward
                # track history if only in train
                with torch.set_grad_enabled(phase == 'train'):
                    outputs = model(inputs)      
                    # print(outputs.shape)          
                    # loss = ordinal_mse(outputs, labels, metrics,loss_fn)
                    # loss = l1_loss_compute(outputs,labels,metrics,loss_fn)
                    # outputs=preds2onehot(outputs)
                    loss=calc_loss(outputs,labels,metrics,loss_fn)
                    # outputs=get_one_hot_classification(outputs)

                    
                    #print(loss)

                    if phase == 'train':
                        loss.backward()
                        #print("BP")
                        optimizer.step()
                    calc_acc(outputs, labels,metrics)#no return just calc Acc.
                    calc_other_metric(outputs,labels,metrics)
                # statistics
                epoch_samples += inputs.size(0) #number of total samples in one epoch

            #after each epoch ends the following lines prints and measures the metrics
            print_metrics(metrics, epoch_samples, phase)
            epoch_loss = metrics['loss'] / epoch_samples
            epoch_acc = metrics['Acc.'] / epoch_samples
            epoch_f1 = metrics['F1_Score'] / epoch_samples
            epoch_prec= metrics['Precision'] / epoch_samples
            epoch_recall= metrics['Recall'] / epoch_samples
            # epoch_auc= metrics['auc'] / epoch_samples

            

            # deep copy the model
            if phase == 'val':
                if epoch_loss<best_loss:
                    # best_model_wts = copy.deepcopy(model.state_dict())
                    best_loss=epoch_loss
                    torch.save(model.state_dict(), 'saved_models/'+wandb.config['saved_model_name'])
                    print("saving model: ",wandb.config['saved_model_name'])

                # val_acc_list.append(epoch_acc.item())
                # val_loss_list.append(epoch_loss)
                # val_recall_list.append(epoch_recall)
                # val_precision_list.append(epoch_prec)
                # val_f1_list.append(epoch_f1)

                wandb.log({"val loss":epoch_loss,
                           "val acc":epoch_acc,
                           "val f1":epoch_f1, 
                           "val recall":epoch_recall,
                           "val prec":epoch_prec                    
                         })
            if phase=='train':
                train_acc_list.append(epoch_acc.item())
                train_loss_list.append(epoch_loss)
                train_recall_list.append(epoch_recall)
                train_precision_list.append(epoch_prec)
                train_f1_list.append(epoch_f1) 

                wandb.log({"train loss":epoch_loss,
                           "train acc":epoch_acc,
                           "train f1":epoch_f1, 
                           "train recall":epoch_recall,
                           "train prec":epoch_prec
                })
                scheduler.step()
            # if phase=='test':
            #     if epoch_loss<best_test_loss:
            #         # best_model_wts = copy.deepcopy(model.state_dict())
            #         best_test_loss=epoch_loss
            #         torch.save(model.state_dict(), 'saved_models/'+config['saved_model_name']+'_best_test')
            #     wandb.log({"test loss":epoch_loss,
            #                "test acc":epoch_acc,
            #                "test f1":epoch_f1, 
            #                "test recall":epoch_recall,
            #                "test prec":epoch_prec
            #     },step=epoch)
                           
            #   if epoch_acc>best_acc:
            #     best_acc=epoch_acc
            #   if epoch_f1>best_f1:
            #     best_f1=epoch_f1 
            #   if epoch_prec>best              
                

        time_elapsed = time.time() - since
        print('{:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))
        #print('val loss: ',epoch_loss,' val acc: ',epoch_acc,' val f1 ',epoch_f1)

    
    # print('Best val loss: {:4f}'.format(min(val_loss_list)))
    # print('Best val acc: {:4f}'.format(max(val_acc_list)))
    # print('Best val f1: {:4f}'.format(max(val_f1_list)))
    # print('Best val prec: {:4f}'.format(max(val_precision_list)))
    # print('Best val recall: {:4f}'.format(max(val_recall_list)))
     # load best model weights
    model.load_state_dict(best_model_wts)
    return model

In [None]:
if config["model"]=="vgg19" or config["model"]=="vgg16":
    num_ftrs = model.classifier[6].in_features
    if config["classes"]>2:
        model.classifier[6] = torch.nn.Linear(num_ftrs,config["classes"])
    else:
        model.classifier[6] = torch.nn.Linear(num_ftrs,1)
elif config["model"]=="effnetv2s":
    num_ftrs=model.classifier[1].in_features 
    if config["classes"]>2:
        model.classifier[1] = torch.nn.Linear(num_ftrs,config["classes"])
    else:
        model.classifier[1] = torch.nn.Linear(num_ftrs,1)      
else:
    num_ftrs = model.fc.in_features
    if config["classes"]>2:
        model.fc = torch.nn.Linear(num_ftrs,config["classes"])
    else:
        model.fc = torch.nn.Linear(num_ftrs,1)
model=model.to(config['device'])

In [None]:
"""
Model is trained with adam optimizer 
and saved everytime if current epoch loss is less than the observed best validation loss
"""
# Define optimizer
optimizer_ft = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=config['LR'],weight_decay=1e-5)
# Define LR scheduler
cos_lr_scheduler = lr_scheduler.CosineAnnealingLR(optimizer_ft, T_max=config['epoch'], eta_min=0,verbose=True)
# step_lr=lr_scheduler.StepLR(optimizer=optimizer_ft,step_size=30,verbose=True,gamma=0.1)
# Define loss function
# cls_weights=torch.cuda.FloatTensor(cls_weights)
# loss_fn=nn.MSELoss()
# loss_fn = FocalLoss()
# loss_fn=torch.nn.L1Loss()
# loss_fn=loss_fn.cuda()
# 
if config['loss_fn']=='focal':
    print("training with focal loss")
    loss_fn=FocalLoss_BCE()
elif config['loss_fn']=='bce_weight':
    loss_fn=torch.nn.BCEWithLogitsLoss(pos_weight=torch.from_numpy(np.array(config['pos_weight'])))
elif config['loss_fn']=='bce':
    print("Loading Regular BCE")
    loss_fn=torch.nn.BCEWithLogitsLoss()
    

# wandb.watch(model) #for watching convergence of the model
model = train_model(optimizer_ft, cos_lr_scheduler,loss_fn, num_epochs=config['epoch'])


In [None]:
model.load_state_dict(torch.load('saved_models/'+wandb.config['saved_model_name']))
model.eval()

In [None]:
def get_test_results():
  y_test_pred=np.empty(0)
  y_test_gt=np.empty(0)
  model.eval()
  y_test_pred_logits=np.empty(0)
  for inputs,labels in dataloader['test']:
    inputs = inputs.to(config['device'])
    outputs = model(inputs)
    # print(outputs)
    # outputs=labels
    

    if config["classes"]==2:
        sig=nn.Sigmoid()
        outputs=sig(outputs)
        # print(outputs)

        output_logits=outputs
        output_logits=output_logits.data.cpu().numpy()
        y_test_pred_logits=np.append(y_test_pred_logits,output_logits)

        # print(output_logits)
        outputs[outputs < 0.5] = 0
        outputs[outputs >= 0.5] = 1
        
    else:
        softmax=nn.Softmax()
        outputs=softmax(outputs)
        _, outputs = torch.max(outputs, 1)
    # _, outputs = torch.max(outputs, 1)#pred contains the max valued index
    outputs=outputs.data.cpu().numpy()
    
    # print(output_logits)
    # outputs=np.argmax(outputs,axis=1)
    gc.collect()
    torch.cuda.empty_cache()
    # print(y_test_pred_logits)
    # print(y_test_pred_logits.shape)
    y_test_pred=np.append(y_test_pred,outputs)
    labels=labels.data.cpu().numpy()
    y_test_gt=np.append(y_test_gt,labels)

  # print(y_test_pred_logits)
  return y_test_pred,y_test_gt,y_test_pred_logits
y_test_pred,y_test_gt,y_test_pred_logits=get_test_results()

conf_matrix=confusion_matrix(y_test_gt,y_test_pred)
conf_disp=ConfusionMatrixDisplay(conf_matrix,display_labels=set(y_test_gt))
# wandb.plot.confusion_matrix(probs=None,y_true=y_test_gt, preds=y_test_pred)

# conf_disp.plot()

# plt.savefig(config['saved_model_name']+'_'+"conf_mat_test.png")

wandb.log({"test conf" :wandb.plot.confusion_matrix(probs=None,
                        y_true=y_test_gt, preds=y_test_pred,title="test conf")
})

In [None]:
if config['classes']>2:
        print({"test acc":float(accuracy_score(y_test_gt,y_test_pred)),
                "test f1":float(f1_score(y_test_gt,y_test_pred,average = 'macro')), 
                "test recall":float(recall_score(y_test_gt,y_test_pred,average = 'macro')),
                "test prec":float(precision_score(y_test_gt,y_test_pred,average = 'macro'))# ,
                #"test auc ":float(roc_auc_score(y_test_gt,y_test_pred_logits))
        })
else:
        
        print({"test acc":float(accuracy_score(y_test_gt,y_test_pred)),
                "test f1":float(f1_score(y_test_gt,y_test_pred,average = 'macro')), 
                "test recall":float(recall_score(y_test_gt,y_test_pred,average = 'macro')),
                "test prec":float(precision_score(y_test_gt,y_test_pred,average = 'macro')) ,
                "test auc ":float(roc_auc_score(y_test_gt,y_test_pred_logits))
        })

In [None]:
test_acc=float(accuracy_score(y_test_gt,y_test_pred))
test_f1=float(f1_score(y_test_gt,y_test_pred,average='macro'))
test_recall=float(recall_score(y_test_gt,y_test_pred,average='macro'))
test_prec=float(precision_score(y_test_gt,y_test_pred,average='macro')) 

if config['classes']==2:
    test_auc =float(roc_auc_score(y_test_gt,y_test_pred_logits))
    columns = ["Test Acc", "Test F1", "Test Prec", "Test Recall", "Test Auc"]
    data = [[test_acc,test_f1,test_prec,test_recall,test_auc]]
else:
    columns = ["Test Acc", "Test F1", "Test Prec", "Test Recall"]
    data = [[test_acc,test_f1,test_prec,test_recall]]

test_table=wandb.Table(columns=columns,data=data) 
wandb.log({"Test Table": test_table})

In [None]:
wandb.finish()