In [None]:
import torch
import matplotlib.pyplot as plt
import numpy as np
import torch.nn as nn
import torch.utils.data as utils
from torchvision import transforms, models
import pretrainedmodels
from torchvision.datasets import ImageFolder
import torch.optim as optim
from torch.optim import lr_scheduler
import torch.nn.functional as F
import os
import copy
import cv2
import shutil
import json
import pandas as pd
import time
from PIL import Image
from sklearn.metrics import confusion_matrix, roc_curve, auc
import random
import logging



In [None]:
def count_into_lst(lst):
    answer=dict()
    for num in lst:
        if num not in answer.keys():
            answer[num]=1
        else:
            answer[num]+=1
    return answer

In [None]:
def instance_data_sequence(json_file):
    random.shuffle(json_file)
    indi_data = []
    for i in range(len(json_file)): 
        for j in range(json_file[i]['file_num']):
            json_dict = dict()
            json_dict['file_dir'] = json_file[i]['file_dir'][j]
            json_dict['class'] = json_file[i]['class number']
            json_dict['instance_num'] = i 
            
            indi_data += [json_dict] 
    return indi_data

In [None]:
class CustomDatset(object):
    
    def __init__(self,all_json, transform):
        self.all_json = all_json
        self.transform = transform

    def __getitem__(self, idx):
        
        imgs = self.transform(Image.open(self.all_json[idx]['file_dir']).convert("RGB"))
        targets = torch.from_numpy(np.array(self.all_json[idx]['class'])).type(torch.LongTensor)
        instance = self.all_json[idx]['instance_num']
        
        return imgs ,targets, instance

    def __len__(self):
        return len(self.all_json)

In [None]:
import yaml


def save_checkpoint(state, is_best, filename='checkpoint.pth.tar'):
    torch.save(state, filename)
    if is_best:
        shutil.copyfile(filename, 'model_best.pth.tar')


def save_config_file(model_checkpoints_folder, args):
    if not os.path.exists(model_checkpoints_folder):
        os.makedirs(model_checkpoints_folder)
        with open(os.path.join(model_checkpoints_folder, 'config.yml'), 'w') as outfile:
            yaml.dump(args, outfile, default_flow_style=False)


def accuracy(output, target, topk=(1,)):
    with torch.no_grad():
        maxk = max(topk)
        batch_size = target.size(0)

        _, pred = output.topk(maxk, 1, True, True)

        pred = pred.t()
        correct = pred.eq(target.view(1, -1).expand_as(pred))
        
        res = []
        for k in topk:
            correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
            res.append(correct_k.mul_(100.0 / batch_size))
        return res

In [None]:

import sys
from torch.utils.tensorboard import SummaryWriter
import torch.nn.functional as F
from torch.cuda.amp import GradScaler, autocast
from tqdm import tqdm

torch.manual_seed(0)


