<a href="https://www.kaggle.com/code/vovanquangnbk/uwmgi-train?scriptVersionId=150483418" target="_blank"><img align="left" alt="Kaggle" title="Open in Kaggle" src="https://kaggle.com/static/images/open-in-kaggle.svg"></a>

## Initial Setting

In [1]:
!pip install -q segmentation_models_pytorch

In [2]:
import numpy as np
import pandas as pd
pd.options.plotting.backend = "plotly"
import random
import glob
import os, shutil
import re
from tqdm import tqdm
tqdm.pandas()
import time
import copy
import joblib
from collections import defaultdict
import gc
from IPython import display as ipd

# visualization
import cv2
import matplotlib.pyplot as plt
from matplotlib.patches import Rectangle

# Sklearn
from sklearn.model_selection import StratifiedKFold, KFold, StratifiedGroupKFold

# PyTorch 
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
from torch.utils.data import Dataset, DataLoader
from torch.cuda import amp

import timm

# Albumentations for augmentations
import albumentations as A
from albumentations.pytorch import ToTensorV2

import rasterio
from joblib import Parallel, delayed

# For colored terminal text
from colorama import Fore, Back, Style
c_  = Fore.GREEN
sr_ = Style.RESET_ALL

import warnings
warnings.filterwarnings("ignore")

# For descriptive error messages
os.environ['CUDA_LAUNCH_BLOCKING'] = "1"



