# Classification

In [1]:
import os
import numpy as np
import pytorch_lightning as pl
import torch
import pandas as pd
import timm
import torch.nn as nn
from tqdm import tqdm
from PIL import Image
from sklearn.model_selection import KFold
from torchvision import transforms as tsfm
from torch.utils.data import Dataset, DataLoader
from pytorch_lightning import Trainer, seed_everything
from pytorch_lightning.loggers import CSVLogger
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping
from torchcontrib.optim import SWA
from torchmetrics import Metric
from torch.utils.tensorboard import SummaryWriter

In [2]:
class CFG:
    # data path
    train_csv_path = '/scratch/ps4364/BTMRI/data/train.csv'
    train_imgs_dir = '/scratch/ps4364/BTMRI/data/Training/'
    # model info
    # label info
    label_num2str = {0: 'glioma',
                     1: 'pituitary',
                     2:'notumor',
                     3:'meningioma'
                     }
    label_str2num = {'glioma': 0,
                     'pituitary':1,
                     'notumor':2,
                     'meningioma':3
                     }
    fl_alpha = 1.0  # alpha of focal_loss
    fl_gamma = 2.0  # gamma of focal_loss
    cls_weight =  [0.2, 0.5970802919708029, 1.0, 0.25255474452554744]
    cnn_name='resnet50'
    vit_name='vit_base_patch16_384'
    seed = 77
    num_classes = 4
    batch_size = 16
    t_max = 16
    lr = 1e-3
    min_lr = 1e-6
    n_fold = 6
    num_workers = 8
    gpu_idx = 0
    device = torch.device(f'cuda:{gpu_idx}' if torch.cuda.is_available() else 'cpu')
    gpu_list = [gpu_idx]

In [3]:
def normalize(arr, t_min, t_max):
    norm_arr = []
    diff = t_max - t_min
    diff_arr = max(arr) - min(arr)
    for i in arr:
        temp = (((i - min(arr))*diff)/diff_arr) + t_min
        norm_arr.append(temp)
    return norm_arr
  
# assign array and range
array_1d = [1321,1457,1595,1339]
range_to_normalize = (0.2, 1)
normalized_array_1d = normalize(
    array_1d, range_to_normalize[0], 
  range_to_normalize[1])
  
# display original and normalized array
print("Original Array = ", array_1d)
print("Normalized Array = ", normalized_array_1d)

Original Array =  [1321, 1457, 1595, 1339]
Normalized Array =  [0.2, 0.5970802919708029, 1.0, 0.25255474452554744]


In [4]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
seed_everything(77)

Global seed set to 77


77

In [5]:
"""
Define train & valid image transformation
"""
DATASET_IMAGE_MEAN = (0.485, 0.456, 0.406)
DATASET_IMAGE_STD = (0.229, 0.224, 0.225)

train_transform = tsfm.Compose([tsfm.Resize((384,384)),
                                tsfm.RandomApply([tsfm.ColorJitter(0.2, 0.2, 0.2),tsfm.RandomPerspective(distortion_scale=0.2),], p=0.3),
                                tsfm.RandomApply([tsfm.RandomAffine(degrees=10),], p=0.3),
                                tsfm.RandomVerticalFlip(p=0.3),
                                tsfm.RandomHorizontalFlip(p=0.3),
                                tsfm.ToTensor(),
                                tsfm.Normalize(DATASET_IMAGE_MEAN, DATASET_IMAGE_STD), ])

valid_transform = tsfm.Compose([tsfm.Resize((384,384)),
                                tsfm.ToTensor(),
                                tsfm.Normalize(DATASET_IMAGE_MEAN, DATASET_IMAGE_STD), ])

In [6]:
"""
Define dataset class
"""
class Dataset(Dataset):
    def __init__(self, cfg, img_names: list, labels: list, transform=None):
        self.img_dir = cfg.train_imgs_dir
        self.img_names = img_names
        self.labels = labels
        self.transform = transform

    def __len__(self):
        return len(self.img_names)

    def __getitem__(self, idx):
        img_path = self.img_names[idx]
        img = Image.open(img_path).convert('RGB')
        img_ts = self.transform(img)
        label_ts = self.labels[idx]
        return img_ts, label_ts

In [7]:
"""
Define Focal-Loss
"""

