In [None]:
import os
import numpy as np
import pytorch_lightning as pl
import torch
import pandas as pd
import timm
import math
import torch.nn as nn
from tqdm import tqdm
from PIL import Image
from torchvision import transforms as tsfm
from torch.utils.data import Dataset, DataLoader
from pytorch_lightning import Trainer, seed_everything
from pytorch_lightning.loggers import CSVLogger
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping
from torchcontrib.optim import SWA
from torchmetrics import Metric
from torch.utils.tensorboard import SummaryWriter

# Classification

In [None]:
class CFG:
    data_path = '/scratch/ISIC2019/train/isic_train.csv'
    train_imgs_dir = '/scratch/ISIC2019/train/'
    # model info
    # label info
    label_num2str = {0: 'NV',
                     1: 'MEL',
                     2: 'BCC',
                     3: 'BKL',
                     4: 'AK',
                     5: 'SCC',
                     6: 'VASC',
                     7: 'DF'
                     }
    label_str2num = {'NV': 0,
                     'MEL':1,
                     'BCC':2,
                     'BKL':3,
                     'AK':4,
                     'SCC':5,
                     'VASC':6,
                     'DF':7 
                     }
    fl_alpha = 1.0  # alpha of focal_loss
    fl_gamma = 2.0  # gamma of focal_loss
    cls_weight = [1.0, 0.4717294571343815, 0.39523385741125283, 0.3535449421536636, 0.2405023237417186, 0.2242064669237615, 0.20134480371798677, 0.2]
    cnn_name='resnet50'
    vit_name='vit_base_patch16_384'
    seed = 77
    num_classes = 8
    batch_size = 16
    t_max = 16
    lr = 1e-3
    min_lr = 1e-6
    n_fold = 6
    num_workers = 8
    accum_grad_batch = 1
    early_stop_delta = 1e-7
    gpu_idx = 0
    device = torch.device(f'cuda:{gpu_idx}' if torch.cuda.is_available() else 'cpu')
    gpu_list = [gpu_idx]

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
seed_everything(77)
cfg=CFG()

In [None]:
"""
Define train & valid image transformation
"""
DATASET_IMAGE_MEAN = (0.485, 0.456, 0.406)
DATASET_IMAGE_STD = (0.229, 0.224, 0.225)

train_transform = tsfm.Compose([tsfm.Resize((384,384)),
                                tsfm.RandomApply([tsfm.ColorJitter(0.2, 0.2, 0.2),tsfm.RandomPerspective(distortion_scale=0.2),], p=0.3),
                                tsfm.RandomApply([tsfm.ColorJitter(0.2, 0.2, 0.2),tsfm.RandomAffine(degrees=10),], p=0.3),
                                tsfm.RandomVerticalFlip(p=0.3),
                                tsfm.RandomHorizontalFlip(p=0.3),
                                tsfm.ToTensor(),
                                tsfm.Normalize(DATASET_IMAGE_MEAN, DATASET_IMAGE_STD), ])

valid_transform = tsfm.Compose([tsfm.Resize((384,384)),
                                tsfm.ToTensor(),
                                tsfm.Normalize(DATASET_IMAGE_MEAN, DATASET_IMAGE_STD), ])

In [None]:
"""
Define dataset class
"""
class Dataset(Dataset):
    def __init__(self, cfg, img_names: list, labels: list, transform=None):
        self.img_dir = cfg.train_imgs_dir
        self.img_names = img_names
        self.labels = labels
        self.transform = transform

    def __len__(self):
        return len(self.img_names)

    def __getitem__(self, idx):
        img_path = os.path.join(self.img_dir, self.img_names[idx]+'.jpg')
        img = Image.open(img_path).convert('RGB')
        img_ts = self.transform(img)
        label_ts = self.labels[idx]
        return img_ts, label_ts

In [None]:
"""
Define Focal-Loss
"""