In [3]:
class CFG:
    seed          = 101
    debug         = True # set debug=False for Full Training
    exp_name      = '2.5D'
    comment       = 'unet-efficientnet_b0-160x192-ep=5'
    model_name    = 'Unet'
    backbone      = 'efficientnet-b0'
    train_bs      = 64
    valid_bs      = train_bs*2
    img_size      = [320, 320]
    epochs        = 10
    lr            = 1e-4
    scheduler     = 'CosineAnnealingLR'
    min_lr        = 1e-6
    T_max         = int(30000/train_bs*epochs)+50
    T_0           = 25
    warmup_epochs = 0
    wd            = 1e-6
    n_accumulate  = max(1, 32//train_bs)
    n_fold        = 5
    folds         = [0]
    num_workers   = 2
    num_classes   = 3
    device        = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [4]:
def set_seed(seed = 42):
    '''Sets the seed of the entire notebook so results are the same every time we run.
    This is for REPRODUCIBILITY.'''
    np.random.seed(seed)
    random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    # When running on the CuDNN backend, two further options must be set
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    # Set a fixed value for the hash seed
    os.environ['PYTHONHASHSEED'] = str(seed)
    print('> SEEDING DONE')

## Meta Data

In [5]:
## Slice infomation
def create_csv(data_dir):
    df = {
        'id':[], 'f_path':[], 
        'slice_h':[], 'slice_w':[], 'px_spacing_h':[], 'px_spacing_w':[], 
        'case_id_str':[], 'case_id':[],
        'day_num_str':[], 'day_num':[], 
        'slice_id':[]
    }
    for case in glob.glob(data_dir + '/*'):
        case_id_str = case[case.rfind("/")+1:]
        case_id = int(re.findall(r'(\d+)', case_id_str)[0])
        
        for day in glob.glob(os.path.join(data_dir, case_id_str) + '/*'):
            case_day_id_str = day[day.rfind("/")+1:]
            day_num_str = case_day_id_str.split("_")[1]
            day_num = int(re.findall(r'(\d+)', day_num_str)[0])
            
            for f in glob.glob(os.path.join(data_dir, case, case_day_id_str, "scans") + '/*'):
                f_name = f[f.rfind("/")+1:f.rfind(".")]
                slice_id = '_'.join(f_name.split('_')[:2])
                df['slice_id'].append(slice_id)
                df['id'].append('_'.join([case_id_str, day_num_str, slice_id]))
                df['f_path'].append(f)
                df['slice_w'].append(int(f_name.split('_')[2]))
                df['slice_h'].append(int(f_name.split('_')[3]))
                df['px_spacing_w'].append(float(f_name.split('_')[4]))
                df['px_spacing_h'].append(float(f_name.split('_')[5]))
                df['case_id_str'].append(case_id_str)
                df['case_id'].append(case_id)
                df['day_num_str'].append(day_num_str)
                df['day_num'].append(day_num)
                
    df = pd.DataFrame(df)
    return df

data_dir = '/kaggle/input/uw-madison-gi-tract-image-segmentation/train'
slice_info = create_csv(data_dir)
print(len(slice_info))

38496


In [6]:
## Segmentation infomation
segment_info = pd.read_csv('/kaggle/input/uw-madison-gi-tract-image-segmentation/train.csv')
segment_info = segment_info.groupby('id').agg(list).reset_index()
segment_info['n_segs'] = segment_info['segmentation'].apply(lambda row: sum(~pd.isnull(row)))
segment_info['lb_seg_flag'] = segment_info['segmentation'].apply(lambda row: ~pd.isnull([row[0]])[0])
segment_info['lb_seg_rle'] = segment_info['segmentation'].apply(lambda row: row[0])
segment_info['sb_seg_flag'] = segment_info['segmentation'].apply(lambda row: ~pd.isnull([row[1]])[0])
segment_info['sb_seg_rle'] = segment_info['segmentation'].apply(lambda row: row[1])
segment_info['st_seg_flag'] = segment_info['segmentation'].apply(lambda row: ~pd.isnull([row[2]])[0])
segment_info['st_seg_rle'] = segment_info['segmentation'].apply(lambda row: row[2])
segment_info.drop(['segmentation'], axis=1, inplace=True)

print(len(segment_info))

38496


In [7]:
train = segment_info.merge(slice_info, how='inner', on='id')

## Remove Faults
fault1 = 'case7_day0'
fault2 = 'case81_day30'
train = train[~train['id'].str.contains(fault1) & ~train['id'].str.contains(fault2)].reset_index(drop=True)
train.head()

Unnamed: 0,id,class,n_segs,lb_seg_flag,lb_seg_rle,sb_seg_flag,sb_seg_rle,st_seg_flag,st_seg_rle,f_path,slice_h,slice_w,px_spacing_h,px_spacing_w,case_id_str,case_id,day_num_str,day_num,slice_id
0,case101_day20_slice_0001,"[large_bowel, small_bowel, stomach]",0,False,,False,,False,,/kaggle/input/uw-madison-gi-tract-image-segmen...,266,266,1.5,1.5,case101,101,day20,20,slice_0001
1,case101_day20_slice_0002,"[large_bowel, small_bowel, stomach]",0,False,,False,,False,,/kaggle/input/uw-madison-gi-tract-image-segmen...,266,266,1.5,1.5,case101,101,day20,20,slice_0002
2,case101_day20_slice_0003,"[large_bowel, small_bowel, stomach]",0,False,,False,,False,,/kaggle/input/uw-madison-gi-tract-image-segmen...,266,266,1.5,1.5,case101,101,day20,20,slice_0003
3,case101_day20_slice_0004,"[large_bowel, small_bowel, stomach]",0,False,,False,,False,,/kaggle/input/uw-madison-gi-tract-image-segmen...,266,266,1.5,1.5,case101,101,day20,20,slice_0004
4,case101_day20_slice_0005,"[large_bowel, small_bowel, stomach]",0,False,,False,,False,,/kaggle/input/uw-madison-gi-tract-image-segmen...,266,266,1.5,1.5,case101,101,day20,20,slice_0005


## Helper functions

In [8]:
## RLE ENCODING AND DECODING
# ref.: https://www.kaggle.com/stainsby/fast-tested-rle
def mask2rle(img):
    """
    img: numpy array, 1 - mask, 0 - background
    Returns run length as string formatted
    """
    pixels = img.flatten()
    pixels = np.concatenate([[0], pixels, [0]])
    runs = np.where(pixels[1:] != pixels[:-1])[0] + 1
    runs[1::2] -= runs[::2]
    return ' '.join(str(x) for x in runs)


def rle2mask(mask_rle: str, label=1, shape=(266, 266)):
    """
    mask_rle: run-length as string formatted (start length)
    shape: (height,width) of array to return
    Returns numpy array, 1 - mask, 0 - background
    """
    s = mask_rle.split()
    starts, lengths = [np.asarray(x, dtype=int) for x in (s[0:][::2], s[1:][::2])]
    starts -= 1
    ends = starts + lengths
    img = np.zeros(shape[0] * shape[1], dtype=np.uint8)
    for lo, hi in zip(starts, ends):
        img[lo:hi] = label
    return img.reshape(shape)  # Needed to align to RLE direction

In [9]:
def center_pad_crop(img, new_size):
    # Height, Width, and Channel
    new_h, new_w = new_size
    
    if len(img.shape) == 3:
        old_h, old_w, c = img.shape
        max_h = np.max([old_h, new_h])
        max_w = np.max([old_w, new_w])
        max_img = np.zeros((max_h, max_w, c), dtype="float32")
    else:
        old_h, old_w = img.shape
        max_h = np.max([old_h, new_h])
        max_w = np.max([old_w, new_w])
        max_img = np.zeros((max_h, max_w), dtype="float32")
    
    # pad image first
    offset_h_max = (max_h - old_h) // 2
    offset_w_max = (max_w - old_w) // 2
    max_img[offset_h_max : offset_h_max + old_h, 
            offset_w_max : offset_w_max + old_w] = img
    
    # crop image
    offset_h_new = (max_h - new_h) // 2
    offset_w_new = (max_w - new_w) // 2
    
    return max_img[offset_h_new : offset_h_new + new_h, 
                   offset_w_new : offset_w_new + new_w]

if CFG.debug:
    a = np.array([[[1,2,3,4], [5,6,7,8], [9,10,11,12]],
                 [[1,2,3,4], [5,6,7,8], [9,10,11,12]],
                 [[1,2,3,4], [5,6,7,8], [9,10,11,12]]])
#     a = np.array([[1,2,3,4], [5,6,7,8], [9,10,11,12]])
    print(a)
    print(a.shape)
    print(center_pad_crop(a, (2,8)))
    print(center_pad_crop(a, (2,8)).shape)

[[[ 1  2  3  4]
  [ 5  6  7  8]
  [ 9 10 11 12]]

 [[ 1  2  3  4]
  [ 5  6  7  8]
  [ 9 10 11 12]]

 [[ 1  2  3  4]
  [ 5  6  7  8]
  [ 9 10 11 12]]]
(3, 3, 4)
[[[ 0.  0.  0.  0.]
  [ 0.  0.  0.  0.]
  [ 1.  2.  3.  4.]
  [ 5.  6.  7.  8.]
  [ 9. 10. 11. 12.]
  [ 0.  0.  0.  0.]
  [ 0.  0.  0.  0.]
  [ 0.  0.  0.  0.]]

 [[ 0.  0.  0.  0.]
  [ 0.  0.  0.  0.]
  [ 1.  2.  3.  4.]
  [ 5.  6.  7.  8.]
  [ 9. 10. 11. 12.]
  [ 0.  0.  0.  0.]
  [ 0.  0.  0.  0.]
  [ 0.  0.  0.  0.]]]
(2, 8, 4)


## Dataset

In [10]:
from albumentations.pytorch import ToTensorV2
data_transforms = {
    "train": A.Compose([
        A.HorizontalFlip(p=0.5),
        A.VerticalFlip(p=0.5),
        A.ShiftScaleRotate(shift_limit=0.0625, scale_limit=0.05, rotate_limit=10, p=0.5),
        A.OneOf([
            A.GridDistortion(num_steps=5, distort_limit=0.05, p=1.0),
            A.OpticalDistortion(distort_limit=0.05, shift_limit=0.05, p=1.0),
            A.ElasticTransform(alpha=1, sigma=50, alpha_affine=50, p=1.0)
        ], p=0.25),
        A.CoarseDropout(max_holes=8, max_height=CFG.img_size[0]//20, max_width=CFG.img_size[1]//20,
                         min_holes=5, fill_value=0, mask_fill_value=0, p=0.5),
        A.Normalize(p=1.0),
        ToTensorV2(p=1.0),
        ], p=1.0),
    
    "valid": A.Compose([
        A.Normalize(p=1.0),
        ToTensorV2(p=1.0),
        ], p=1.0)
}

In [11]:
class UWMDataset(Dataset):
    def __init__(self, df, label=True, transforms=None):
        self.df = df.copy()
        self.label = label
        self.transforms = transforms
        
    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, index):
        f_path = self.df.iloc[index]['f_path']
        img = cv2.imread(f_path)
        img = img.astype('float32')
        img_shape = img.shape[:2]
        
        if self.label:
            binary_masks = self.create_mask(index, img_shape)
            # Resize
            img = center_pad_crop(img, CFG.img_size)
            binary_masks = [center_pad_crop(mask, CFG.img_size) for mask in binary_masks]
            # Transform
            if self.transforms:
                data = self.transforms(image=img, masks=binary_masks)
                img  = data['image']
                binary_masks  = data['masks']
            binary_masks = torch.stack(binary_masks, dim=0)
            return img.float(), binary_masks.float()
        else:
            # Resize
            img = center_pad_crop(img, CFG.img_size)
            # Transform
            if self.transforms:
                data = self.transforms(image=img)
                img  = data['image']
            return img.float()
    
    def create_mask(self, index, shape=None):
        flags = self.df.iloc[index][['lb_seg_flag', 'sb_seg_flag', 'st_seg_flag']].tolist()
        seg_rle = self.df.iloc[index][['lb_seg_rle', 'sb_seg_rle', 'st_seg_rle']].tolist()

        binary_masks = []
        for i in range(len(flags)):
            if flags[i]:
                binary_masks.append(rle2mask(seg_rle[i], 1, shape))
            else:
                binary_masks.append(np.zeros(shape))

        return binary_masks

