In [None]:
### Installing tensorboard_logger

!pip install tensorboard_logger

In [None]:
### Importing Libraries

import numpy as np
import pandas as pd 

import os
import time
import sys
import argparse
import math
import tensorboard_logger as tb_logger
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.backends.cudnn as cudnn
import torch.nn.functional as F
import torch.utils.data as data
import torch.optim as optim
import torch.backends.cudnn as cudnn

import torchvision
import torchvision.models as models
from torchvision import transforms, datasets

from sklearn.metrics import roc_auc_score, f1_score
from scipy.sparse import csr_matrix
from PIL import Image
from __future__ import print_function

* # **Arguments**

In [None]:
### from config.config_supcon import parse_option

# ------------------------------------------------Model Name-----------------------------------------#
def parse_option(args):
    parser = argparse.ArgumentParser('argument for training')

    parser.add_argument('--print_freq', type=int, default=10,
                        help='print frequency')
    parser.add_argument('--save_freq', type=int, default=50,
                        help='save frequency')
    parser.add_argument('--batch_size', type=int, default=128,
                        help='batch_size')
    parser.add_argument('--num_workers', type=int, default=2,
                        help='num of workers to use')
    parser.add_argument('--epochs', type=int, default=100,
                        help='number of training epochs')
    parser.add_argument('--device', type=str, default='cuda:0')
    
    # optimization
    parser.add_argument('--learning_rate', type=float, default=0.05,
                        help='learning rate')
    parser.add_argument('--patient_lambda', type=float, default=1,
                        help='learning rate')
    parser.add_argument('--cluster_lambda', type=float, default=1,
                        help='learning rate')
    parser.add_argument('--lr_decay_epochs', type=str, default='100',
                        help='where to decay lr, can be a list')
    parser.add_argument('--lr_decay_rate', type=float, default=0.1,
                        help='decay rate for learning rate')
    parser.add_argument('--weight_decay', type=float, default=1e-4,
                        help='weight decay')
    parser.add_argument('--momentum', type=float, default=0.9,
                        help='momentum')
    parser.add_argument('--train_csv_path', type=str, default='train data csv')
    parser.add_argument('--test_csv_path', type=str, default='test data csv')
    parser.add_argument('--train_image_path', type=str, default='train data csv')
    parser.add_argument('--test_image_path', type=str, default='test data csv')
    parser.add_argument('--results_dir_contrastive', type=str, default='/home/kiran/Desktop/Dev/SupCon_OCT_Clinical/results.txt')
    parser.add_argument('--percentage', type=int, default=10,
                        help='momentum')
    parser.add_argument('--discrete_level', type=int, default=10,
                        help='discretization Level')
    parser.add_argument('--parallel', type=int, default=1, help='data parallel')
    
    # model dataset
    parser.add_argument('--model', type=str, default='resnet50')
    parser.add_argument('--dataset', type=str, default='TREX_DME',
                        choices=[ 'OCT', 'OCT_Cluster', 'Prime', 'PrimeBio',
                                 'Recovery_Compressed',
                                 'Prime_Comb_Bio', 'CombinedBio', 'CombinedBio_Modfied', 'Patient_Split_2_Prime_TREX',
                                 'Patient_Split_3_Prime_TREX', 'Alpha',
                                 'Prime_Compressed', 'Prime_TREX_DME_Fixed', 'Prime_TREX_Alpha',
                                 'Prime_TREX_DME_Discrete',
                                 'Recovery', 'TREX_DME'], help='dataset')
    parser.add_argument('--mean', type=str, help='mean of dataset in path in form of str tuple')
    parser.add_argument('--std', type=str, help='std of dataset in path in form of str tuple')
    parser.add_argument('--data_folder', type=str, default=None, help='path to custom dataset')
    parser.add_argument('--size', type=int, default=128, help='parameter for RandomResizedCrop')

    # method
    parser.add_argument('--num_methods', type=int, default=0,
                        help='choose method')
    parser.add_argument('--method1', type=str, default='n',
                        help='choose method')
    parser.add_argument('--method2', type=str, default='n',
                        help='choose method')
    parser.add_argument('--method3', type=str, default='n',
                        help='choose method')
    parser.add_argument('--method4', type=str, default='n',
                        help='choose method')
    parser.add_argument('--method5', type=str, default='n',
                        help='choose method')
    parser.add_argument('--alpha', type=float, default=1,
                        help='choose method')
    parser.add_argument('--gradcon_labels', type=str, default='',
                        help='choose method')
    parser.add_argument('--patient_split', type=int, default=1,
                        help='choose method')
    parser.add_argument('--quantized', type=int, default=0,
                        help='choose method')

    # temperature
    parser.add_argument('--temp', type=float, default=0.07,
                        help='temperature for loss function')

    # other setting
    parser.add_argument('--cosine', action='store_true',
                        help='using cosine annealing')
    parser.add_argument('--syncBN', action='store_true',
                        help='using synchronized batch normalization')
    parser.add_argument('--warm', action='store_true',
                        help='warm-up for large batch training')
    parser.add_argument('--trial', type=str, default='0',
                        help='id for recording multiple runs')

    opt = parser.parse_args(args)
    