class FocalLoss(nn.Module):
    """
    The focal loss for fighting against class-imbalance
    """
    def __init__(self, alpha=1, gamma=2):
        super(FocalLoss, self).__init__()
        self.alpha = alpha
        self.gamma = gamma
        self.epsilon = 1e-12  # prevent training from Nan-loss error
        self.cls_weights = torch.tensor([CFG.cls_weight],dtype=torch.float, requires_grad=False, device=CFG.device)

    def forward(self, logits, target):
        """
        logits & target should be tensors with shape [batch_size, num_classes]
        """
        probs = torch.sigmoid(logits)
        one_subtract_probs = 1.0 - probs
        # add epsilon
        probs_new = probs + self.epsilon
        one_subtract_probs_new = one_subtract_probs + self.epsilon
        # calculate focal loss
        log_pt = target * torch.log(probs_new) + (1.0 - target) * torch.log(one_subtract_probs_new)
        pt = torch.exp(log_pt)
        focal_loss = -1.0 * (self.alpha * (1 - pt) ** self.gamma) * log_pt
        focal_loss = focal_loss * self.cls_weights
        return torch.mean(focal_loss)

In [None]:
"""
Define F1 score metric
"""
class MyF1Score(Metric):
    def __init__(self, cfg, threshold: float = 0.5, dist_sync_on_step=False):
        super().__init__(dist_sync_on_step=dist_sync_on_step)
        self.cfg = cfg
        self.threshold = threshold
        self.add_state("tp", default=torch.tensor(0), dist_reduce_fx="sum")
        self.add_state("fp", default=torch.tensor(0), dist_reduce_fx="sum")
        self.add_state("fn", default=torch.tensor(0), dist_reduce_fx="sum")

    def update(self, preds: torch.Tensor, target: torch.Tensor):
        assert preds.shape == target.shape
        preds_str_batch = self.num_to_str(torch.sigmoid(preds))
        target_str_batch = self.num_to_str(target)
        tp, fp, fn = 0, 0, 0
        for pred_str_list, target_str_list in zip(preds_str_batch, target_str_batch):
            for pred_str in pred_str_list:
                if pred_str in target_str_list:
                    tp += 1
                if pred_str not in target_str_list:
                    fp += 1

            for target_str in target_str_list:
                if target_str not in pred_str_list:
                    fn += 1
        self.tp += tp
        self.fp += fp
        self.fn += fn

    def compute(self):
        #To switch between F1 score and recall.
        #f1 = 2.0 * self.tp / (2.0 * self.tp + self.fn + self.fp)
        rec = self.tp/(self.tp + self.fn)
        return rec
    
    def num_to_str(self, ts: torch.Tensor) -> list:
        batch_bool_list = (ts > self.threshold).detach().cpu().numpy().tolist()
        batch_str_list = []
        for one_sample_bool in batch_bool_list:
            lb_str_list = [self.cfg.label_num2str[lb_idx] for lb_idx, bool_val in enumerate(one_sample_bool) if bool_val]
            batch_str_list.append(lb_str_list)
        return batch_str_list

In [None]:
df=pd.read_csv(cfg.data_path)

In [None]:
from sklearn.model_selection import train_test_split
X_train, X_val, y_train, y_val = train_test_split(df['image_name'], df['diagnosis'], test_size=0.2, random_state=77)

In [None]:
all_img_names: list = X_train.values.tolist()

In [None]:
all_img_names_valid: list = X_val.values.tolist()

In [None]:
len(all_img_names)

In [None]:
all_img_labels_ts = []
for tmp_lb in y_train:
    tmp_label = torch.zeros([CFG.num_classes], dtype=torch.float)
    label_num=CFG.label_str2num.get(tmp_lb)
    k=int(label_num)
    tmp_label[k] = 1.0
    all_img_labels_ts.append(tmp_label) 

In [None]:
all_img_labels_ts

In [None]:
all_img_labels_val_ts = []
for tmp_lb in y_val:
    tmp_label = torch.zeros([CFG.num_classes], dtype=torch.float)
    label_num=CFG.label_str2num.get(tmp_lb)
    k=int(label_num)
    tmp_label[k] = 1.0
    all_img_labels_val_ts.append(tmp_label)

In [None]:
all_img_labels_val_ts

In [None]:
model_cnn = timm.create_model(cfg.cnn_name, pretrained=True)
model_vit = timm.create_model(cfg.vit_name, pretrained=True)
model_cnn.to(device)
model_vit.to(device)