class FocalLoss(nn.Module):
    """
    The focal loss for fighting against class-imbalance
    """
    def __init__(self, alpha=1, gamma=2):
        super(FocalLoss, self).__init__()
        self.alpha = alpha
        self.gamma = gamma
        self.epsilon = 1e-12  # prevent training from Nan-loss error
        self.cls_weights = torch.tensor([CFG.cls_weight],dtype=torch.float, requires_grad=False, device=CFG.device)

    def forward(self, logits, target):
        """
        logits & target should be tensors with shape [batch_size, num_classes]
        """
        probs = torch.sigmoid(logits)
        one_subtract_probs = 1.0 - probs
        # add epsilon
        probs_new = probs + self.epsilon
        one_subtract_probs_new = one_subtract_probs + self.epsilon
        # calculate focal loss
        log_pt = target * torch.log(probs_new) + (1.0 - target) * torch.log(one_subtract_probs_new)
        pt = torch.exp(log_pt)
        focal_loss = -1.0 * (self.alpha * (1 - pt) ** self.gamma) * log_pt
        focal_loss = focal_loss * self.cls_weights
        return torch.mean(focal_loss)

In [8]:
"""
Define F1 score metric
"""
class MyF1Score(Metric):
    def __init__(self, cfg, threshold: float = 0.5, dist_sync_on_step=False):
        super().__init__(dist_sync_on_step=dist_sync_on_step)
        self.cfg = cfg
        self.threshold = threshold
        self.add_state("tp", default=torch.tensor(0), dist_reduce_fx="sum")
        self.add_state("fp", default=torch.tensor(0), dist_reduce_fx="sum")
        self.add_state("fn", default=torch.tensor(0), dist_reduce_fx="sum")

    def update(self, preds: torch.Tensor, target: torch.Tensor):
        assert preds.shape == target.shape
        preds_str_batch = self.num_to_str(torch.sigmoid(preds))
        target_str_batch = self.num_to_str(target)
        tp, fp, fn = 0, 0, 0
        for pred_str_list, target_str_list in zip(preds_str_batch, target_str_batch):
            for pred_str in pred_str_list:
                if pred_str in target_str_list:
                    tp += 1
                if pred_str not in target_str_list:
                    fp += 1

            for target_str in target_str_list:
                if target_str not in pred_str_list:
                    fn += 1
        self.tp += tp
        self.fp += fp
        self.fn += fn

    def compute(self):
        f1 = 2.0 * self.tp / (2.0 * self.tp + self.fn + self.fp)
        return f1
    
    def num_to_str(self, ts: torch.Tensor) -> list:
        batch_bool_list = (ts > self.threshold).detach().cpu().numpy().tolist()
        batch_str_list = []
        for one_sample_bool in batch_bool_list:
            lb_str_list = [self.cfg.label_num2str[lb_idx] for lb_idx, bool_val in enumerate(one_sample_bool) if bool_val]
            batch_str_list.append(lb_str_list)
        return batch_str_list

In [9]:
TRAIN_DF = pd.read_csv(CFG.train_csv_path)

In [10]:
TRAIN_DF

Unnamed: 0.1,Unnamed: 0,path,target
0,0,/scratch/ps4364/BTMRI/data/Training/glioma/Tr-...,glioma
1,1,/scratch/ps4364/BTMRI/data/Training/glioma/Tr-...,glioma
2,2,/scratch/ps4364/BTMRI/data/Training/glioma/Tr-...,glioma
3,3,/scratch/ps4364/BTMRI/data/Training/glioma/Tr-...,glioma
4,4,/scratch/ps4364/BTMRI/data/Training/glioma/Tr-...,glioma
...,...,...,...
5707,5707,/scratch/ps4364/BTMRI/data/Training/meningioma...,meningioma
5708,5708,/scratch/ps4364/BTMRI/data/Training/meningioma...,meningioma
5709,5709,/scratch/ps4364/BTMRI/data/Training/meningioma...,meningioma
5710,5710,/scratch/ps4364/BTMRI/data/Training/meningioma...,meningioma


In [11]:

"""
Split train & validation into Cross-Validation Folds
"""

all_img_names: list = TRAIN_DF["path"].values.tolist()
all_img_labels: list = TRAIN_DF["target"].values.tolist()

In [12]:
cfg=CFG()
all_img_labels_ts = []
for tmp_lb in all_img_labels:
    tmp_label = torch.zeros([CFG.num_classes], dtype=torch.float)
    place=cfg.label_str2num.get(tmp_lb)
    k=int(place)
    tmp_label[k] = 1.0
    all_img_labels_ts.append(tmp_label)
    
k_fold = KFold(n_splits=CFG.n_fold, shuffle=True, random_state=CFG.seed)

In [13]:
print(k_fold.split(all_img_names))

<generator object _BaseKFold.split at 0x1511a8db6900>


In [14]:
len(all_img_labels_ts)

5712

In [15]:
for fold_idx, (train_indices, valid_indices) in enumerate(k_fold.split(all_img_names)):
    break

In [16]:
fold_idx

0