# ------------------------Check if dataset is path that passed required arguments----------------------#  
    if opt.dataset == 'path':
        assert opt.data_folder is not None \
               and opt.mean is not None \
               and opt.std is not None

# ---------------------------Set the path according to the environment--------------------------------#          
    if opt.data_folder is None:
        opt.data_folder = './datasets/'
    opt.model_path = './save/SupCon/{}_models'.format(opt.dataset)
    opt.tb_path = './save/SupCon/{}_tensorboard'.format(opt.dataset)

    iterations = opt.lr_decay_epochs.split(',')
    opt.lr_decay_epochs = list([])
    for it in iterations:
        opt.lr_decay_epochs.append(int(it))
        
# ------------------------------------------------Model Name-----------------------------------------#
    opt.model_name = '{}_{}_{}_{}_{}_{}_{}_{}_{}_lr_{}_{}_decay_{}_bsz_{}_temp_{}_trial_{}_{}_{}'. \
        format(opt.method1, opt.method2, opt.method3, opt.method4, opt.method5, opt.alpha, opt.patient_split,opt.discrete_level,
               opt.dataset, opt.model, opt.learning_rate,
               opt.weight_decay, opt.batch_size, opt.temp, opt.trial, opt.gradcon_labels, opt.quantized)

    if opt.cosine:
        opt.model_name = '{}_cosine'.format(opt.model_name)
        
# -------------------------------------------Storing TensorBoard logs--------------------------------#
    opt.tb_folder = os.path.join(opt.tb_path, opt.model_name)
    if not os.path.isdir(opt.tb_folder):
        os.makedirs(opt.tb_folder)
        
# -------------------------------------------Saving model-related files------------------------------#
    opt.save_folder = os.path.join(opt.model_path, opt.model_name)
    if not os.path.isdir(opt.save_folder):
        os.makedirs(opt.save_folder)

    return opt

* # **Data loader**

In [None]:
## from utils.utils_supcon import set_loader,set_model_contrast

# from datasets.prime_trex_combined import CombinedDataset


try:
    import apex
    from apex import amp, optimizers
except ImportError:
    pass

# ---------------------------------Configuring contrastive model---------------------------------------#

def set_model_contrast(opt):

    model = SupConResNet_Original(name=opt.model)
    criterion = SupConLoss(temperature=opt.temp,device=opt.device)
    device = opt.device
    
    # enable synchronized Batch Normalization
    if opt.syncBN:
        model = apex.parallel.convert_syncbn_model(model)

    if torch.cuda.is_available():
        if opt.parallel == 1:
            model = torch.nn.DataParallel(model)
            model = model.cuda()
            criterion = criterion.cuda()
        else:
            model = model.to(device)
            criterion = criterion.to(device)
        cudnn.benchmark = True

    return model, criterion