if CFG.debug:
    print("TRAIN:\n")
    train_dataset = UWMDataset(train, transforms=data_transforms['train'])
    for i, (x, mask) in enumerate(train_dataset):
        try:
            print(x.shape, mask.shape)
            print(x.mean(), mask.mean())
            if i > 0:
                break
        except: 
            print(x)
            print(mask)
            print("Error!")
            break
    print("\n\n")
    
    print("VALID:\n")
    valid_dataset = UWMDataset(train, transforms=data_transforms['valid'])
    for i, (x, mask) in enumerate(valid_dataset):
        print(x.shape, mask.shape)
        if i > 0:
            break
    print("\n\n")

TRAIN:

torch.Size([3, 320, 320]) torch.Size([3, 320, 320])
tensor(-1.9860) tensor(0.)
torch.Size([3, 320, 320]) torch.Size([3, 320, 320])
tensor(-1.9860) tensor(0.)



VALID:

torch.Size([3, 320, 320]) torch.Size([3, 320, 320])
torch.Size([3, 320, 320]) torch.Size([3, 320, 320])





In [12]:
def prepare_dataloader(ids, df=train, train_mode='train'):
    df_ = df.loc[ids,:].reset_index(drop=True)

    if train_mode == 'train':
        dataset = UWMDataset(df_, transforms=data_transforms[train_mode])
        loader = DataLoader(
            dataset, 
            batch_size=CFG.train_bs,
            num_workers=CFG.num_workers, 
            shuffle=True, 
            pin_memory=True, 
            drop_last=False)
    else:
        dataset = UWMDataset(df_)
        loader = DataLoader(
            dataset, 
            batch_size=CFG.valid_bs,
            num_workers=CFG.num_workers, 
            shuffle=False, 
            pin_memory=True)    
    
    return loader

