In [1]:
import os
import numpy as np
import pytorch_lightning as pl
import torch
import pandas as pd
import timm
import torch.nn as nn
from tqdm import tqdm
from PIL import Image
from sklearn.model_selection import KFold
from torchvision import transforms as tsfm
from torch.utils.data import Dataset, DataLoader
from pytorch_lightning import Trainer, seed_everything
from pytorch_lightning.loggers import CSVLogger
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping
from torchcontrib.optim import SWA
from torchmetrics import Metric
from torch.utils.tensorboard import SummaryWriter

In [2]:
class CFG:
    # data path
    test_csv_path = '/scratch/ps4364/BTMRI/data/test.csv'
    test_imgs_dir = '/scratch/ps4364/BTMRI/data/Testing/'
    # model info
    # label info
    label_num2str = {0: 'glioma',
                     1: 'pituitary',
                     2:'notumor',
                     3:'meningioma'
                     }
    label_str2num = {'glioma': 0,
                     'pituitary':1,
                     'notumor':2,
                     'meningioma':3
                     }
    fl_alpha = 1.0  # alpha of focal_loss
    fl_gamma = 2.0  # gamma of focal_loss
    cls_weight =  [0.2, 0.5970802919708029, 1.0, 0.25255474452554744]
    cnn_name='resnet50'
    vit_name='vit_base_patch16_384'
    seed = 77
    num_classes = 4
    batch_size = 16
    t_max = 16
    lr = 1e-3
    min_lr = 1e-6
    n_fold = 6
    num_workers = 8
    gpu_idx = 0
    device = torch.device(f'cuda:{gpu_idx}' if torch.cuda.is_available() else 'cpu')
    gpu_list = [gpu_idx]

In [3]:
#!pip show timm

In [4]:
def normalize(arr, t_min, t_max):
    norm_arr = []
    diff = t_max - t_min
    diff_arr = max(arr) - min(arr)
    for i in arr:
        temp = (((i - min(arr))*diff)/diff_arr) + t_min
        norm_arr.append(temp)
    return norm_arr
  
# assign array and range
array_1d = [1321,1457,1595,1339]
range_to_normalize = (0.2, 1)
normalized_array_1d = normalize(
    array_1d, range_to_normalize[0], 
  range_to_normalize[1])
  
# display original and normalized array
print("Original Array = ", array_1d)
print("Normalized Array = ", normalized_array_1d)

Original Array =  [1321, 1457, 1595, 1339]
Normalized Array =  [0.2, 0.5970802919708029, 1.0, 0.25255474452554744]


In [5]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
seed_everything(77)

Global seed set to 77


77

In [6]:
"""
Define train & valid image transformation
"""
DATASET_IMAGE_MEAN = (0.485, 0.456, 0.406)
DATASET_IMAGE_STD = (0.229, 0.224, 0.225)

train_transform = tsfm.Compose([tsfm.Resize((384,384)),
                                tsfm.RandomApply([tsfm.ColorJitter(0.2, 0.2, 0.2),tsfm.RandomPerspective(distortion_scale=0.2),], p=0.3),
                                tsfm.RandomApply([tsfm.RandomAffine(degrees=10),], p=0.3),
                                tsfm.RandomVerticalFlip(p=0.3),
                                tsfm.RandomHorizontalFlip(p=0.3),
                                tsfm.ToTensor(),
                                tsfm.Normalize(DATASET_IMAGE_MEAN, DATASET_IMAGE_STD), ])

valid_transform = tsfm.Compose([tsfm.Resize((384,384)),
                                tsfm.ToTensor(),
                                tsfm.Normalize(DATASET_IMAGE_MEAN, DATASET_IMAGE_STD), ])

test_transform = tsfm.Compose([tsfm.Resize((384,384)),
                                tsfm.ToTensor(),
                                tsfm.Normalize(DATASET_IMAGE_MEAN, DATASET_IMAGE_STD), ])

In [22]:
"""
Define dataset class
"""
class Dataset(Dataset):
    def __init__(self, cfg, img_names: list, labels: list, transform=None):
        self.img_dir = cfg.test_imgs_dir
        self.img_names = img_names
        self.labels = labels
        self.transform = transform

    def __len__(self):
        return len(self.img_names)

    def __getitem__(self, idx):
        img_path = self.img_names[idx]
        img = Image.open(img_path).convert('RGB')
        img_ts = self.transform(img)
        label_ts = self.labels[idx]
        return img_ts, label_ts

In [23]:
"""
Define Focal-Loss
"""

class FocalLoss(nn.Module):
    """
    The focal loss for fighting against class-imbalance
    """
    def __init__(self, alpha=1, gamma=2):
        super(FocalLoss, self).__init__()
        self.alpha = alpha
        self.gamma = gamma
        self.epsilon = 1e-12  # prevent training from Nan-loss error
        self.cls_weights = torch.tensor([CFG.cls_weight],dtype=torch.float, requires_grad=False, device=CFG.device)

    def forward(self, logits, target):
        """
        logits & target should be tensors with shape [batch_size, num_classes]
        """
        probs = torch.sigmoid(logits)
        one_subtract_probs = 1.0 - probs
        # add epsilon
        probs_new = probs + self.epsilon
        one_subtract_probs_new = one_subtract_probs + self.epsilon
        # calculate focal loss
        log_pt = target * torch.log(probs_new) + (1.0 - target) * torch.log(one_subtract_probs_new)
        pt = torch.exp(log_pt)
        focal_loss = -1.0 * (self.alpha * (1 - pt) ** self.gamma) * log_pt
        focal_loss = focal_loss * self.cls_weights
        return torch.mean(focal_loss)