# --------------------------------------Set Loader configurations----------------------------------#

def set_loader(opt):
    if opt.dataset == 'Prime_TREX_DME_Fixed':
        mean = (.1706)
        std = (.2112)  
    else:
        raise ValueError('dataset not supported: {}'.format(opt.dataset))

# ------------------------------------------------Model Name-----------------------------------------#
    
    normalize = transforms.Normalize(mean=mean, std=std)
    
    train_transform = transforms.Compose([
        transforms.RandomResizedCrop(size=224, scale=(0.2, 1.)),
        transforms.RandomHorizontalFlip(),
        transforms.RandomApply([transforms.ColorJitter(0.4, 0.4, 0.4, 0.1)], p=0.8),
        transforms.RandomGrayscale(p=0.2),
        transforms.ToTensor(),
        normalize,
    ])

# -------------------------------------------Assign the path-------------------------------------#
    if opt.dataset == 'Prime_TREX_DME_Fixed':
        #csv_path_train = './final_csvs_' + str(opt.patient_split) +'/datasets_combined/prime_trex_compressed.csv'
        csv_path_train = '/kaggle/input/supcon-oct-clinical-master/SupCon_OCT_Clinical-master/final_csvs_1/datasets_combined/prime_trex_compressed.csv'
        data_path_train = opt.train_image_path
        train_dataset = CombinedDataset(csv_path_train, data_path_train, transforms=TwoCropTransform(train_transform))
    
    else:
        raise ValueError(opt.dataset)
    train_sampler = None
    
# --------------------------------Load data set to the train_loader--------------------------------#
    train_loader = torch.utils.data.DataLoader(
                    train_dataset, batch_size=opt.batch_size, shuffle=True,
                    num_workers=opt.num_workers, pin_memory=True, sampler=train_sampler,drop_last=True)

    return train_loader

* # **Utils**

In [None]:
### from utils.utils import set_optimizer, adjust_learning_rate,save_model

# --------------------------Create two crops of the same image------------------------------#
class TwoCropTransform:

    def __init__(self, transform):
        self.transform = transform

    def __call__(self, x):
        return [self.transform(x), self.transform(x)]

# ----------------------Computes and stores the average and current value----------------------#
class AverageMeter(object):

    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

# ---------------------------calculate the accuracy of model predictions-------------------------#

def accuracy(output, target, topk=(1,)):
    """Computes the accuracy over the k top predictions for the specified values of k"""
    with torch.no_grad():
        maxk = max(topk)
        batch_size = target.size(0)

        _, pred = output.topk(maxk, 1, True, True)
        pred = pred.t()
        correct = pred.eq(target.view(1, -1).expand_as(pred))

        res = []
        for k in topk:
            correct_k = correct[:k].view(-1).float().sum(0, keepdim=True)
            res.append(correct_k.mul_(100.0 / batch_size))
        return res

# -----------------------------------Adjust the learning rate-----------------------------------#
def adjust_learning_rate(args, optimizer, epoch):
    
    lr = args.learning_rate
    if args.cosine:
        eta_min = lr * (args.lr_decay_rate ** 3)
        lr = eta_min + (lr - eta_min) * (
                1 + math.cos(math.pi * epoch / args.epochs)) / 2
    else:
        steps = np.sum(epoch > np.asarray(args.lr_decay_epochs))
        if steps > 0:
            lr = lr * (args.lr_decay_rate ** steps)
            
    print(lr)
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr

# -----------------------------------Warm up the learning rate-----------------------------------#
def warmup_learning_rate(args, epoch, batch_id, total_batches, optimizer):
    if args.warm and epoch <= args.warm_epochs:
        p = (batch_id + (epoch - 1) * total_batches) / \
            (args.warm_epochs * total_batches)
        lr = args.warmup_from + p * (args.warmup_to - args.warmup_from)

        for param_group in optimizer.param_groups:
            param_group['lr'] = lr