In [None]:
def ssl_train_model(train_loader,model_vit,criterion_vit,optimizer_vit,scheduler_vit,model_cnn,criterion_cnn,optimizer_cnn,scheduler_cnn,num_epochs):
    writer = SummaryWriter()
    phase = 'train'
    model_cnn.train()
    model_vit.train()
    f1_score_cnn=0
    f1_score_vit=0
    for i in tqdm(range(num_epochs)):
        with torch.set_grad_enabled(phase == 'train'):
            for img,_ in train_loader:
                f1_score_cnn=0
                f1_score_vit=0
                img = img.to(device)
                pred_vit = model_vit(img)
                pred_cnn = model_cnn(img)
                model_sim_loss=loss_fn(pred_vit,pred_cnn)
                loss = model_sim_loss.mean()
                loss.backward()
                optimizer_cnn.step()
                optimizer_vit.step()
                scheduler_cnn.step()
                scheduler_vit.step()
            print('For -',i,'Loss:',loss) 
            writer.add_scalar("Self-Supervised Loss/train", loss, i)
    writer.flush()

In [None]:
optimizer_cnn = SWA(torch.optim.Adam(model_cnn.parameters(), lr= 1e-3))
optimizer_vit = SWA(torch.optim.Adam(model_vit.parameters(), lr= 1e-3))
scheduler_cnn = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer_cnn,
                                                                    T_max=16,
                                                                    eta_min=1e-6)
scheduler_vit = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer_vit,
                                                                    T_max=16,
                                                                    eta_min=1e-6)

criterion_vit = FocalLoss(cfg.fl_alpha, cfg.fl_gamma)
criterion_cnn = FocalLoss(cfg.fl_alpha, cfg.fl_gamma)

In [None]:
def loss_fn(x, y):
    x =  torch.nn.functional.normalize(x, dim=-1, p=2)
    y =  torch.nn.functional.normalize(y, dim=-1, p=2)
    return 2 - 2 * (x * y).sum(dim=-1)

In [None]:
import random
random.seed(77)
x=0.1 #currently set to use 10% of the labels for reduced label training 
onep=random.sample(range(0, len(X_train)), int(len(X_train)*x))
all_img_names_train = [all_img_names[idx] for idx in onep]
all_img_labels_ts_train = [all_img_labels_ts[idx] for idx in onep]

In [None]:
train_dataset = Dataset(CFG, all_img_names_train,all_img_labels_ts_train, train_transform)
valid_dataset = Dataset(CFG, all_img_names_valid, all_img_labels_val_ts, valid_transform)
train_loader = DataLoader(train_dataset, batch_size=CFG.batch_size, shuffle=True, num_workers=CFG.num_workers)
valid_loader = DataLoader(valid_dataset, batch_size=CFG.batch_size, shuffle=False, num_workers=CFG.num_workers)

In [None]:
len(valid_dataset)

In [None]:
len(train_dataset)

In [None]:
#Train SSL
print('Training Cov-T')
ssl_train_model(train_loader,model_vit,criterion_vit,optimizer_vit,scheduler_vit,model_cnn,criterion_cnn,optimizer_cnn,scheduler_cnn,num_epochs=100)
#Saving SSL Models
print('Saving Cov-T')
torch.save(model_cnn,'./cass-r50-isic.pt')
torch.save(model_vit,'./cass-r50-vit-isic.pt')

In [None]:
model_cnn=torch.load('./covt-r50-isic.pt')
model_vit=torch.load('./covt-r50-vit-isic.pt')