if CFG.debug:
    trn_idx = [_ for _ in range(100)]
    train_loader = prepare_dataloader(trn_idx, train, 'train')
    val_loader = prepare_dataloader(trn_idx, train)

    print("TRAIN:\n")
    for i, (x, mask) in enumerate(train_loader):
        x = x.to(CFG.device)
        mask = mask.to(CFG.device)
        print(x.shape, mask.shape)
        print(x.mean(), mask.mean())
        if i > 5:
            break
    print("\n\n")
    
    print("VALID:\n")
    for i, (x, mask) in enumerate(val_loader):
        x = x.to(CFG.device)
        mask = mask.to(CFG.device)
        print(x.shape, mask.shape)
        print(x.mean(), mask.mean())
        if i > 5:
            break
    print("\n\n")

TRAIN:

torch.Size([64, 3, 320, 320]) torch.Size([64, 3, 320, 320])
tensor(-1.9855, device='cuda:0') tensor(0.0028, device='cuda:0')
torch.Size([36, 3, 320, 320]) torch.Size([36, 3, 320, 320])
tensor(-1.9855, device='cuda:0') tensor(0.0033, device='cuda:0')



VALID:

torch.Size([64, 3, 320, 320]) torch.Size([64, 3, 320, 320])
tensor(-1.9855, device='cuda:0') tensor(0.0026, device='cuda:0')
torch.Size([36, 3, 320, 320]) torch.Size([36, 3, 320, 320])
tensor(-1.9855, device='cuda:0') tensor(0.0035, device='cuda:0')