# --------------------------------------Optimizer initializing--------------------------------------#
def set_optimizer(opt, model):
    
    optimizer = optim.SGD(model.parameters(),
                          lr=opt.learning_rate,
                          momentum=opt.momentum,
                          weight_decay=opt.weight_decay)
    
    return optimizer

# ---------------------------------------------Saving the Model-------------------------------------#
def save_model(model, optimizer, opt, epoch, save_file):
    
    print('==> Saving...')
    state = {
        'opt': opt,
        'model': model.state_dict(),
        'optimizer': optimizer.state_dict(),
        'epoch': epoch,
    }
    torch.save(state, save_file)
    del state
    
# --------------------------------Load data set to the train_loader--------------------------------#
def accuracy_multilabel(output,target):
    
    output = output.detach().cpu().numpy()
    target = target.detach().cpu().numpy()
    r = roc_auc_score(target,output,multi_class='ovr')
    print(r)
    
# ------------------------------------------------------------------------------------------------#    
# def set_model_student_teacher(opt):

#     if(opt.multi == 0):
#         model_student = SupConResNet_Original_Headless(name=opt.model,use_head=False)
#         model = SupConResNet_Original_Headless(name=opt.model,use_head=False)
#         classifier = LinearClassifier(name=opt.model, num_classes=2)
#         model_teacher = nn.Sequential(model,classifier)
#         model_student = nn.Sequential(model_student,classifier)
        
#     criterion = torch.nn.CrossEntropyLoss()
#     ckpt = torch.load(opt.ckpt, map_location='cpu')
#     state_dict = ckpt['model']
#     device = opt.device
    
#     if torch.cuda.is_available():
#         if opt.parallel == 0:
#             model.encoder = torch.nn.DataParallel(model.encoder)
#         else:
#             new_state_dict = {}
#             for k, v in state_dict.items():
#                 k = k.replace("module.", "")
#                 k = '0.encoder.' + k[2:]
#                 if(k=='0.encoder.fc.weight'):
#                     k = '1.fc.weight'
#                 if (k == '0.encoder.fc.bias'):
#                     k = '1.fc.bias'
#                 #k = k.replace("encoder.", "")
#                 new_state_dict[k] = v
#             state_dict = new_state_dict
#         model_teacher = model_teacher.to(device)
#         model_student = model_student.to(device)
#         criterion = criterion.to(device)
#         cudnn.benchmark = True

#         model_teacher.load_state_dict(state_dict)

#     return model_teacher, model_student, criterion

* # **Training**

In [None]:
### from training_supcon.training_one_epoch_prime_trex_combined import train_Combined


# -------------------------------------One epoch training--------------------------------------#
def train_Combined(train_loader, model, criterion, optimizer, epoch, opt):
    model.train()

    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = AverageMeter()
    device = opt.device
    end = time.time()
    
    for idx, (images, bcva,cst,eye_id,patient) in enumerate(train_loader):
        data_time.update(time.time() - end)

        images = torch.cat([images[0], images[1]], dim=0)

        if torch.cuda.is_available():
            images = images.to(device)

        bsz = bcva.shape[0]

        # warm-up learning rate
        warmup_learning_rate(opt, epoch, idx, len(train_loader), optimizer)

        # compute loss
        features = model(images)
        f1, f2 = torch.split(features, [bsz, bsz], dim=0)
        features = torch.cat([f1.unsqueeze(1), f2.unsqueeze(1)], dim=1)

