# CS-7643 Deep Learning Final Project

### by Zhiyin (Steven) Lu and Yixuan (Elliot) Xie

In [None]:
from google.colab import drive
drive.mount('/content/drive/')

# %cd drive/MyDrive/cs7643_final_project
%cd drive/MyDrive/cs7643
%ls

!nvidia-smi

In [None]:
pip install segmentation-models-pytorch

## 0. Environment Setup

In [None]:
import os
import time
import glob
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from joblib import Parallel, delayed
from tqdm.notebook import tqdm

from sklearn.model_selection import StratifiedGroupKFold # need scikit-learn 1.1.1
import cv2
import albumentations as A
import segmentation_models_pytorch as smp
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as T
from torch.utils.data import Dataset, DataLoader
from torch.cuda import amp

## 1. Build Modules

### 1.1 Data Transformation

In [None]:
class Transform():
    def __init__(self, img_size):
        self.img_size = img_size
    
    def baseline(self):
        train_transform = A.Compose([
            A.Resize(*self.img_size)
        ])
        valid_transform = A.Compose([
            A.Resize(*self.img_size)
        ])
        data_transform = {'train': train_transform, 'valid': valid_transform}
        return data_transform
    
    def augmentation(self):
        train_transform = A.Compose([
        A.Resize(*self.img_size),
        A.ShiftScaleRotate(shift_limit=0., scale_limit=0.1, rotate_limit=10, border_mode=0, p=0.5),
        A.RandomCrop(255, 255)
        ])  

        valid_transform = A.Compose([
            A.Resize(*self.img_size)
        ])
        data_transform = {'train': train_transform, 'valid': valid_transform}
        
        return data_transform

### 1.2 Model

In [None]:
class Model():
    def __init__(self, backbone, pretrained_weights, num_class, device):
        self.backbone = backbone
        self.pretrained_weights = pretrained_weights
        self.num_class = num_class
        self.device = device
    
    def UNet(self):
        model = smp.Unet(
            encoder_name=self.backbone,
            encoder_weights=self.pretrained_weights,
            in_channels=3,
            classes=self.num_class,
            activation=None
        )
        model.to(self.device)
        return model
    
    def UNetPlusPlus(self):
        model = smp.UnetPlusPlus(
            encoder_name=self.backbone,
            encoder_weights=self.pretrained_weights,
            in_channels=3,
            classes=self.num_class,
            activation=None
        )
        model.to(self.device)
        return model

### 1.3 Loss

In [None]:
class Loss():
    def __init__(self):
        pass
        
    def BCE(self):
        return smp.losses.SoftBCEWithLogitsLoss()
    
    def Tversky(self):
        return smp.losses.TverskyLoss(mode='multilabel', log_loss=False)

### 1.4 Metric

In [None]:
class Metric():
    def __init__(self, threshold, epsilon):
        self.threshold = threshold
        self.epsilon = epsilon
    
    def dice(self, y_pred, y_true):
        y_pred, y_true = (y_pred > self.threshold).to(torch.float32), y_true.to(torch.float32)
        inter = (y_pred * y_true).sum(dim=(-2, -1))
        score = (2 * inter + self.epsilon) / (y_true.sum(dim=(-2, -1)) + y_pred.sum(dim=(-2, -1)) + self.epsilon)
        score = score.mean(dim=(1, 0))
        return score
    
    def IOU(self, y_pred, y_true):
        y_pred, y_true = (y_pred > self.threshold).to(torch.float32), y_true.to(torch.float32)
        inter = (y_pred * y_true).sum(dim=(-2, -1))
        union = (y_pred + y_true - y_pred * y_true).sum(dim=(-2, -1))
        score = (inter + self.epsilon) / (union + self.epsilon)
        score = score.mean(dim=(1, 0))
        return score

### 1.5 Dataset and Dataloader

In [None]:
# class Dataset(Dataset):
#     def __init__(self, df, transform):
#         self.df = df
#         self.transform = transform
        
#     def __len__(self):
#         return len(self.df)
    