## Model

In [13]:
import segmentation_models_pytorch as smp

def build_model():
    model = smp.Unet(
        encoder_name=CFG.backbone, 
        encoder_weights="imagenet",
        in_channels=3,
        classes=CFG.num_classes,
        activation=None,
    )
    model.to(CFG.device)
    return model

def load_model(path):
    model = build_model()
    model.load_state_dict(torch.load(path))
    model.eval()
    return model

In [14]:
if CFG.debug:
    model = build_model()
    imgs, msks = next(iter(train_loader))
    imgs = imgs.to(CFG.device)
    msks = msks.to(CFG.device)
    print(imgs.shape)
    print(msks.shape)
    
    with torch.no_grad():
        preds = model(imgs)
    print(preds.shape)

Downloading: "https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b0-355c32eb.pth" to /root/.cache/torch/hub/checkpoints/efficientnet-b0-355c32eb.pth
100%|██████████| 20.4M/20.4M [00:00<00:00, 146MB/s] 


torch.Size([64, 3, 320, 320])
torch.Size([64, 3, 320, 320])
torch.Size([64, 3, 320, 320])


## Loss function

In [15]:
JaccardLoss = smp.losses.JaccardLoss(mode='multilabel')
DiceLoss    = smp.losses.DiceLoss(mode='multilabel')
BCELoss     = smp.losses.SoftBCEWithLogitsLoss()
LovaszLoss  = smp.losses.LovaszLoss(mode='multilabel', per_image=False)
TverskyLoss = smp.losses.TverskyLoss(mode='multilabel', log_loss=False)

def dice_coef(y_true, y_pred, thr=0.5, dim=(2,3), epsilon=0.001):
    y_true = y_true.to(torch.float32)
    y_pred = (y_pred>thr).to(torch.float32)
    inter = (y_true*y_pred).sum(dim=dim)
    den = y_true.sum(dim=dim) + y_pred.sum(dim=dim)
    dice = ((2*inter+epsilon)/(den+epsilon)).mean(dim=(1,0))
    return dice

def iou_coef(y_true, y_pred, thr=0.5, dim=(2,3), epsilon=0.001):
    y_true = y_true.to(torch.float32)
    y_pred = (y_pred>thr).to(torch.float32)
    inter = (y_true*y_pred).sum(dim=dim)
    union = (y_true + y_pred - y_true*y_pred).sum(dim=dim)
    iou = ((inter+epsilon)/(union+epsilon)).mean(dim=(1,0))
    return iou

def criterion(y_pred, y_true):
    return 0.5*BCELoss(y_pred, y_true) + 0.5*TverskyLoss(y_pred, y_true)

In [16]:
if CFG.debug:
    print(criterion(msks, preds))

tensor(0.4989, device='cuda:0')


## Training functions