# -------------------------------------Label initializing--------------------------------------#
        ### Method 1
        if opt.method1 == 'patient':
            labels1 = patient.cuda()
        elif opt.method1 == 'bcva':
            labels1 = bcva.cuda()
        elif opt.method1 == 'cst':
            labels1 = cst.cuda()

        elif opt.method1 == 'eye_id':
            labels1 = eye_id.cuda()
        else:
            labels1 = 'Null'
            
        
        ### Method 2
        if opt.method2 == 'patient':
            labels2 = patient.cuda()
        elif opt.method2 == 'bcva':
            labels2 = bcva.cuda()
        elif opt.method2 == 'cst':
            labels2 = cst.cuda()
        elif opt.method2 == 'eye_id':
            labels2 = eye_id.cuda()
        else:
            labels2 = 'Null'
            
            
        ### Method 3
        if opt.method3 == 'patient':
            labels3 = patient.cuda()
        elif opt.method3 == 'bcva':
            labels3 = bcva.cuda()
        elif opt.method3 == 'cst':
            labels3 = cst.cuda()
        elif opt.method3 == 'eye_id':
            labels3 = eye_id.cuda()
        else:
            labels3 = 'Null'
            
            
        ### Method 4
        if opt.method4 == 'patient':
            labels4 = patient.cuda()
        elif opt.method4 == 'bcva':
            labels4 = bcva.cuda()
        elif opt.method4 == 'cst':
            labels4 = cst.cuda()
        elif opt.method4 == 'eye_id':
            labels4 = eye_id.cuda()
        else:
            labels4 = 'Null'
            
            
        ### Method 5
        if opt.method5 == 'patient':
            labels5 = patient.cuda()
        elif opt.method5 == 'bcva':
            labels5 = bcva.cuda()
        elif opt.method5 == 'cst':
            labels5 = cst.cuda()
        elif opt.method5 == 'eye_id':
            labels5 = eye_id.cuda()
        else:
            labels5 = 'Null'

# -------------------------------------Loss criterion selecting--------------------------------------#       
        if(opt.num_methods == 0):
            loss = criterion(features)
        elif(opt.num_methods==1):
            if (opt.method1 == 'HCL'):
                loss = simclr_loss_func_hard(features, f1, f2)
            else:
                loss= criterion(features,labels1)
        elif(opt.num_methods == 2):
            if(opt.method2 == 'SuperClass'):
                criterion = SupConLoss_SuperClassDistance(temperature=.07)
                loss = criterion(features,super_labels = labels1)
            elif(opt.method2 == 'SuperClass_Combined'):
                criterion2 = SupConLoss_SuperClassDistance(temperature=.07)
                loss = criterion(features) +  criterion2(features,super_labels = labels1)
            else:
                loss = criterion(features,labels1) + criterion(features,labels2)
        elif(opt.num_methods == 3):
            loss = criterion(features,labels1) + criterion(features,labels2) + criterion(features,labels3)
        elif (opt.num_methods == 4):
            loss = criterion(features, labels1) + criterion(features, labels2) + criterion(features, labels3) + criterion(features,labels4)
        elif (opt.num_methods == 5):
            loss = criterion(features, labels1) + criterion(features, labels2) + criterion(features, labels3) + criterion(features,labels4) + criterion(features,labels5)
        else:
            loss = 'Null'
            

# -------------------------------------Loss,Optimizer updating--------------------------------------#
        # update metric
        losses.update(loss.item(), bsz)

        # SGD
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

# -------------------------------------------Print info----------------------------------------------#
        if (idx + 1) % opt.print_freq == 0:
            print('Train: [{0}][{1}/{2}]\t'
                  'BT {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                  'DT {data_time.val:.3f} ({data_time.avg:.3f})\t'
                  'loss {loss.val:.3f} ({loss.avg:.3f})'.format(
                   epoch, idx + 1, len(train_loader), batch_time=batch_time,
                   data_time=data_time, loss=losses))
            sys.stdout.flush()

    return losses.avg




* # **Data reading**

In [None]:
# from datasets.prime_trex_combined import CombinedDataset