In [None]:
#Train Correspong Supervised CNN
print('Fine tunning Cov-T')
model_cnn.fc=nn.Linear(in_features=2048, out_features=8, bias=True)
criterion = FocalLoss(cfg.fl_alpha, cfg.fl_gamma)
metric = MyF1Score(cfg)
val_metric=MyF1Score(cfg)
optimizer = torch.optim.Adam(model_cnn.parameters(), lr = 3e-4)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer,T_max=cfg.t_max,eta_min=cfg.min_lr,verbose=True)
model_cnn.train()
from torch.autograd import Variable
best=0
best_val=0
last_loss=math.inf
writer = SummaryWriter()
for epoch in range(50):
    for images,label in train_loader:
        model_cnn.train()
        images = images.to(device)
        label = label.to(device)
        model_cnn.to(device)
        pred_ts=model_cnn(images)
        loss = criterion(pred_ts, label)
        score = metric(pred_ts, label)
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
        scheduler.step()
    train_score=metric.compute()
    logs = {'train_loss': loss, 'Recall': train_score, 'lr': optimizer.param_groups[0]['lr']}
    writer.add_scalar("Supervised-CNN Loss/train", loss, epoch)
    writer.add_scalar("Supervised-CNN Recall/train", train_score, epoch)
    for name, weight in model_cnn.named_parameters():
        writer.add_histogram(name,weight, epoch)
        writer.add_histogram(f'{name}.grad',weight.grad, epoch)
    print(logs)
    if best < train_score:
        with torch.no_grad():
        best=train_score
        model_cnn.eval()
        total_loss = 0
        for images,label in valid_loader:
            images = images.to(device)
            label = label.to(device)
            model_cnn.to(device)
            pred_ts=model_cnn(images)
            score_val = val_metric(pred_ts,label)
            val_loss = criterion(pred_ts, label)
            total_loss += val_loss.detach()
        avg_loss=total_loss/ len(train_loader)   
        print('Val Loss:',avg_loss)
        val_score=val_metric.compute()
        print('CNN Validation Score:',val_score)
        writer.add_scalar("CNN Supervised F1/Validation", val_score, epoch)
        if avg_loss > last_loss:
            counter+=1
        else:
            counter=0
                
        last_loss = avg_loss
        if counter > 5:
            print('Early Stopping!')
            break
        else:
            if val_score > best_val:
                best_val=val_score
                print('Saving')
                torch.save(model_cnn,
                    './CASS-CNN-part-ft.pt')
writer.flush()

In [None]:
model_vit=torch.load('./covt-r50-vit-isic.pt')
model_vit.head=nn.Linear(in_features=768, out_features=8, bias=True)
criterion = FocalLoss(cfg.fl_alpha, cfg.fl_gamma)
metric = MyF1Score(cfg)
optimizer = torch.optim.Adam(model_vit.parameters(), lr = 3e-4)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer,T_max=cfg.t_max,eta_min=cfg.min_lr,verbose=True)
model_vit.train()
val_metric=MyF1Score(cfg)
writer = SummaryWriter()
from torch.autograd import Variable
best=0
best_val=0
last_loss=math.inf
for epoch in range(50):
    for images,label in train_loader:
        model_vit.train()
        images = images.to(device)
        label = label.to(device)
        model_vit.to(device)
        pred_ts=model_vit(images)
        loss = criterion(pred_ts, label)
        score = metric(pred_ts,label)
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
        scheduler.step()
    train_score=metric.compute()
    logs = {'train_loss': loss, 'Recall': train_score, 'lr': optimizer.param_groups[0]['lr']}
    writer.add_scalar("Supervised-ViT Loss/train", loss, epoch)
    writer.add_scalar("Supervised-ViT Recall/train", train_score, epoch)
    for name, weight in model_vit.named_parameters():
        writer.add_histogram(name,weight, epoch)
        writer.add_histogram(f'{name}.grad',weight.grad, epoch)
    print(logs)
    if best < train_score:
        with torch.no_grad():
        best=train_score
        model_cnn.eval()
        total_loss = 0
        for images,label in valid_loader:
            images = images.to(device)
            label = label.to(device)
            model_cnn.to(device)
            pred_ts=model_cnn(images)
            score_val = val_metric(pred_ts,label)
            val_loss = criterion(pred_ts, label)
            total_loss += val_loss.detach()
        avg_loss=total_loss/ len(train_loader)   
        print('Val Loss:',avg_loss)
        val_score=val_metric.compute()
        print('CNN Validation Score:',val_score)
        writer.add_scalar("CNN Supervised F1/Validation", val_score, epoch)
        if avg_loss > last_loss:
            counter+=1
        else:
            counter=0
                
        last_loss = avg_loss
        if counter > 5:
            print('Early Stopping!')
            break
        else:
            if val_score > best_val:
                best_val=val_score
                print('Saving')
                torch.save(model_cnn,
                    './CASS-ViT-part-ft.pt')
writer.flush()