class SimCLR(object):

    def __init__(self, *args, **kwargs):
        
        numoflabel_1 = len(os.listdir(r'C:\Users\yonsei\Desktop\4 scar types\Resize/Train/Adhesive'))
        numoflabel_2 = len(os.listdir(r'C:\Users\yonsei\Desktop\4 scar types\Resize/Train/Bulge'))
        numoflabel_3 = len(os.listdir(r'C:\Users\yonsei\Desktop\4 scar types\Resize/Train/Hypertrophy'))
        numoflabel_4 = len(os.listdir(r'C:\Users\yonsei\Desktop\4 scar types\Resize/Train/Linear'))
        weights = torch.tensor([numoflabel_1,numoflabel_2,numoflabel_3,numoflabel_4], dtype=torch.float32)
        weights = weights / weights.sum()
        weights = 1.0 / weights
        weights = weights / weights.sum()
        
        
        self.args = kwargs['args']
        self.model = kwargs['model'].cuda()
        self.optimizer = kwargs['optimizer']
        self.scheduler = kwargs['scheduler']
        self.writer = SummaryWriter()
        logging.basicConfig(filename=os.path.join(self.writer.log_dir, 'training.log'), level=logging.DEBUG)
        self.criterion = torch.nn.CrossEntropyLoss(weight=weights.cuda()).cuda()

    def instance_loss(self, features, targets, instances):
        
        targets = np.array(targets)
        instances = np.array(instances)
        loss_class = torch.zeros((3,self.args.num_class))
        features = F.normalize(features, dim=1)
        similarity_matrix = torch.matmul(features, features.T)
        
        set_instances = list(set(instances))
        instance_counts = count_into_lst(instances)
        
        for i in range(len(set_instances)):
            if instance_counts[set_instances[i]] == 1:
                pass
            else:
                idx = np.where(np.array(instances) == set_instances[i])
                loss_class[0,targets[idx[0][0]]] = torch.mean(2-(similarity_matrix[idx[0][0]:idx[0][len(idx[0])-1]]+1))
        set_targets = list(set(targets))
        for i in (set_targets):
            p_idx = np.where(targets == i)
            n_idx = np.where(targets != i)
            loss_class[1,i] = torch.mean(2-(similarity_matrix[p_idx[0]][:,p_idx[0]]+1))
            loss_class[2,i] = torch.mean(1/(2-(similarity_matrix[p_idx[0]][:,n_idx[0]]+1)) - 0.5)
            
        instance_loss = self.args.instance_loss_alpha*torch.sum(loss_class[0,:]) + self.args.instance_loss_beta*torch.sum(loss_class[1,:]) + self.args.instance_loss_gamma*torch.sum(loss_class[2,:])
        
        return instance_loss

    def train(self, json_Train,valid_loader,transforms):
        scaler = GradScaler(enabled=self.args.fp16_precision)
        
        # save config file
        save_config_file(self.writer.log_dir, self.args)
            
        n_iter = 0
        logging.info(f"Start SimCLR training for {self.args.epochs} epochs.")
        logging.info(f"Training with gpu: {self.args.disable_cuda}.")
        
        for epoch_counter in range(self.args.epochs):
            phase = 'train'
            with torch.set_grad_enabled(phase == 'train'):
                self.model.train()

                json_shuffle = instance_data_sequence(json_Train)

                train_dataset = CustomDatset(all_json = json_shuffle, transform = transforms)

                train_loader = torch.utils.data.DataLoader(
                    train_dataset, batch_size=self.args.batch_size, shuffle=False,
                    num_workers=0, pin_memory=True, drop_last=True)

                print(epoch_counter+1)
                for images,targets, instances in tqdm(train_loader):

                    images = images.cuda()


                    with autocast(enabled=self.args.fp16_precision):
                        pred,features = self.model(images)
                        loss1 = self.instance_loss(features,targets,instances)
                        targets = targets.cuda()
                        loss2 = self.criterion(pred, targets)
                        loss = loss1 + loss2

                    self.optimizer.zero_grad()

                    scaler.scale(loss).backward()

                    scaler.step(self.optimizer)
                    scaler.update()

                    if n_iter % self.args.log_every_n_steps == 0:
                        self.writer.add_scalar('loss', loss, global_step=n_iter)
                        self.writer.add_scalar('learning_rate', self.scheduler.get_lr()[0], global_step=n_iter)

                    n_iter += 1

                # warmup for the first 10 epochs
                if epoch_counter >= 10:
                    self.scheduler.step()
                logging.debug(f"Epoch: {epoch_counter}\tinstance_Loss: {loss1}\tentrophy_Loss: {loss2}")

            phase = 'valid'
            self.model.eval()
            valid_loss = 0
            valid_idx = 0
            with torch.no_grad():
                for images,targets, instances in tqdm(valid_loader):
                    valid_idx += 1
                    images = images.cuda()
                    pred,features = self.model(images)
                    loss1 = self.instance_loss(features,targets,instances)
                    targets = targets.cuda()
                    loss2 = self.criterion(pred, targets)
                    valid_loss +=  loss2
                    
                valid_loss = valid_loss/valid_idx
                self.writer.add_scalar('Val_loss', valid_loss, global_step=n_iter)
                
                

        
        logging.info("Training has finished.")
        checkpoint_name = 'checkpoint_{:04d}.pth.tar'.format(self.args.epochs)
        save_checkpoint({
            'epoch': self.args.epochs,
            'arch': self.args.arch,
            'state_dict': self.model.state_dict(),
            'optimizer': self.optimizer.state_dict(),
        }, is_best=False, filename=os.path.join(self.writer.log_dir, checkpoint_name))
        logging.info(f"Model checkpoint and metadata has been saved at {self.writer.log_dir}.")

In [None]:
import torch.nn as nn
import torchvision.models as models