# -------------------------------------One epoch training--------------------------------------#
class CombinedDataset(data.Dataset):
    def __init__(self,df, img_dir, transforms):
        self.img_dir = img_dir
        self.transforms = transforms
        self.df = pd.read_csv(df)
    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        
        temp_path = self.df.iloc[idx,0][0:9]
        if temp_path == "/TREX DME":
            path = self.img_dir + "/TREX_DME" + self.df.iloc[idx,0]
        else:
            path = self.img_dir  + self.df.iloc[idx,0]
            
        image = Image.open(path).convert("L")
        image = np.array(image)
        image = Image.fromarray(image)
        image = self.transforms(image)

        bcva=self.df.iloc[idx,1]
        cst = self.df.iloc[idx, 2]
        eye = self.df.iloc[idx, 3]
        patient = self.df.iloc[idx, 4]


        return image, bcva,cst,eye,patient

* #  **Model**

In [None]:
# from models.resnet import SupConResNet, SupConResNet_Original


# -------------------Fundamental building block within a neural network architecture--------------------#
class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, in_planes, planes, stride=1, is_last=False):
        super(BasicBlock, self).__init__()
        self.is_last = is_last
        self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != self.expansion * planes:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(self.expansion * planes)
            )

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out += self.shortcut(x)
        preact = out
        out = F.relu(out)
        if self.is_last:
            return out, preact
        else:
            return out

# -------------------------Building block for deep convolutional neural networks-----------------------#
class Bottleneck(nn.Module):
    expansion = 4

    def __init__(self, in_planes, planes, stride=1, is_last=False):
        super(Bottleneck, self).__init__()
        self.is_last = is_last
        self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)
        self.conv3 = nn.Conv2d(planes, self.expansion * planes, kernel_size=1, bias=False)
        self.bn3 = nn.BatchNorm2d(self.expansion * planes)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != self.expansion * planes:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(self.expansion * planes)
            )

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = F.relu(self.bn2(self.conv2(out)))
        out = self.bn3(self.conv3(out))
        out += self.shortcut(x)
        preact = out
        out = F.relu(out)
        if self.is_last:
            return out, preact
        else:
            return out

# ---------------------------------Resnet model initializing-----------------------------------#
class ResNet(nn.Module):
    def __init__(self, block, num_blocks, in_channel=1, zero_init_residual=False):
        super(ResNet, self).__init__()
        self.in_planes = 64

        #self.conv1 = nn.Conv2d(in_channel, 64, kernel_size=3, stride=1, padding=1,bias=False)
        self.conv1 = nn.Conv2d(in_channel, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)
        self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
        self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
        self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
            elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)

        # Zero-initialize the last BN in each residual branch,
        # so that the residual branch starts with zeros, and each residual block behaves
        # like an identity. This improves the model by 0.2~0.3% according to:
        # https://arxiv.org/abs/1706.02677
        if zero_init_residual:
            for m in self.modules():
                if isinstance(m, Bottleneck):
                    nn.init.constant_(m.bn3.weight, 0)
                elif isinstance(m, BasicBlock):
                    nn.init.constant_(m.bn2.weight, 0)

    def _make_layer(self, block, planes, num_blocks, stride):
        strides = [stride] + [1] * (num_blocks - 1)
        layers = []
        for i in range(num_blocks):
            stride = strides[i]
            layers.append(block(self.in_planes, planes, stride))
            self.in_planes = planes * block.expansion
        return nn.Sequential(*layers)

    def forward(self, x, layer=100):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.layer1(out)

        out = self.layer2(out)

        out = self.layer3(out)

        out = self.layer4(out)

        out = self.avgpool(out)
        out = torch.flatten(out, 1)
        return out

# -------------------------------Resnet model information--------------------------------#
def resnet18(**kwargs):
    return ResNet(BasicBlock, [2, 2, 2, 2], **kwargs)


def resnet34(**kwargs):
    return ResNet(BasicBlock, [3, 4, 6, 3], **kwargs)


def resnet50(**kwargs):
    return ResNet(Bottleneck, [3, 4, 6, 3], **kwargs)