#     def __getitem__(self, index):
#         scan_path = self.df['scan_path'][index]
#         scan = cv2.imread(scan_path, -1)
#         # 3 channels for RGB
#         scan = np.tile(scan[..., None], [1, 1, 3]).astype('float32')
#         # normalize
#         mx = np.max(scan)
#         if mx:
#             scan /= mx
        
#         mask_path = self.df['mask_path'][index]
#         mask = np.load(mask_path).astype('float32')
#         mask /= 255.0
        
#         transformed = self.transform(image=scan, mask=mask)
#         scan, mask = transformed['image'], transformed['mask']
#         scan = torch.tensor(np.transpose(scan, (2, 0, 1)))
#         mask = torch.tensor(np.transpose(mask, (2, 0, 1)))
        
#         return scan, mask


# For 2.5d data
class Dataset(Dataset):
    def __init__(self, df, transform):
        self.df = df
        self.transform = transform
        
    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, index):
        scan_path = self.df['scan_path'][index]
        scan = np.load(scan_path).astype('float32')
        mx = np.max(scan)
        if mx:
            scan /= mx
        
        mask_path = self.df['mask_path'][index]
        mask = np.load(mask_path).astype('float32')
        mask /= 255.0
        
        transformed = self.transform(image=scan, mask=mask)
        scan, mask = transformed['image'], transformed['mask']
        scan = torch.tensor(np.transpose(scan, (2, 0, 1)))
        mask = torch.tensor(np.transpose(mask, (2, 0, 1)))
        
        return scan, mask
    
def get_dataloader(df, fold, data_transform, train_batch_size, valid_batch_size, num_workers):
    train_df = df[df['fold'] == fold].reset_index(drop=True)
    valid_df = df[df['fold'] != fold].reset_index(drop=True)
    train_dataset = Dataset(train_df, transform=data_transform['train'])
    valid_dataset = Dataset(valid_df, transform=data_transform['valid'])
    train_loader = DataLoader(train_dataset, batch_size=train_batch_size, num_workers=num_workers, shuffle=True, pin_memory=True, drop_last=False)
    valid_loader = DataLoader(valid_dataset, batch_size=valid_batch_size, num_workers=num_workers, shuffle=False, pin_memory=True)
    
    return train_loader, valid_loader

### 1.6 Training and Validation Loop

In [None]:
def train_one_epoch(train_loader, model, optimizer, loss_fn, device):
    model.train()
    scaler = amp.GradScaler()
    total_loss = 0
    pbar = tqdm(enumerate(train_loader), total=len(train_loader), desc='Training ...')
    for i, sample in pbar:
        optimizer.zero_grad()
        scans, masks = sample
        scans, masks = scans.to(device, dtype=torch.float), masks.to(device, dtype=torch.float)
        with amp.autocast():
            y_preds = model(scans)
            bce_loss = loss_fn.BCE()(y_preds, masks)
            tverskly_loss = loss_fn.Tversky()(y_preds, masks)
            loss = 0.5 * bce_loss + 0.5 * tverskly_loss
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
        total_loss += loss.item() / scans.shape[0]
    curr_lr = optimizer.param_groups[0]['lr']
    print(f'CURRENT LEARNING RATE: {curr_lr:.3f}, TOTAL LOSS: {total_loss:.3f}', flush=True)

    return total_loss

@torch.no_grad()
def valid_one_epoch(valid_loader, model, metric, device):
    model.eval()
    scores = []
    pbar = tqdm(enumerate(valid_loader), total=len(valid_loader), desc='Validating ...')
    for i, sample in pbar:
        scans, masks = sample
        scans, masks = scans.to(device, dtype=torch.float), masks.to(device, dtype=torch.float)
        sigmoid = nn.Sigmoid()
        y_preds = sigmoid(model(scans))
        dice_score = metric.dice(y_preds, masks).cpu().detach().numpy()
        iou_score = metric.IOU(y_preds, masks).cpu().detach().numpy()
        scores.append([dice_score, iou_score])
    scores = np.mean(scores, axis=0)
    print(f'DICE SCORE: {scores[0]:.3f}, IOU SCORE: {scores[1]:.3f}', flush=True)
    
    return scores

## 2. Execution

### 2.1 Configuration