In [17]:
def train_one_epoch(model, optimizer, scheduler, dataloader, device, epoch):
    model.train()
    scaler = amp.GradScaler()
    
    dataset_size = 0
    running_loss = 0.0
    
    pbar = tqdm(enumerate(dataloader), total=len(dataloader), desc='Train ')
    for step, (images, masks) in pbar:         
        images = images.to(device, dtype=torch.float)
        masks  = masks.to(device, dtype=torch.float)
        
        batch_size = images.size(0)
        
        with amp.autocast(enabled=True):
            y_pred = model(images)
            loss   = criterion(y_pred, masks)
            loss   = loss / CFG.n_accumulate
            
        scaler.scale(loss).backward()
    
        if (step + 1) % CFG.n_accumulate == 0:
            scaler.step(optimizer)
            scaler.update()

            # zero the parameter gradients
            optimizer.zero_grad()

            if scheduler is not None:
                scheduler.step()
                
        running_loss += (loss.item() * batch_size)
        dataset_size += batch_size
        
        epoch_loss = running_loss / dataset_size
        
        mem = torch.cuda.memory_reserved() / 1E9 if torch.cuda.is_available() else 0
        current_lr = optimizer.param_groups[0]['lr']
        pbar.set_postfix(train_loss=f'{epoch_loss:0.4f}',
                        lr=f'{current_lr:0.5f}',
                        gpu_mem=f'{mem:0.2f} GB')
        torch.cuda.empty_cache()
        gc.collect()
    
    return epoch_loss

In [18]:
@torch.no_grad()
def valid_one_epoch(model, optimizer, dataloader, device, epoch):
    model.eval()
    
    dataset_size = 0
    running_loss = 0.0
    
    val_scores = []
    
    pbar = tqdm(enumerate(dataloader), total=len(dataloader), desc='Valid ')
    for step, (images, masks) in pbar:        
        images  = images.to(device, dtype=torch.float)
        masks   = masks.to(device, dtype=torch.float)
        
        batch_size = images.size(0)
        
        y_pred  = model(images)
        loss    = criterion(y_pred, masks)
        
        running_loss += (loss.item() * batch_size)
        dataset_size += batch_size
        
        epoch_loss = running_loss / dataset_size
        
        y_pred = nn.Sigmoid()(y_pred)
        val_dice = dice_coef(masks, y_pred).cpu().detach().numpy()
        val_jaccard = iou_coef(masks, y_pred).cpu().detach().numpy()
        val_scores.append([val_dice, val_jaccard])
        
        mem = torch.cuda.memory_reserved() / 1E9 if torch.cuda.is_available() else 0
        current_lr = optimizer.param_groups[0]['lr']
        pbar.set_postfix(valid_loss=f'{epoch_loss:0.4f}',
                        lr=f'{current_lr:0.5f}',
                        gpu_memory=f'{mem:0.2f} GB')
    val_scores  = np.mean(val_scores, axis=0)
    torch.cuda.empty_cache()
    gc.collect()
    
    return epoch_loss, val_scores