class ResNetSimCLR(nn.Module):

    def __init__(self, base_model, out_dim, num_class):
        super(ResNetSimCLR, self).__init__()
        self.resnet_dict = {
                            "resnet50": models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V2),
                           "resnet152": models.resnet152(weights=models.ResNet152_Weights.IMAGENET1K_V2)}

        self.backbone = self._get_basemodel(base_model)
        dim_mlp = self.backbone.fc.in_features

        self.backbone.fc = nn.Identity()
        self.backbone.fc1 = nn.Sequential(nn.Linear(dim_mlp, num_class))
        self.backbone.fc2 = nn.Sequential(nn.Linear(dim_mlp, dim_mlp), nn.ReLU(), nn.Linear(dim_mlp, out_dim))
    def _get_basemodel(self, model_name):
        try:
            model = self.resnet_dict[model_name]
        except KeyError:
            raise InvalidBackboneError(
                "Invalid backbone architecture. Check the config file and pass one of: resnet18 or resnet50")
        else:
            return model

    def forward(self, x):
        return self.backbone.fc1(self.backbone(x)),self.backbone.fc2(self.backbone(x))

In [None]:
import argparse
import torch.backends.cudnn as cudnn

model_names = sorted(name for name in models.__dict__
                     if name.islower() and not name.startswith("__")
                     and callable(models.__dict__[name]))

parser = argparse.ArgumentParser(description='PyTorch SimCLR')
parser.add_argument('-data', metavar='DIR', default='./datasets',
                    help='path to dataset')
parser.add_argument('-a', '--arch', metavar='ARCH', default='resnet50',
                    choices=model_names,
                    help='model architecture: ' +
                         ' | '.join(model_names) +
                         ' (default: resnet50)')
parser.add_argument('--epochs', default=200, type=int, metavar='N',
                    help='number of total epochs to run')
parser.add_argument('-b', '--batch-size', default=128, type=int,
                    metavar='N',
                    help='mini-batch size (default: 256), this is the total '
                         'batch size of all GPUs on the current node when '
                         'using Data Parallel or Distributed Data Parallel')
parser.add_argument('--lr', '--learning-rate', default=0.00001, type=float,
                    metavar='LR', help='initial learning rate', dest='lr')
parser.add_argument('--wd', '--weight-decay', default=1e-4, type=float,
                    metavar='W', help='weight decay (default: 1e-4)',
                    dest='weight_decay')
parser.add_argument('--seed', default=None, type=int,
                    help='seed for initializing training. ')
parser.add_argument('--disable-cuda', action='store_true',
                    help='Disable CUDA')
parser.add_argument('--fp16-precision', action='store_true',
                    help='Whether or not to use 16-bit precision GPU training.')

parser.add_argument('--out_dim', default=128, type=int,
                    help='feature dimension (default: 128)')
parser.add_argument('--log-every-n-steps', default=100, type=int,
                    help='Log every n steps')
parser.add_argument('--gpu-index', default=0, type=int, help='Gpu index.')
parser.add_argument('--instance_loss_alpha', default=0.7, type=float, help='Instance loss alpha.')
parser.add_argument('--instance_loss_beta', default=0.3, type=float, help='Instance loss beta.')
parser.add_argument('--instance_loss_gamma', default=0.3, type=float, help='Instance loss beta.')
parser.add_argument('--num_class', default=4, type=int, help='number of class.')


def main():
    args = parser.parse_known_args()
    args = args[0]
    cudnn.deterministic = True
    cudnn.benchmark = True

    with open("./instancelearning_Train.json", 'r') as file:
        json_Train = json.load(file)
    with open("./instancelearning_Valid.json", 'r') as file:
        json_Valid = json.load(file)
    with open("./instancelearning_Test.json", 'r') as file:
        json_Test = json.load(file)

    
    
    
    transform = transforms.Compose([
                                       transforms.ToTensor()
                                       ,transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])
    

    valid_dataset = CustomDatset(all_json =  instance_data_sequence(json_Valid), transform = transform)

    valid_loader = torch.utils.data.DataLoader(
        valid_dataset, batch_size=args.batch_size, shuffle=False,
        num_workers=0, pin_memory=True, drop_last=True)

    train_dataset = CustomDatset(all_json = instance_data_sequence(json_Train), transform = transform)

    train_loader = torch.utils.data.DataLoader(
        train_dataset, batch_size=args.batch_size, shuffle=False,
        num_workers=0, pin_memory=True, drop_last=True)

    model = ResNetSimCLR(base_model=args.arch, out_dim=args.out_dim, num_class = args.num_class)

    optimizer = torch.optim.Adam(model.parameters(), args.lr, weight_decay=args.weight_decay)

    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=len(train_loader), eta_min=0,
                                                           last_epoch=-1)
    with torch.cuda.device(args.gpu_index):
        simclr = SimCLR(model=model, optimizer=optimizer, scheduler=scheduler, args=args)
        simclr.train(json_Train,valid_loader,transform)


if __name__ == "__main__":
    main()