In [None]:
# Training
NUM_FOLD = 3
EPOCH = 10
IMG_SIZE = (352, 352)
TRAIN_BATCH_SIZE = 64
VALID_BATCH_SIZE = 128
NUM_WORKERS = 2

# Model
BACKBONE = 'efficientnet-b1'
PRETRAINED_WEIGHTS = 'imagenet'
NUM_CLASS = 3

# Optimizer
LEARNING_RATE = 0.001
WEIGHT_DECAY = 0.00001
# LEARNING_RATE_DROP = 2

# Metric
THRESHOLD = 0.45
EPSILON = 0.001

# Device and Directory
EXPERIMENT_NAME = 'UNet_2.5d'
DEVICE = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
CKPT_SAVE_PATH = f'./checkpoints/{EXPERIMENT_NAME}_batchsize{TRAIN_BATCH_SIZE}_fold{NUM_FOLD}'
if not os.path.exists(CKPT_SAVE_PATH):
    os.makedirs(CKPT_SAVE_PATH)
if not os.path.exists(f'{CKPT_SAVE_PATH}/trained'):
    os.makedirs(f'{CKPT_SAVE_PATH}/trained')
if not os.path.exists(f'{CKPT_SAVE_PATH}/best'):
    os.makedirs(f'{CKPT_SAVE_PATH}/best')
print(f'EXPERIMENT NAME: {EXPERIMENT_NAME}')
print(f'CURRENT DEVICE: {DEVICE}, CHECKPOINT SAVE PATH: {CKPT_SAVE_PATH}')

### 2.2 Load Data

In [None]:
# Load Data from .csv file
# load_dir = './data/data.csv'
load_dir = './2.5d_data/data.csv'
# load_dir = './data/debug.csv'
data = pd.read_csv(load_dir)

### 2.3 Divide for K-fold Cross-Validation 

In [None]:
# Divide the Dataset for Cross-Validation 
cv = StratifiedGroupKFold(n_splits=NUM_FOLD, shuffle=True, random_state=14)
for fold, (train_idx, val_idx) in enumerate(cv.split(data, data['empty'], data['case'])):
    data.loc[val_idx, 'fold'] = fold

display(data)

### 2.4 Train and Evaluate

In [None]:
RESUME = False

if RESUME:
    # if resume from training, set SKIP_TRAIN as False and set LAST_FOLD, LAST_EPOCH to the next of the loaded checkpoint 
    # (e.g. LAST_FOLD=2, LAST_EPOCH=6 for loading trained_fold2_epoch5.pth)
    # if resume from validation, set SKIP_TRAIN as True and set LAST_FOLD and LAST_EPOCH as in the name of the loaded checkpoint
    CKPT_LOAD_PATH = './checkpoints/UNet_batchsize4_fold4/trained/trained_fold3_epoch6.pth'
    LAST_FOLD = 3
    LAST_EPOCH = 6
    SKIP_TRAIN = True
else: 
    LAST_FOLD = 0
    LAST_EPOCH = 0
    SKIP_TRAIN = False