def resnet101(**kwargs):
    return ResNet(Bottleneck, [3, 4, 23, 3], **kwargs)


model_dict = {
    'resnet18': [resnet18, 512],
    'resnet34': [resnet34, 512],
    'resnet50': [resnet50, 2048],
    'resnet101': [resnet101, 2048],
}

# --------------------------------------Used Resnet Model-------------------------------------#
class SupConResNet_Original(nn.Module):
    def __init__(self, name='resnet50',head='mlp',feat_dim=128,use_head=True):
        super(SupConResNet_Original,self).__init__()
        self.use_head = use_head
        if(name == 'resnet50'):
            self.encoder = torchvision.models.resnet50(zero_init_residual=True)
            self.encoder.conv1 = nn.Conv2d(1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
            self.encoder.fc = nn.Identity()

            if head == 'linear':
                self.head = nn.Linear(2048, feat_dim)
            elif head == 'mlp':
                self.head = nn.Sequential(
                    nn.Linear(2048, 2048),
                    nn.ReLU(inplace=True),
                    nn.Linear(2048, feat_dim)
                )
            else:
                raise NotImplementedError(
                    'head not supported: {}'.format(head))

        else:
            self.encoder = torchvision.models.resnet18(zero_init_residual=True)
            self.encoder.conv1 = nn.Conv2d(1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
            self.encoder.fc = nn.Identity()

            if head == 'linear':
                self.head = nn.Linear(512, feat_dim)
            elif head == 'mlp':
                self.head = nn.Sequential(
                    nn.Linear(512, 512),
                    nn.ReLU(inplace=True),
                    nn.Linear(512, feat_dim)
                )
            else:
                raise NotImplementedError(
                    'head not supported: {}'.format(head))

    def forward(self, x):
        feat = self.encoder(x)

        feat = F.normalize(self.head(feat), dim=1)
        return feat
    

# class LinearBatchNorm(nn.Module):
#     """Implements BatchNorm1d by BatchNorm2d, for SyncBN purpose"""
#     def __init__(self, dim, affine=True):
#         super(LinearBatchNorm, self).__init__()
#         self.dim = dim
#         self.bn = nn.BatchNorm2d(dim, affine=affine)

#     def forward(self, x):
#         x = x.view(-1, self.dim, 1, 1)
#         x = self.bn(x)
#         x = x.view(-1, self.dim)
#         return x

* # **Loss**

In [None]:
# from loss.loss import SupConLoss

class SupConLoss(nn.Module):
    """Supervised Contrastive Learning: https://arxiv.org/pdf/2004.11362.pdf.
    It also supports the unsupervised contrastive loss in SimCLR"""
    def __init__(self, device='cuda:0',temperature=0.07, contrast_mode='all',
                 base_temperature=0.07):
        super(SupConLoss, self).__init__()
        self.temperature = temperature
        self.contrast_mode = contrast_mode
        self.base_temperature = base_temperature
        self.device = device
    def forward(self, features, labels=None, mask=None):
        """Compute loss for model. If both `labels` and `mask` are None,
        it degenerates to SimCLR unsupervised loss:
        https://arxiv.org/pdf/2002.05709.pdf
        Args:
            features: hidden vector of shape [bsz, n_views, ...].
            labels: ground truth of shape [bsz].
            mask: contrastive mask of shape [bsz, bsz], mask_{i,j}=1 if sample j
                has the same class as sample i. Can be asymmetric.
        Returns:
            A loss scalar.
        """
        device = self.device

        if len(features.shape) < 3:
            raise ValueError('`features` needs to be [bsz, n_views, ...],'
                             'at least 3 dimensions are required')
        if len(features.shape) > 3:
            features = features.view(features.shape[0], features.shape[1], -1)

        batch_size = features.shape[0]
        if labels is not None and mask is not None:
            raise ValueError('Cannot define both `labels` and `mask`')
        elif labels is None and mask is None:
            mask = torch.eye(batch_size, dtype=torch.float32).to(device)
        elif labels is not None:
            labels = labels.contiguous().view(-1, 1)
            if labels.shape[0] != batch_size:
                raise ValueError('Num of labels does not match num of features')
            mask = torch.eq(labels, labels.T).float().to(device)
        else:
            mask = mask.float().to(device)

        contrast_count = features.shape[1]
        contrast_feature = torch.cat(torch.unbind(features, dim=1), dim=0)
        if self.contrast_mode == 'one':
            anchor_feature = features[:, 0]
            anchor_count = 1
        elif self.contrast_mode == 'all':
            anchor_feature = contrast_feature
            anchor_count = contrast_count
        else:
            raise ValueError('Unknown mode: {}'.format(self.contrast_mode))

        # compute logits
        anchor_dot_contrast = torch.div(
            torch.matmul(anchor_feature, contrast_feature.T),
            self.temperature)

        # for numerical stability
        logits_max, _ = torch.max(anchor_dot_contrast, dim=1, keepdim=True)
        logits = anchor_dot_contrast - logits_max.detach()

        # tile mask
        mask = mask.repeat(anchor_count, contrast_count)
        # mask-out self-contrast cases
        logits_mask = torch.scatter(
            torch.ones_like(mask),
            1,
            torch.arange(batch_size * anchor_count).view(-1, 1).to(device),
            0
        )
        mask = mask * logits_mask

        # compute log_prob
        exp_logits = torch.exp(logits) * logits_mask
        log_prob = logits - torch.log(exp_logits.sum(1, keepdim=True))

        # compute mean of log-likelihood over positive
        mean_log_prob_pos = (mask * log_prob).sum(1) / mask.sum(1)

        # loss
        loss = - (self.temperature / self.base_temperature) * mean_log_prob_pos
        loss = loss.view(anchor_count, batch_size).mean()

        return loss

* # **Main**

In [None]:
# from training_supcon.training_one_epoch_trex import train_TREX

def main():
    opt = parse_option(args)

    # build data loader
    train_loader = set_loader(opt)

    # build model and criterion
    model, criterion = set_model_contrast(opt)

    # build optimizer
    optimizer = set_optimizer(opt, model)

    # tensorboard
    logger = tb_logger.Logger(logdir=opt.tb_folder, flush_secs=2)
    # training routine
    for epoch in range(1, opt.epochs + 1):
        adjust_learning_rate(opt, optimizer, epoch)

        # train for one epoch
        time1 = time.time()

        if(opt.dataset == 'Prime_TREX_DME_Fixed'):
            loss = train_Combined(train_loader, model, criterion, optimizer, epoch, opt)
        
        time2 = time.time()
        print('epoch {}, total time {:.2f}'.format(epoch, time2 - time1))

        # tensorboard logger
        logger.log_value('loss', loss, epoch)
        logger.log_value('learning_rate', optimizer.param_groups[0]['lr'], epoch)

        if epoch % opt.save_freq == 0:
            save_file = os.path.join(
                opt.save_folder, 'ckpt_epoch_{epoch}.pth'.format(epoch=epoch))
            save_model(model, optimizer, opt, epoch, save_file)

    save_file = os.path.join(
        opt.save_folder, 'last.pth')
    save_model(model, optimizer, opt, opt.epochs, save_file)

* # **Arguments**

In [None]:
args = ('--batch_size', '128', '--patient_split', '1', '--model', 'resnet50','--num_methods', '2',
        '--method1','bcva','--method2','eye_id','--dataset' ,'Prime_TREX_DME_Fixed', '--epochs', '25','--device' ,'cuda:0',
        '--train_image_path' ,'/kaggle/input/olives-vip-cup-2023/2023 IEEE SPS Video and Image Processing (VIP) Cup - Ophthalmic Biomarker Detection/TRAIN/OLIVES')

# **Run the code**

In [None]:
lets_run_the_code = main()