In [19]:
def run_training(model, train_loader, valid_loader, optimizer, scheduler, device, fold, num_epochs):
    if torch.cuda.is_available():
        print("cuda: {}\n".format(torch.cuda.get_device_name()))
    
    start = time.time()
    best_model_wts = model.state_dict()
    best_dice = -np.inf
    best_epoch = -1
    history = defaultdict(list)
    
    for epoch in range(1, num_epochs + 1): 
        gc.collect()
        print(f'Epoch {epoch}/{num_epochs}', end='')
        train_loss = train_one_epoch(model, optimizer, scheduler, 
                                           dataloader=train_loader, 
                                           device=CFG.device, epoch=epoch)
        
        val_loss, val_scores = valid_one_epoch(model, optimizer, valid_loader, 
                                                 device=CFG.device, 
                                                 epoch=epoch)
        val_dice, val_jaccard = val_scores
    
        history['Train Loss'].append(train_loss)
        history['Valid Loss'].append(val_loss)
        history['Valid Dice'].append(val_dice)
        history['Valid Jaccard'].append(val_jaccard)
        
        print(f'Valid Dice: {val_dice:0.4f} | Valid Jaccard: {val_jaccard:0.4f}')
        
        # deep copy the model
        if val_dice >= best_dice:
            print(f"{c_}Valid Score Improved ({best_dice:0.4f} ---> {val_dice:0.4f})")
            best_dice = val_dice
            best_jaccard = val_jaccard
            best_epoch   = epoch
            best_model_wts = model.state_dict()
            PATH = f"best_epoch-{fold:02d}.bin"
            torch.save(model.state_dict(), PATH)
            print(f"Model Saved")
            
        last_model_wts = model.state_dict()
        PATH = f"last_epoch-{fold:02d}.bin"
        torch.save(model.state_dict(), PATH)
            
        print(); print()
    
    end = time.time()
    time_elapsed = end - start
    print('Training complete in {:.0f}h {:.0f}m {:.0f}s'.format(
        time_elapsed // 3600, (time_elapsed % 3600) // 60, (time_elapsed % 3600) % 60))
    print("Best Score: {:.4f}".format(best_jaccard))
    
    # load best model weights
    model.load_state_dict(best_model_wts)
    
    return model, history

## Optimizer

In [20]:
def fetch_scheduler(optimizer):
    if CFG.scheduler == 'CosineAnnealingLR':
        scheduler = lr_scheduler.CosineAnnealingLR(optimizer,T_max=CFG.T_max, 
                                                   eta_min=CFG.min_lr)
    elif CFG.scheduler == 'CosineAnnealingWarmRestarts':
        scheduler = lr_scheduler.CosineAnnealingWarmRestarts(optimizer,T_0=CFG.T_0, 
                                                             eta_min=CFG.min_lr)
    elif CFG.scheduler == 'ReduceLROnPlateau':
        scheduler = lr_scheduler.ReduceLROnPlateau(optimizer,
                                                   mode='min',
                                                   factor=0.1,
                                                   patience=7,
                                                   threshold=0.0001,
                                                   min_lr=CFG.min_lr,)
    elif CFG.scheduer == 'ExponentialLR':
        scheduler = lr_scheduler.ExponentialLR(optimizer, gamma=0.85)
    elif CFG.scheduler == None:
        return None
        
    return scheduler

## Trainer

In [21]:
def train_model():
    set_seed(CFG.seed)
    
    ## Split folds
    skf = StratifiedGroupKFold(n_splits=CFG.n_fold, shuffle=True, random_state=CFG.seed)
    for fold, (train_idx, val_idx) in enumerate(skf.split(train, train['n_segs'], groups = train["case_id"])):
        print(f'#'*15)
        print(f'### Fold: {fold}')
        print(f'#'*15)
        train_loader = prepare_dataloader(train_idx, train, 'train')
        valid_loader = prepare_dataloader(val_idx, train, 'train')
    
        ## Initilize model
        model = build_model()
        optimizer = optim.Adam(model.parameters(), lr=CFG.lr, weight_decay=CFG.wd)
        scheduler = fetch_scheduler(optimizer)
        
        ## Run model
        model, history = run_training(model, train_loader, valid_loader, optimizer, scheduler, device=CFG.device, fold=fold, num_epochs=CFG.epochs)
        del model, optimizer, scheduler, train_loader, valid_loader
        torch.cuda.empty_cache()
        
        break

## Main

In [22]:
if __name__ == "__main__":
    train_model()

> SEEDING DONE
###############
### Fold: 0
###############
cuda: Tesla P100-PCIE-16GB

Epoch 1/10

Train : 100%|██████████| 465/465 [16:34<00:00,  2.14s/it, gpu_mem=13.90 GB, lr=0.00010, train_loss=0.6280]
Valid : 100%|██████████| 133/133 [04:31<00:00,  2.04s/it, gpu_memory=5.00 GB, lr=0.00010, valid_loss=0.5755]


Valid Dice: 0.1829 | Valid Jaccard: 0.1688
[32mValid Score Improved (-inf ---> 0.1829)
Model Saved


Epoch 2/10

Train : 100%|██████████| 465/465 [14:50<00:00,  1.91s/it, gpu_mem=13.90 GB, lr=0.00009, train_loss=0.4284]
Valid : 100%|██████████| 133/133 [04:03<00:00,  1.83s/it, gpu_memory=5.00 GB, lr=0.00009, valid_loss=0.3893]


Valid Dice: 0.7031 | Valid Jaccard: 0.6861
[32mValid Score Improved (0.1829 ---> 0.7031)
Model Saved


Epoch 3/10

Train : 100%|██████████| 465/465 [14:50<00:00,  1.91s/it, gpu_mem=13.90 GB, lr=0.00008, train_loss=0.2600]
Valid : 100%|██████████| 133/133 [03:57<00:00,  1.79s/it, gpu_memory=5.00 GB, lr=0.00008, valid_loss=0.3016]


Valid Dice: 0.7403 | Valid Jaccard: 0.7214
[32mValid Score Improved (0.7031 ---> 0.7403)
Model Saved


Epoch 4/10

Train : 100%|██████████| 465/465 [14:52<00:00,  1.92s/it, gpu_mem=13.90 GB, lr=0.00007, train_loss=0.1956]
Valid : 100%|██████████| 133/133 [03:41<00:00,  1.66s/it, gpu_memory=5.00 GB, lr=0.00007, valid_loss=0.2854]


Valid Dice: 0.7413 | Valid Jaccard: 0.7225
[32mValid Score Improved (0.7403 ---> 0.7413)
Model Saved


Epoch 5/10

Train : 100%|██████████| 465/465 [14:57<00:00,  1.93s/it, gpu_mem=13.90 GB, lr=0.00005, train_loss=0.1764]
Valid : 100%|██████████| 133/133 [03:59<00:00,  1.80s/it, gpu_memory=5.00 GB, lr=0.00005, valid_loss=0.2876]


Valid Dice: 0.7591 | Valid Jaccard: 0.7406
[32mValid Score Improved (0.7413 ---> 0.7591)
Model Saved


Epoch 6/10

Train : 100%|██████████| 465/465 [15:22<00:00,  1.98s/it, gpu_mem=13.90 GB, lr=0.00004, train_loss=0.1664]
Valid : 100%|██████████| 133/133 [04:00<00:00,  1.81s/it, gpu_memory=5.00 GB, lr=0.00004, valid_loss=0.2682]


Valid Dice: 0.7734 | Valid Jaccard: 0.7544
[32mValid Score Improved (0.7591 ---> 0.7734)
Model Saved


Epoch 7/10

Train : 100%|██████████| 465/465 [15:18<00:00,  1.98s/it, gpu_mem=13.90 GB, lr=0.00002, train_loss=0.1600]
Valid : 100%|██████████| 133/133 [04:13<00:00,  1.91s/it, gpu_memory=5.00 GB, lr=0.00002, valid_loss=0.2716]


Valid Dice: 0.7546 | Valid Jaccard: 0.7351


Epoch 8/10

Train : 100%|██████████| 465/465 [15:11<00:00,  1.96s/it, gpu_mem=13.90 GB, lr=0.00001, train_loss=0.1566]
Valid : 100%|██████████| 133/133 [04:05<00:00,  1.85s/it, gpu_memory=5.00 GB, lr=0.00001, valid_loss=0.2591]


Valid Dice: 0.7866 | Valid Jaccard: 0.7674
[32mValid Score Improved (0.7734 ---> 0.7866)
Model Saved


Epoch 9/10

Train : 100%|██████████| 465/465 [15:35<00:00,  2.01s/it, gpu_mem=13.90 GB, lr=0.00000, train_loss=0.1537]
Valid : 100%|██████████| 133/133 [04:01<00:00,  1.81s/it, gpu_memory=5.00 GB, lr=0.00000, valid_loss=0.2655]


Valid Dice: 0.7805 | Valid Jaccard: 0.7612


Epoch 10/10

Train : 100%|██████████| 465/465 [14:54<00:00,  1.92s/it, gpu_mem=13.90 GB, lr=0.00000, train_loss=0.1519]
Valid : 100%|██████████| 133/133 [04:01<00:00,  1.82s/it, gpu_memory=5.00 GB, lr=0.00000, valid_loss=0.2629]


Valid Dice: 0.7839 | Valid Jaccard: 0.7648


Training complete in 3h 13m 12s
Best Score: 0.7674