In [24]:
"""
Define F1 score metric
"""
class MyF1Score(Metric):
    def __init__(self, cfg, threshold: float = 0.5, dist_sync_on_step=False):
        super().__init__(dist_sync_on_step=dist_sync_on_step)
        self.cfg = cfg
        self.threshold = threshold
        self.add_state("tp", default=torch.tensor(0), dist_reduce_fx="sum")
        self.add_state("fp", default=torch.tensor(0), dist_reduce_fx="sum")
        self.add_state("fn", default=torch.tensor(0), dist_reduce_fx="sum")

    def update(self, preds: torch.Tensor, target: torch.Tensor):
        assert preds.shape == target.shape
        preds_str_batch = self.num_to_str(torch.sigmoid(preds))
        target_str_batch = self.num_to_str(target)
        tp, fp, fn = 0, 0, 0
        for pred_str_list, target_str_list in zip(preds_str_batch, target_str_batch):
            for pred_str in pred_str_list:
                if pred_str in target_str_list:
                    tp += 1
                if pred_str not in target_str_list:
                    fp += 1

            for target_str in target_str_list:
                if target_str not in pred_str_list:
                    fn += 1
        self.tp += tp
        self.fp += fp
        self.fn += fn

    def compute(self):
        f1 = 2.0 * self.tp / (2.0 * self.tp + self.fn + self.fp)
        return f1
    
    def num_to_str(self, ts: torch.Tensor) -> list:
        batch_bool_list = (ts > self.threshold).detach().cpu().numpy().tolist()
        batch_str_list = []
        for one_sample_bool in batch_bool_list:
            lb_str_list = [self.cfg.label_num2str[lb_idx] for lb_idx, bool_val in enumerate(one_sample_bool) if bool_val]
            if len(lb_str_list) == 0:
                lb_str_list = ['healthy']
            batch_str_list.append(lb_str_list)
        return batch_str_list

In [25]:
Test_DF = pd.read_csv(CFG.test_csv_path)

In [26]:
Test_DF

Unnamed: 0.1,Unnamed: 0,path,target
0,0,/scratch/ps4364/BTMRI/data/Testing/glioma/Te-g...,glioma
1,1,/scratch/ps4364/BTMRI/data/Testing/glioma/Te-g...,glioma
2,2,/scratch/ps4364/BTMRI/data/Testing/glioma/Te-g...,glioma
3,3,/scratch/ps4364/BTMRI/data/Testing/glioma/Te-g...,glioma
4,4,/scratch/ps4364/BTMRI/data/Testing/glioma/Te-g...,glioma
...,...,...,...
1306,1306,/scratch/ps4364/BTMRI/data/Testing/meningioma/...,meningioma
1307,1307,/scratch/ps4364/BTMRI/data/Testing/meningioma/...,meningioma
1308,1308,/scratch/ps4364/BTMRI/data/Testing/meningioma/...,meningioma
1309,1309,/scratch/ps4364/BTMRI/data/Testing/meningioma/...,meningioma


In [27]:

"""
Split train & validation into Cross-Validation Folds
"""

all_img_names: list = Test_DF["path"].values.tolist()
all_img_labels: list = Test_DF["target"].values.tolist()

In [28]:
cfg=CFG()
all_img_labels_ts = []
for tmp_lb in all_img_labels:
    tmp_label = torch.zeros([CFG.num_classes], dtype=torch.float)
    place=cfg.label_str2num.get(tmp_lb)
    k=int(place)
    tmp_label[k] = 1.0
    all_img_labels_ts.append(tmp_label)

In [29]:
len(all_img_labels_ts)

1311

In [30]:
import timm
cfg=CFG()

In [33]:
test_dataset = Dataset(CFG, all_img_names, all_img_labels_ts, test_transform)

test_loader = DataLoader(test_dataset, batch_size=CFG.batch_size, shuffle=True, num_workers=CFG.num_workers, drop_last=True)



In [34]:
model_vit=torch.load('/scratch/ps4364/BTMRI/code/1p-data/Cov-t/covt-r50-label-bMRI-100p-es-0.pt')



In [35]:
model_vit.eval()
model_vit.to(device)
criterion = FocalLoss(cfg.fl_alpha, cfg.fl_gamma)
metric = MyF1Score(cfg)
val_metric=MyF1Score(cfg)
optimizer = torch.optim.Adam(model_vit.parameters(), lr = 3e-4)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer,T_max=cfg.t_max,eta_min=cfg.min_lr,verbose=True)
for images,label in test_loader:
    images = images.to(device)
    label = label.to(device)
    pred_ts=model_vit(images)
    loss = criterion((pred_ts), label)
    score = metric((pred_ts),label)
train_score=metric.compute()
logs = {'test_loss': loss, 'test_f1': train_score, 'lr': optimizer.param_groups[0]['lr']}
print(logs)

Adjusting learning rate of group 0 to 3.0000e-04.
{'test_loss': tensor(0.0300, device='cuda:0', grad_fn=<MeanBackward0>), 'test_f1': tensor(0.9122), 'lr': 0.0003}