In [None]:
fold_loss, fold_dice, fold_iou = [], [], []
for fold in range(LAST_FOLD, NUM_FOLD):
    # Initialize Modules
    DATA_TRANSFORM = Transform(IMG_SIZE).baseline()
    MODEL = Model(BACKBONE, PRETRAINED_WEIGHTS, NUM_CLASS, DEVICE).UNet()
    LOSS_FN = Loss()
    METRIC = Metric(THRESHOLD, EPSILON)
    OPTIMIZER = torch.optim.AdamW(MODEL.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY)
    # LR_SCHEDULER = torch.optim.lr_scheduler.StepLR(OPTIMIZER, LEARNING_RATE_DROP)
    if RESUME:
        print()
        print(f'---------------- RESUME FROM {CKPT_LOAD_PATH}, LAST FOLD: {LAST_FOLD}, LAST EPOCH: {LAST_EPOCH} ----------------', flush=True)
        # load checkpoint
        CKPT = torch.load(CKPT_LOAD_PATH)
        MODEL.load_state_dict(CKPT['model_state_dict'])
        OPTIMIZER.load_state_dict(CKPT['optimizer_state_dict'])
        # LR_SCHEDULER.load_state_dict(CKPT['lr_scheduler_state_dict'])
        RESUME = False
    print()
    print(f'----------------------------- FOLD {fold} STARTS -----------------------------', flush=True)
    train_loader, valid_loader = get_dataloader(data, fold, DATA_TRANSFORM, TRAIN_BATCH_SIZE, VALID_BATCH_SIZE, NUM_WORKERS)
    best_score, best_epoch = 0, 0
    epoch_loss, epoch_dice, epoch_iou = [], [], []
    for epoch in range(LAST_EPOCH, EPOCH):
        print()
        print(f'------------------------- EPOCH {epoch} STARTS -------------------------', flush=True)
        start_time = time.time()
        # training
        if not SKIP_TRAIN:
            loss = train_one_epoch(train_loader, MODEL, OPTIMIZER, LOSS_FN, DEVICE)
            epoch_loss.append(loss)
            # LR_SCHEDULER.step()
            save_trained_path = f'{CKPT_SAVE_PATH}/trained/trained_fold{fold}_epoch{epoch}.pth'
            if os.path.isfile(save_trained_path):
                os.remove(save_trained_path)
            torch.save({'model_state_dict': MODEL.state_dict(), 
                        'optimizer_state_dict': OPTIMIZER.state_dict(),
                        # 'lr_scheduler_state_dict': LR_SCHEDULER.state_dict()
                        }, save_trained_path)
        else:
            print('------------------ SKIPPED TRAINING ------------------', flush=True)
            SKIP_TRAIN = False

        # validation
        curr_scores = valid_one_epoch(valid_loader, MODEL, METRIC, DEVICE)
        dice_score, iou_score = curr_scores[0], curr_scores[1]
        epoch_dice.append(dice_score)
        epoch_iou.append(iou_score)
        if dice_score > best_score:
            best_score = dice_score
            best_epoch = epoch
            save_best_path = f'{CKPT_SAVE_PATH}/best/best_fold{fold}_epoch{epoch}_score{best_score:.3f}.pth'
            if os.path.isfile(save_best_path):
                os.remove(save_best_path)
            torch.save(MODEL.state_dict(), save_best_path)
        os.remove(save_trained_path)

        epoch_time = time.time() - start_time
        print(f'-------------------------- EPOCH {epoch} ENDS --------------------------')
        print(f'FOLD: {fold}, EPOCH: {epoch}, TIME ELAPSED: {epoch_time:.3f}, BEST SCORE: {best_score:.3f} BEST EPOCH: {best_epoch}', flush=True)
        print()
    fold_loss.append(epoch_loss)
    fold_dice.append(epoch_dice)
    fold_iou.append(epoch_iou)
    print(f'------------------------------ FOLD {fold} ENDS ------------------------------', flush=True)

### 2.5 Plot Loss and Scores

In [None]:
IMAGE_SAVE_PATH = f'./images/{EXPERIMENT_NAME}_batchsize{TRAIN_BATCH_SIZE}_fold{NUM_FOLD}'
if not os.path.exists(IMAGE_SAVE_PATH):
    os.makedirs(IMAGE_SAVE_PATH)

# Plot Loss
for i, loss in enumerate(fold_loss):
    plt.plot(loss, label=f'Fold {i}')
plt.title('Change of Loss For Each Fold')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.grid()
plt.legend()
plt.savefig(f'{IMAGE_SAVE_PATH}/loss.pdf')
plt.close()

# Plot Dice Coefficients
for i, score in enumerate(fold_dice):
    plt.plot(score, label=f'Fold {i}')
plt.title('Change of Dice Coefficient For Each Fold')
plt.xlabel('Epoch')
plt.ylabel('Score')
plt.grid()
plt.legend()
plt.savefig(f'{IMAGE_SAVE_PATH}/dice_scores.pdf')
plt.close()

# Plot IOU Coefficients
for i, score in enumerate(fold_iou):
    plt.plot(score, label=f'Fold {i}')
plt.title('Change of IOU Coefficient For Each Fold')
plt.xlabel('Epoch')
plt.ylabel('Score')
plt.grid()
plt.legend()
plt.savefig(f'{IMAGE_SAVE_PATH}/IOU_scores.pdf')
plt.close()