In [17]:
import timm
cfg=CFG()

In [16]:
import timm
cfg=CFG()
model_cnn = timm.create_model(cfg.cnn_name, pretrained=True)
model_vit = timm.create_model(cfg.vit_name, pretrained=True)
model_cnn.to(device)
model_vit.to(device)

VisionTransformer(
  (patch_embed): PatchEmbed(
    (proj): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))
    (norm): Identity()
  )
  (pos_drop): Dropout(p=0.0, inplace=False)
  (blocks): Sequential(
    (0): Block(
      (norm1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
      (attn): Attention(
        (qkv): Linear(in_features=768, out_features=2304, bias=True)
        (attn_drop): Dropout(p=0.0, inplace=False)
        (proj): Linear(in_features=768, out_features=768, bias=True)
        (proj_drop): Dropout(p=0.0, inplace=False)
      )
      (drop_path): Identity()
      (norm2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
      (mlp): Mlp(
        (fc1): Linear(in_features=768, out_features=3072, bias=True)
        (act): GELU()
        (drop1): Dropout(p=0.0, inplace=False)
        (fc2): Linear(in_features=3072, out_features=768, bias=True)
        (drop2): Dropout(p=0.0, inplace=False)
      )
    )
    (1): Block(
      (norm1): LayerNorm((768,),

In [21]:
def ssl_train_model(train_loader,model_vit,criterion_vit,optimizer_vit,scheduler_vit,model_cnn,criterion_cnn,optimizer_cnn,scheduler_cnn,num_epochs):
    writer = SummaryWriter()
    phase = 'train'
    model_cnn.train()
    model_vit.train()
    f1_score_cnn=0
    f1_score_vit=0
    for i in tqdm(range(num_epochs)):
        with torch.set_grad_enabled(phase == 'train'):
            for img,_ in train_loader:
                f1_score_cnn=0
                f1_score_vit=0
                img = img.to(device)
                pred_vit = model_vit(img)
                pred_cnn = model_cnn(img)
                model_sim_loss=loss_fn(pred_vit,pred_cnn)
                loss = model_sim_loss.mean()
                loss.backward()
                optimizer_cnn.step()
                optimizer_vit.step()
                scheduler_cnn.step()
                scheduler_vit.step()
            print('For -',i,'Loss:',loss) 
            writer.add_scalar("Self-Supervised Loss/train", loss, i)
    writer.flush()

In [22]:
optimizer_cnn = SWA(torch.optim.Adam(model_cnn.parameters(), lr= 1e-3))
optimizer_vit = SWA(torch.optim.Adam(model_vit.parameters(), lr= 1e-3))
scheduler_cnn = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer_cnn,
                                                                    T_max=16,
                                                                    eta_min=1e-6)
scheduler_vit = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer_vit,
                                                                    T_max=16,
                                                                    eta_min=1e-6)


fl_alpha = 1.0  # alpha of focal_loss
fl_gamma = 2.0  # gamma of focal_loss
cls_weight = [0.9475164011246484, 0.4934395501405811, 0.5029053420805999, 0.2, 1.0]
criterion_vit = FocalLoss(fl_alpha, fl_gamma)
criterion_cnn = FocalLoss(fl_alpha, fl_gamma)

In [23]:
def loss_fn(x, y):
    x =  torch.nn.functional.normalize(x, dim=-1, p=2)
    y =  torch.nn.functional.normalize(y, dim=-1, p=2)
    return 2 - 2 * (x * y).sum(dim=-1)

In [24]:
for fold_idx, (train_indices, valid_indices) in enumerate(k_fold.split(all_img_names)):
    print(fold_idx)
    print(train_indices)
    print(valid_indices)

0
[   0    1    2 ... 5709 5710 5711]
[  28   34   38   42   43   44   48   49   51   65   75   77   87   92
  100  102  105  130  131  133  134  139  140  155  161  168  171  178
  182  185  190  195  197  204  216  228  253  260  267  291  297  299
  303  307  310  313  318  329  335  339  344  347  356  361  363  366
  377  387  397  402  408  410  413  414  420  441  443  444  448  450
  451  464  467  468  476  480  487  489  497  498  499  515  524  525
  528  534  545  548  549  558  563  566  568  595  600  605  606  610
  612  616  624  627  629  632  634  635  636  637  665  666  679  680
  702  706  708  710  720  722  732  750  756  767  771  773  774  780
  786  796  797  803  810  812  813  822  825  830  836  846  848  861
  862  866  867  871  873  875  888  891  892  894  899  912  919  921
  923  925  929  932  934  938  944  960  961  964  965  971  975  977
  981  985  993  996  998 1002 1003 1004 1005 1006 1007 1017 1018 1021
 1028 1047 1052 1055 1060 1061 1075 107

In [None]:
for fold_idx, (train_indices, valid_indices) in enumerate(k_fold.split(all_img_names)):
    model_cnn = timm.create_model(cfg.cnn_name, pretrained=True)
    model_vit = timm.create_model(cfg.vit_name, pretrained=True)
    model_cnn.to(device)
    model_vit.to(device)
    print('*'*10)
    print('For',fold_idx)
    fold_train_img_names = [all_img_names[idx] for idx in train_indices]
    fold_valid_img_names = [all_img_names[idx] for idx in valid_indices]
    fold_train_img_labels = [all_img_labels_ts[idx] for idx in train_indices]
    fold_valid_img_labels = [all_img_labels_ts[idx] for idx in valid_indices]
    train_dataset = Dataset(CFG, fold_train_img_names, fold_train_img_labels, train_transform)
    valid_dataset = Dataset(CFG, fold_valid_img_names, fold_valid_img_labels, valid_transform)
    train_loader = DataLoader(train_dataset, batch_size=CFG.batch_size, shuffle=True, num_workers=CFG.num_workers, drop_last=True)
    valid_loader = DataLoader(valid_dataset, batch_size=CFG.batch_size, shuffle=False, num_workers=CFG.num_workers)
    
    #Train SSL
    print('Training Cov-T')
    ssl_train_model(train_loader,model_vit,criterion_vit,optimizer_vit,scheduler_vit,model_cnn,criterion_cnn,optimizer_cnn,scheduler_cnn,num_epochs=100)
    #Saving SSL Models
    print('Saving Cov-T')
    
    torch.save(model_cnn,'/scratch/ps4364/BTMRI/code/cass/cass-r50-{}-bMRI-with-valid.pt'.format(fold_idx))
    torch.save(model_vit,'/scratch/ps4364/BTMRI/code/cass/cass-vit-{}-bMRI-with-valid.pt'.format(fold_idx))



**********
For 0
Training Cov-T


  1%|          | 1/100 [04:16<7:03:05, 256.42s/it]

For - 0 Loss: tensor(1.9725, device='cuda:0', grad_fn=<MeanBackward0>)


  2%|▏         | 2/100 [08:35<7:01:02, 257.78s/it]

For - 1 Loss: tensor(1.9616, device='cuda:0', grad_fn=<MeanBackward0>)


  3%|▎         | 3/100 [12:53<6:57:29, 258.24s/it]

For - 2 Loss: tensor(1.9689, device='cuda:0', grad_fn=<MeanBackward0>)


  4%|▍         | 4/100 [17:12<6:53:29, 258.44s/it]

For - 3 Loss: tensor(1.9479, device='cuda:0', grad_fn=<MeanBackward0>)


  5%|▌         | 5/100 [21:31<6:49:37, 258.71s/it]

For - 4 Loss: tensor(1.9665, device='cuda:0', grad_fn=<MeanBackward0>)


  6%|▌         | 6/100 [25:50<6:45:31, 258.84s/it]

For - 5 Loss: tensor(1.9574, device='cuda:0', grad_fn=<MeanBackward0>)


  7%|▋         | 7/100 [30:09<6:41:13, 258.86s/it]

For - 6 Loss: tensor(1.9592, device='cuda:0', grad_fn=<MeanBackward0>)


  8%|▊         | 8/100 [34:28<6:36:41, 258.71s/it]

For - 7 Loss: tensor(1.9728, device='cuda:0', grad_fn=<MeanBackward0>)


  9%|▉         | 9/100 [38:48<6:33:02, 259.15s/it]

For - 8 Loss: tensor(1.9666, device='cuda:0', grad_fn=<MeanBackward0>)


 10%|█         | 10/100 [43:07<6:28:52, 259.25s/it]

For - 9 Loss: tensor(1.9604, device='cuda:0', grad_fn=<MeanBackward0>)


 11%|█         | 11/100 [47:27<6:24:33, 259.25s/it]

For - 10 Loss: tensor(1.9619, device='cuda:0', grad_fn=<MeanBackward0>)


 12%|█▏        | 12/100 [51:46<6:20:22, 259.34s/it]

For - 11 Loss: tensor(1.9574, device='cuda:0', grad_fn=<MeanBackward0>)


 13%|█▎        | 13/100 [56:05<6:15:55, 259.26s/it]

For - 12 Loss: tensor(1.9582, device='cuda:0', grad_fn=<MeanBackward0>)


 14%|█▍        | 14/100 [1:00:24<6:11:32, 259.22s/it]

For - 13 Loss: tensor(1.9573, device='cuda:0', grad_fn=<MeanBackward0>)


 15%|█▌        | 15/100 [1:04:44<6:07:35, 259.48s/it]

For - 14 Loss: tensor(1.9628, device='cuda:0', grad_fn=<MeanBackward0>)


 16%|█▌        | 16/100 [1:09:04<6:03:17, 259.50s/it]

For - 15 Loss: tensor(1.9345, device='cuda:0', grad_fn=<MeanBackward0>)


 17%|█▋        | 17/100 [1:13:23<5:58:39, 259.27s/it]

For - 16 Loss: tensor(1.9573, device='cuda:0', grad_fn=<MeanBackward0>)


 18%|█▊        | 18/100 [1:17:42<5:54:12, 259.18s/it]

For - 17 Loss: tensor(1.9628, device='cuda:0', grad_fn=<MeanBackward0>)


 19%|█▉        | 19/100 [1:22:00<5:49:33, 258.93s/it]

For - 18 Loss: tensor(1.9625, device='cuda:0', grad_fn=<MeanBackward0>)


 20%|██        | 20/100 [1:26:19<5:45:16, 258.95s/it]

For - 19 Loss: tensor(1.9536, device='cuda:0', grad_fn=<MeanBackward0>)


 21%|██        | 21/100 [1:30:38<5:40:58, 258.96s/it]

For - 20 Loss: tensor(1.9648, device='cuda:0', grad_fn=<MeanBackward0>)


 22%|██▏       | 22/100 [1:34:57<5:36:42, 259.01s/it]

For - 21 Loss: tensor(1.9461, device='cuda:0', grad_fn=<MeanBackward0>)


 23%|██▎       | 23/100 [1:39:16<5:32:22, 259.00s/it]

For - 22 Loss: tensor(1.9605, device='cuda:0', grad_fn=<MeanBackward0>)


 24%|██▍       | 24/100 [1:43:35<5:28:03, 258.99s/it]

For - 23 Loss: tensor(1.9658, device='cuda:0', grad_fn=<MeanBackward0>)


 25%|██▌       | 25/100 [1:47:53<5:23:23, 258.72s/it]

For - 24 Loss: tensor(1.9569, device='cuda:0', grad_fn=<MeanBackward0>)


 26%|██▌       | 26/100 [1:52:11<5:18:54, 258.57s/it]

For - 25 Loss: tensor(1.9669, device='cuda:0', grad_fn=<MeanBackward0>)


 27%|██▋       | 27/100 [1:56:30<5:14:26, 258.45s/it]

For - 26 Loss: tensor(1.9591, device='cuda:0', grad_fn=<MeanBackward0>)


 28%|██▊       | 28/100 [2:00:48<5:10:05, 258.41s/it]

For - 27 Loss: tensor(1.9592, device='cuda:0', grad_fn=<MeanBackward0>)


 29%|██▉       | 29/100 [2:05:06<5:05:42, 258.34s/it]

For - 28 Loss: tensor(1.9520, device='cuda:0', grad_fn=<MeanBackward0>)


 30%|███       | 30/100 [2:09:25<5:01:30, 258.43s/it]

For - 29 Loss: tensor(1.9557, device='cuda:0', grad_fn=<MeanBackward0>)


 31%|███       | 31/100 [2:13:43<4:57:10, 258.41s/it]

For - 30 Loss: tensor(1.9552, device='cuda:0', grad_fn=<MeanBackward0>)


 32%|███▏      | 32/100 [2:18:02<4:52:52, 258.43s/it]

For - 31 Loss: tensor(1.9531, device='cuda:0', grad_fn=<MeanBackward0>)


 33%|███▎      | 33/100 [2:22:20<4:48:33, 258.41s/it]

For - 32 Loss: tensor(1.9611, device='cuda:0', grad_fn=<MeanBackward0>)


 34%|███▍      | 34/100 [2:26:38<4:44:17, 258.45s/it]

For - 33 Loss: tensor(1.9462, device='cuda:0', grad_fn=<MeanBackward0>)


 35%|███▌      | 35/100 [2:30:57<4:39:59, 258.46s/it]

For - 34 Loss: tensor(1.9536, device='cuda:0', grad_fn=<MeanBackward0>)


 36%|███▌      | 36/100 [2:35:15<4:35:41, 258.46s/it]

For - 35 Loss: tensor(1.9758, device='cuda:0', grad_fn=<MeanBackward0>)


 37%|███▋      | 37/100 [2:39:34<4:31:28, 258.54s/it]

For - 36 Loss: tensor(1.9606, device='cuda:0', grad_fn=<MeanBackward0>)


 38%|███▊      | 38/100 [2:43:52<4:27:01, 258.42s/it]

For - 37 Loss: tensor(1.9589, device='cuda:0', grad_fn=<MeanBackward0>)


 39%|███▉      | 39/100 [2:48:11<4:22:49, 258.52s/it]

For - 38 Loss: tensor(1.9639, device='cuda:0', grad_fn=<MeanBackward0>)


 40%|████      | 40/100 [2:52:30<4:18:37, 258.62s/it]

For - 39 Loss: tensor(1.9715, device='cuda:0', grad_fn=<MeanBackward0>)


 41%|████      | 41/100 [2:56:49<4:14:27, 258.77s/it]

For - 40 Loss: tensor(1.9473, device='cuda:0', grad_fn=<MeanBackward0>)


 42%|████▏     | 42/100 [3:01:08<4:10:10, 258.81s/it]

For - 41 Loss: tensor(1.9699, device='cuda:0', grad_fn=<MeanBackward0>)


 43%|████▎     | 43/100 [3:05:27<4:05:52, 258.81s/it]

For - 42 Loss: tensor(1.9643, device='cuda:0', grad_fn=<MeanBackward0>)


 44%|████▍     | 44/100 [3:09:46<4:01:34, 258.83s/it]

For - 43 Loss: tensor(1.9643, device='cuda:0', grad_fn=<MeanBackward0>)


 45%|████▌     | 45/100 [3:14:04<3:57:16, 258.85s/it]

For - 44 Loss: tensor(1.9458, device='cuda:0', grad_fn=<MeanBackward0>)


 46%|████▌     | 46/100 [3:18:24<3:53:03, 258.95s/it]

For - 45 Loss: tensor(1.9561, device='cuda:0', grad_fn=<MeanBackward0>)


 47%|████▋     | 47/100 [3:22:43<3:48:46, 259.00s/it]

For - 46 Loss: tensor(1.9421, device='cuda:0', grad_fn=<MeanBackward0>)


 48%|████▊     | 48/100 [3:27:02<3:44:27, 258.98s/it]

For - 47 Loss: tensor(1.9696, device='cuda:0', grad_fn=<MeanBackward0>)


 49%|████▉     | 49/100 [3:31:21<3:40:11, 259.05s/it]

For - 48 Loss: tensor(1.9624, device='cuda:0', grad_fn=<MeanBackward0>)


 50%|█████     | 50/100 [3:35:40<3:35:50, 259.01s/it]

For - 49 Loss: tensor(1.9516, device='cuda:0', grad_fn=<MeanBackward0>)


 51%|█████     | 51/100 [3:39:59<3:31:35, 259.09s/it]

For - 50 Loss: tensor(1.9460, device='cuda:0', grad_fn=<MeanBackward0>)


 52%|█████▏    | 52/100 [3:44:18<3:27:09, 258.95s/it]

For - 51 Loss: tensor(1.9652, device='cuda:0', grad_fn=<MeanBackward0>)


 53%|█████▎    | 53/100 [3:48:37<3:22:51, 258.96s/it]

For - 52 Loss: tensor(1.9621, device='cuda:0', grad_fn=<MeanBackward0>)


 54%|█████▍    | 54/100 [3:52:56<3:18:30, 258.92s/it]

For - 53 Loss: tensor(1.9519, device='cuda:0', grad_fn=<MeanBackward0>)


 55%|█████▌    | 55/100 [3:57:15<3:14:13, 258.96s/it]

For - 54 Loss: tensor(1.9614, device='cuda:0', grad_fn=<MeanBackward0>)


 56%|█████▌    | 56/100 [4:01:34<3:09:56, 259.02s/it]

For - 55 Loss: tensor(1.9715, device='cuda:0', grad_fn=<MeanBackward0>)


 57%|█████▋    | 57/100 [4:05:53<3:05:37, 259.01s/it]

For - 56 Loss: tensor(1.9602, device='cuda:0', grad_fn=<MeanBackward0>)


 58%|█████▊    | 58/100 [4:10:12<3:01:20, 259.06s/it]

For - 57 Loss: tensor(1.9618, device='cuda:0', grad_fn=<MeanBackward0>)


 59%|█████▉    | 59/100 [4:14:31<2:57:00, 259.03s/it]

For - 58 Loss: tensor(1.9586, device='cuda:0', grad_fn=<MeanBackward0>)


 60%|██████    | 60/100 [4:18:50<2:52:38, 258.97s/it]

For - 59 Loss: tensor(1.9680, device='cuda:0', grad_fn=<MeanBackward0>)


 61%|██████    | 61/100 [4:23:09<2:48:24, 259.09s/it]

For - 60 Loss: tensor(1.9526, device='cuda:0', grad_fn=<MeanBackward0>)


 62%|██████▏   | 62/100 [4:27:28<2:44:04, 259.07s/it]

For - 61 Loss: tensor(1.9429, device='cuda:0', grad_fn=<MeanBackward0>)


 63%|██████▎   | 63/100 [4:31:47<2:39:42, 258.98s/it]

For - 62 Loss: tensor(1.9515, device='cuda:0', grad_fn=<MeanBackward0>)


 64%|██████▍   | 64/100 [4:36:06<2:35:28, 259.12s/it]

For - 63 Loss: tensor(1.9759, device='cuda:0', grad_fn=<MeanBackward0>)


 65%|██████▌   | 65/100 [4:40:25<2:31:09, 259.13s/it]

For - 64 Loss: tensor(1.9572, device='cuda:0', grad_fn=<MeanBackward0>)


 66%|██████▌   | 66/100 [4:44:45<2:26:51, 259.16s/it]

For - 65 Loss: tensor(1.9665, device='cuda:0', grad_fn=<MeanBackward0>)


 67%|██████▋   | 67/100 [4:49:03<2:22:21, 258.83s/it]

For - 66 Loss: tensor(1.9420, device='cuda:0', grad_fn=<MeanBackward0>)


 68%|██████▊   | 68/100 [4:53:23<2:18:17, 259.28s/it]

For - 67 Loss: tensor(1.9512, device='cuda:0', grad_fn=<MeanBackward0>)


 69%|██████▉   | 69/100 [4:57:41<2:13:45, 258.89s/it]

For - 68 Loss: tensor(1.9560, device='cuda:0', grad_fn=<MeanBackward0>)


 70%|███████   | 70/100 [5:01:59<2:09:19, 258.64s/it]

For - 69 Loss: tensor(1.9647, device='cuda:0', grad_fn=<MeanBackward0>)


 71%|███████   | 71/100 [5:06:17<2:04:51, 258.33s/it]

For - 70 Loss: tensor(1.9590, device='cuda:0', grad_fn=<MeanBackward0>)


 72%|███████▏  | 72/100 [5:10:35<2:00:30, 258.25s/it]

For - 71 Loss: tensor(1.9678, device='cuda:0', grad_fn=<MeanBackward0>)


 73%|███████▎  | 73/100 [5:14:53<1:56:11, 258.19s/it]

For - 72 Loss: tensor(1.9535, device='cuda:0', grad_fn=<MeanBackward0>)


In [None]:
from torch.utils.tensorboard import SummaryWriter

In [None]:
for fold_idx, (train_indices, valid_indices) in enumerate(k_fold.split(all_img_names)):
    model_vit=torch.load('/scratch/ps4364/BTMRI/code/Cov-T/cov-t/covt-vit-{}-bMRI-with-valid.pt'.format(fold_idx))
    model_cnn=torch.load('/scratch/ps4364/BTMRI/code/Cov-T/cov-t/covt-r50-{}-bMRI-with-valid.pt'.format(fold_idx))
    last_loss=999999999
    val_loss_arr=[]
    train_loss_arr=[]
    counter=0
    
    model_cnn.to(device)
    model_vit.to(device)
    print('*'*10)
    print('For',fold_idx)
    onep_train_indices = np.random.choice(train_indices, int(len(train_indices)*0.1), replace=False) 
    fold_train_img_names = [all_img_names[idx] for idx in onep_train_indices]
    fold_train_img_labels = [all_img_labels_ts[idx] for idx in onep_train_indices] 
    fold_valid_img_names = [all_img_names[idx] for idx in valid_indices]
    
    fold_valid_img_labels = [all_img_labels_ts[idx] for idx in valid_indices]
    train_dataset = Dataset(CFG, fold_train_img_names, fold_train_img_labels, train_transform)
    valid_dataset = Dataset(CFG, fold_valid_img_names, fold_valid_img_labels, valid_transform)
    train_loader = DataLoader(train_dataset, batch_size=CFG.batch_size, shuffle=True, num_workers=CFG.num_workers, drop_last=True)
    valid_loader = DataLoader(valid_dataset, batch_size=CFG.batch_size, shuffle=False, num_workers=CFG.num_workers)
    
    
    
    
    
    #Train Correspong Supervised CNN
    print('Fine tunning Cov-T')
    writer = SummaryWriter()
    model_cnn.fc=nn.Linear(in_features=2048, out_features=4, bias=True)
    criterion = FocalLoss(cfg.fl_alpha, cfg.fl_gamma)
    metric = MyF1Score(cfg)
    val_metric=MyF1Score(cfg)
    optimizer = torch.optim.Adam(model_cnn.parameters(), lr = 3e-4)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer,T_max=cfg.t_max,eta_min=cfg.min_lr)
    model_cnn.train()
    from torch.autograd import Variable
    best=0
    best_val=0
    for epoch in tqdm(range(50)):
        total_loss = 0
        for images,label in train_loader:
            model_cnn.train()
            images = images.to(device)
            label = label.to(device)
            model_cnn.to(device)
            pred_ts=model_cnn(images)
            loss = criterion(pred_ts, label)
            score = metric(pred_ts,label)
            loss.backward()
            optimizer.step()
            optimizer.zero_grad()
            scheduler.step()
            total_loss += loss.detach()
        avg_loss=total_loss/ len(train_loader)
        train_score=metric.compute()
        logs = {'train_loss': avg_loss, 'train_f1': train_score, 'lr': optimizer.param_groups[0]['lr']}
        writer.add_scalar("CNN Supervised Loss/train", loss, epoch)
        writer.add_scalar("CNN Supervised F1/train", train_score, epoch)
        print(logs)
        if best < train_score:
            best=train_score
            model_cnn.eval()
            total_loss = 0
            for images,label in valid_loader:
                images = images.to(device)
                label = label.to(device)
                model_cnn.to(device)
                pred_ts=model_cnn(images)
                score_val = val_metric(pred_ts,label)
                val_loss = criterion(pred_ts, label)
                total_loss += val_loss.detach()
            avg_loss=total_loss/ len(train_loader)   
            print('Val Loss:',avg_loss)
            val_score=val_metric.compute()
            print('CNN Validation Score:',val_score)
            writer.add_scalar("CNN Supervised F1/Validation", val_score, epoch)
            if avg_loss > last_loss:
                counter+=1
            else:
                counter=0
                
            last_loss = avg_loss
            if counter > 5:
                print('Early Stopping!')
                break
            else:
                if val_score > best_val:
                    best_val=val_score
                    print('Saving')
                    torch.save(model_cnn,
                    '/scratch/ps4364/BTMRI/code/1p-data/Cov-t/covt-r50-label-bMRI-10p-es-{}.pt'.format(fold_idx))
    writer.flush()
    last_loss=999999999
    val_loss_arr=[]
    train_loss_arr=[]
    counter=0
    # Training the Corresponding ViT
    model_vit.head=nn.Linear(in_features=768, out_features=4, bias=True)
    criterion = FocalLoss(cfg.fl_alpha, cfg.fl_gamma)
    metric = MyF1Score(cfg)
    optimizer = torch.optim.Adam(model_vit.parameters(), lr = 3e-4)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer,T_max=cfg.t_max,eta_min=cfg.min_lr)
    model_vit.train()
    val_metric=MyF1Score(cfg)
    writer = SummaryWriter()
    from torch.autograd import Variable
    best=0
    best_val=0
    for epoch in tqdm(range(50)):
        total_loss = 0
        for images,label in train_loader:
            model_vit.train()
            images = images.to(device)
            label = label.to(device)
            model_vit.to(device)
            pred_ts=model_vit(images)
            loss = criterion(pred_ts, label)
            score = metric(pred_ts,label)
            loss.backward()
            optimizer.step()
            optimizer.zero_grad()
            scheduler.step()
            total_loss += loss.detach()
        avg_loss=total_loss/ len(train_loader)
        train_score=metric.compute()
        logs = {'train_loss': loss, 'train_f1': train_score, 'lr': optimizer.param_groups[0]['lr']}
        writer.add_scalar("ViT Supervised Loss/train", loss, epoch)
        writer.add_scalar("ViT Supervised F1/train", train_score, epoch)
        print(logs)
        if best < train_score:
            best=train_score
            model_vit.eval()
            total_loss = 0
            for images,label in valid_loader:
                images = images.to(device)
                label = label.to(device)
                model_vit.to(device)
                pred_ts=model_vit(images)
                score_val = val_metric(pred_ts,label)
                val_loss = criterion(pred_ts, label)
                total_loss += val_loss.detach()
            avg_loss=total_loss/ len(train_loader)
            val_score=val_metric.compute()
            print('ViT Validation Score:',val_score)
            print('Val Loss:',avg_loss)
            writer.add_scalar("ViT Supervised F1/Validation", val_score, epoch)
            if avg_loss > last_loss:
                counter+=1
            else:
                counter=0
                
            last_loss = avg_loss
            if counter > 5:
                print('Early Stopping!')
                break
            else:
                if val_score > best_val:
                    best_val=val_score
                    print('Saving')
                    torch.save(model_vit,
                                   '/scratch/ps4364/BTMRI/code/1p-data/Cov-t/covt-vit-bMRI-10p-es-{}.pt'.format(fold_idx))
                        
        writer.flush()                
        print('*'*10)