In [None]:
# Import libraries
import os
import warnings
warnings.filterwarnings('ignore')
import torch.nn as nn
import torch
import segmentation_models_pytorch as sm
import numpy as np
import pandas as pd
import skimage.io as io
from PIL import Image
import cv2
from torch.utils.data import DataLoader,Dataset
import matplotlib.pyplot as plt
import tifffile
from sklearn.model_selection import KFold
import glob
import torch_optimizer as t_optim
import torch.optim as optim
from tqdm.notebook import tqdm
import torchvision.transforms as transforms
from torch.optim.lr_scheduler import OneCycleLR
import albumentations 
from torch.nn.modules.loss import _Loss
import segmentation_models_pytorch as sm
import torch.nn.functional as F
import albumentations as A
from albumentations.pytorch import ToTensorV2

In [None]:
class Dataset(Dataset):
    def __init__(self,data_csv_path:str="train_data.csv",indexes:list= None,valid:bool=False,transform:transforms = None, target_transform:transforms=None,preprocessing=None):
        self.data = pd.read_csv(data_csv_path)
        self.indexed_data = self.data.iloc[indexes,:]
        self.transform = transform
        self.target_transform = target_transform
        self.preprocessing = preprocessing
        
    def __getitem__(self,idx):
        
        image = tifffile.imread(self.indexed_data.iloc[idx,0])
        mask = tifffile.imread(self.indexed_data.iloc[idx,1]).astype(float)

        if self.transform:
            augmentations = self.transform(image=image,mask=mask)
            image,mask = augmentations['image'],augmentations['mask']
            
            
        if self.preprocessing:
            preprocessed = self.preprocessing(image=image,mask=mask)
            image,mask = preprocessed['image'],preprocessed['mask']
    
        return image.float(),(mask[:,:]/255.0).type(torch.LongTensor)
    
    def __len__(self):
        return len(self.indexed_data)     

In [None]:
#Defining configurations
class Configuration:
    MODEL_SAVEPATH= "model/"
    ENCODER = "efficientnet-b2"
    PRETRAINED_WEIGHTS = "imagenet"
    BATCH_SIZE = 32
    INPUT_CHANNELS = 3
    INPUT_SHAPE = (512,512,3)
    ACTIVATION = None
    #(crypts 1 background 0)
    CLASSES = 2 
    DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
    EPOCHS = 50
    LOSS_CROSSENTROPY = nn.CrossEntropyLoss() 
    LOSS_DICE = DiceScore(loss=True)
    DICE_COEF = DiceScore(loss=False)
    WEIGHT_DECAY = 1e-4
    LEARNING_RATE = 1e-3
    PREPROCESS = sm.encoders.get_preprocessing_fn(ENCODER,PRETRAINED_WEIGHTS)
    MODEL_NAME = 0
    ARCHITECTURE = "UNET"
    MAX_LR_FOR_ONECYCLELR = 1e-3
       
cfg = Configuration()

In [None]:
class Trainer:
    def __init__(self,cfg:Configuration,train_data_loader:DataLoader,valid_data_loader:DataLoader)->None:
        self.cfg = cfg
        self.patience = 5
        self.model = sm.Unet(encoder_name=self.cfg.ENCODER, 
                     encoder_weights=self.cfg.PRETRAINED_WEIGHTS, 
                     in_channels=self.cfg.INPUT_CHANNELS, 
                     classes=self.cfg.CLASSES)
        self.loss_function = self.cfg.LOSS_CROSSENTROPY
        self.lr = self.cfg.LEARNING_RATE
        self.batch_size = self.cfg.BATCH_SIZE
        self.train_dataloader = train_data_loader
        self.valid_dataloader = valid_data_loader
        self.device = self.cfg.DEVICE
        self.epochs = self.cfg.EPOCHS
        self.lr = 1e-3
        self.track_best_valid = []
        self.val_for_early_stopping = 9999999  
        self.log = pd.DataFrame(columns=["model_name","train_loss","train_dice","valid_loss","valid_dice"])
        self.optimizer = t_optim.Ranger(self.model.parameters(),weight_decay=self.cfg.WEIGHT_DECAY)
         
    def calculate_metrics(self,data_loader:DataLoader):
        self.model.eval()
        total_loss = 0
        total_dice = 0 
        with torch.no_grad():
            for data in tqdm(data_loader,total=len(data_loader)):
                im = data[0].to(self.device)
                mask = data[1].to(self.device)
                out = self.model(im)
                loss = self.loss_function(out.data,mask) 
                total_loss+=loss.item()
                total_dice+= self.cfg.DICE_COEF(out.data.to("cpu"),mask.cpu())
        return total_dice/len(data_loader),total_loss/len(data_loader)
    
    def earlystopping(self,val_loss):
        # Implemented early stopping used this in Applied Machine Learning CSCI-P 556 course work
        if val_loss < self.val_for_early_stopping:
            self.val_for_early_stopping = val_loss
            return True
        else:
            self.patience-=1
            return False
 
    def fit(self)->None:
        print("started fitting the model")
        best_loss = 9999999
        
        for epoch in range(self.epochs):
            self.model.train()
            self.model.to(self.device)
            
            dice_score_ = 0
            loss_ = 0
            
            for j,data in enumerate(tqdm(self.train_dataloader,total = len(self.train_dataloader))):
                input_image_batch = data[0].to(self.device)
                mask_batch = data[1].to(self.device)
                self.optimizer.zero_grad()
                output = self.model(input_image_batch)
                loss = self.loss_function(output,mask_batch)
                loss.backward()
                self.optimizer.step()
                loss_+=loss.item()
            
                dice_score_+= self.cfg.DICE_COEF(output.data.to("cpu"),mask_batch.to("cpu"))

            dice_score_valid,loss_valid, = self.calculate_metrics(self.valid_dataloader)
            train_dice = dice_score_/len(self.train_dataloader)
            train_loss = loss_/len(self.train_dataloader)
            print(f"train dice score : {train_dice}, train loss {train_loss}")
            print(f"valid dice score : {dice_score_valid}, valid loss {loss_valid}")
            
            self.log.loc[epoch,:] = [f"fold_{self.cfg.ENCODER}_{self.cfg.ENCODER}.pth",f"{train_loss}",f"{train_dice}",f"{loss_valid}",f"{dice_score_valid}"]
            self.log.to_csv(self.cfg.MODEL_SAVEPATH+f"/fold_{self.cfg.MODEL_NAME}__{self.cfg.ENCODER}_{self.cfg.BATCH_SIZE}_CE_Valid_slicing_all.csv",index=False)
            
            if self.patience >= 0 and self.earlystopping(loss_valid):
                print("saving model")
                
                torch.save(self.model.state_dict(),self.cfg.MODEL_SAVEPATH+f"/fold_{self.cfg.MODEL_NAME}_{self.cfg.ENCODER}_{self.cfg.BATCH_SIZE}_CE_Valid_slicing_all.pth")
                self.patience= 5
                
            
            if self.patience <= 0:
                print("Training terminated, no improvement in valid loss")
                break

In [None]:
train_data_csv_path = "train.csv"
patches_csv_path = "train_data.csv"

def training(train_data_csv_path:str,patches_csv_path:str,n_folds:int)->pd.DataFrame:
    
    train_ids_from_csv = pd.read_csv(train_data_csv_path).iloc[0:-1,:]['id'].values
    # After taking all the unique ids from train data i have used k fold cross validation method inspired from Applied Machine Learning CSCI-P 556 course work
    nfold = KFold(n_folds, shuffle=True, random_state=0)
    
    patch_dataframe = pd.read_csv(patches_csv_path)
    track_best_model = []

    for i, (train_idx, val_idx) in enumerate(nfold.split(train_ids_from_csv)):
        
        train_ids = (patch_dataframe[patch_dataframe.Train_image_path.str.contains("|".join(train_ids_from_csv[train_idx]))]).index
        valid_ids = (patch_dataframe[patch_dataframe.Train_image_path.str.contains("|".join(train_ids_from_csv[val_idx]))]).index
        train_dataset = Dataset("train_data.csv",indexes=train_ids,transform=get_train_transforms(),preprocessing=preprocessing_fucntion(cfg.PREPROCESS))
        valid_dataset = Dataset("train_data.csv",indexes=valid_ids,preprocessing=preprocessing_fucntion(cfg.PREPROCESS))

        train_dataloader = DataLoader(train_dataset,batch_size=cfg.BATCH_SIZE,shuffle=True)
        valid_dataloader = DataLoader(valid_dataset,batch_size=cfg.BATCH_SIZE,shuffle=False)
        cfg.MODEL_NAME = str(i) + "_"+cfg.ARCHITECTURE
        trainer =Trainer(cfg,train_dataloader,valid_dataloader)
        trainer.fit()
        
        
    return track_best_model



In [None]:
training(train_data_csv_path,patches_csv_path,n_folds = 2)

started fitting the model


  0%|          | 0/14 [00:00<?, ?it/s]

  0%|          | 0/9 [00:00<?, ?it/s]

train dice score : 0.2052617073059082, train loss 2.1625138180596486
valid dice score : 0.24767351150512695, valid loss 1.6185424327850342
saving model


  0%|          | 0/14 [00:00<?, ?it/s]

  0%|          | 0/9 [00:00<?, ?it/s]

train dice score : 0.2179231494665146, train loss 1.9829088960375105
valid dice score : 0.23884367942810059, valid loss 1.663023418850369


  0%|          | 0/14 [00:00<?, ?it/s]

  0%|          | 0/9 [00:00<?, ?it/s]

train dice score : 0.24957017600536346, train loss 1.7968011583600725
valid dice score : 0.25929373502731323, valid loss 1.57472734981113
saving model


  0%|          | 0/14 [00:00<?, ?it/s]

  0%|          | 0/9 [00:00<?, ?it/s]

train dice score : 0.28074172139167786, train loss 1.6084792954581124
valid dice score : 0.299227774143219, valid loss 1.4128799968295627
saving model


  0%|          | 0/14 [00:00<?, ?it/s]

  0%|          | 0/9 [00:00<?, ?it/s]

train dice score : 0.3228704333305359, train loss 1.402959167957306
valid dice score : 0.33597663044929504, valid loss 1.2758646541171603
saving model


  0%|          | 0/14 [00:00<?, ?it/s]

  0%|          | 0/9 [00:00<?, ?it/s]

train dice score : 0.3748754560947418, train loss 1.2024165391921997
valid dice score : 0.36505866050720215, valid loss 1.1621722645229764
saving model


  0%|          | 0/14 [00:00<?, ?it/s]

  0%|          | 0/9 [00:00<?, ?it/s]

train dice score : 0.42776599526405334, train loss 1.0289493245737893
valid dice score : 0.426025390625, valid loss 1.0027911596828036
saving model


  0%|          | 0/14 [00:00<?, ?it/s]

  0%|          | 0/9 [00:00<?, ?it/s]

train dice score : 0.48640966415405273, train loss 0.8548740148544312
valid dice score : 0.4799657464027405, valid loss 0.8506386942333646
saving model


  0%|          | 0/14 [00:00<?, ?it/s]

  0%|          | 0/9 [00:00<?, ?it/s]

train dice score : 0.5251966118812561, train loss 0.7585522404738835
valid dice score : 0.49896490573883057, valid loss 0.7890498638153076
saving model


  0%|          | 0/14 [00:00<?, ?it/s]

  0%|          | 0/9 [00:00<?, ?it/s]

train dice score : 0.5667153000831604, train loss 0.660918116569519
valid dice score : 0.5341483354568481, valid loss 0.6988677514923943
saving model


  0%|          | 0/14 [00:00<?, ?it/s]

  0%|          | 0/9 [00:00<?, ?it/s]

train dice score : 0.5997697710990906, train loss 0.5894110075065068
valid dice score : 0.5626924633979797, valid loss 0.64733929766549
saving model


  0%|          | 0/14 [00:00<?, ?it/s]

  0%|          | 0/9 [00:00<?, ?it/s]

train dice score : 0.6326979398727417, train loss 0.5203958302736282
valid dice score : 0.5912550091743469, valid loss 0.5809802545441521
saving model


  0%|          | 0/14 [00:00<?, ?it/s]

  0%|          | 0/9 [00:00<?, ?it/s]

train dice score : 0.6478986144065857, train loss 0.490439310669899
valid dice score : 0.608151376247406, valid loss 0.5598744518227048
saving model


  0%|          | 0/14 [00:00<?, ?it/s]

  0%|          | 0/9 [00:00<?, ?it/s]

train dice score : 0.6714735627174377, train loss 0.4487676854644503
valid dice score : 0.6332492232322693, valid loss 0.513359526793162
saving model


  0%|          | 0/14 [00:00<?, ?it/s]

  0%|          | 0/9 [00:00<?, ?it/s]

train dice score : 0.6901776194572449, train loss 0.41643014337335316
valid dice score : 0.6462050080299377, valid loss 0.48296276066038346
saving model


  0%|          | 0/14 [00:00<?, ?it/s]

  0%|          | 0/9 [00:00<?, ?it/s]

train dice score : 0.694923460483551, train loss 0.40736652059214457
valid dice score : 0.6567431092262268, valid loss 0.4724320438173082
saving model


  0%|          | 0/14 [00:00<?, ?it/s]

  0%|          | 0/9 [00:00<?, ?it/s]

train dice score : 0.7180500626564026, train loss 0.36833253289972034
valid dice score : 0.6807740330696106, valid loss 0.4266561037964291
saving model


  0%|          | 0/14 [00:00<?, ?it/s]

  0%|          | 0/9 [00:00<?, ?it/s]

train dice score : 0.7338312268257141, train loss 0.3449004611798695
valid dice score : 0.6959984302520752, valid loss 0.40570955475171405
saving model


  0%|          | 0/14 [00:00<?, ?it/s]

  0%|          | 0/9 [00:00<?, ?it/s]

train dice score : 0.7488778829574585, train loss 0.3220762951033456
valid dice score : 0.7187990546226501, valid loss 0.36705615785386825
saving model


  0%|          | 0/14 [00:00<?, ?it/s]

  0%|          | 0/9 [00:00<?, ?it/s]

train dice score : 0.7669095993041992, train loss 0.2942676033292498
valid dice score : 0.7402395009994507, valid loss 0.33334749274783665
saving model


  0%|          | 0/14 [00:00<?, ?it/s]

  0%|          | 0/9 [00:00<?, ?it/s]

train dice score : 0.7811325192451477, train loss 0.2733152572597776
valid dice score : 0.7537763714790344, valid loss 0.3134782877233293
saving model


  0%|          | 0/14 [00:00<?, ?it/s]

  0%|          | 0/9 [00:00<?, ?it/s]

train dice score : 0.7946535348892212, train loss 0.2528963450874601
valid dice score : 0.7789229154586792, valid loss 0.2756347606579463
saving model


  0%|          | 0/14 [00:00<?, ?it/s]

  0%|          | 0/9 [00:00<?, ?it/s]

train dice score : 0.8125752210617065, train loss 0.22834313128675734
valid dice score : 0.8211769461631775, valid loss 0.22472400135464138
saving model


  0%|          | 0/14 [00:00<?, ?it/s]

  0%|          | 0/9 [00:00<?, ?it/s]

train dice score : 0.8318995833396912, train loss 0.20255965207304275
valid dice score : 0.8398833274841309, valid loss 0.19840564330418906
saving model


  0%|          | 0/14 [00:00<?, ?it/s]

  0%|          | 0/9 [00:00<?, ?it/s]

train dice score : 0.8443679809570312, train loss 0.18737934742655074
valid dice score : 0.8489030599594116, valid loss 0.19044371445973715
saving model


  0%|          | 0/14 [00:00<?, ?it/s]

  0%|          | 0/9 [00:00<?, ?it/s]

train dice score : 0.8570351004600525, train loss 0.17116719165018626
valid dice score : 0.8591689467430115, valid loss 0.1739905443456438
saving model


  0%|          | 0/14 [00:00<?, ?it/s]

  0%|          | 0/9 [00:00<?, ?it/s]

train dice score : 0.8662228584289551, train loss 0.15941989847591945
valid dice score : 0.866603672504425, valid loss 0.1653993311855528
saving model


  0%|          | 0/14 [00:00<?, ?it/s]

  0%|          | 0/9 [00:00<?, ?it/s]

train dice score : 0.8763852119445801, train loss 0.14684608791555678
valid dice score : 0.8795632123947144, valid loss 0.15228956606653002
saving model


  0%|          | 0/14 [00:00<?, ?it/s]

  0%|          | 0/9 [00:00<?, ?it/s]

train dice score : 0.886483371257782, train loss 0.13506082338946207
valid dice score : 0.8898022770881653, valid loss 0.14057531704505286
saving model


  0%|          | 0/14 [00:00<?, ?it/s]

  0%|          | 0/9 [00:00<?, ?it/s]

train dice score : 0.893578052520752, train loss 0.12696398262466704
valid dice score : 0.8935868740081787, valid loss 0.13592714650763404
saving model


  0%|          | 0/14 [00:00<?, ?it/s]

  0%|          | 0/9 [00:00<?, ?it/s]

train dice score : 0.8989971280097961, train loss 0.1219525124345507
valid dice score : 0.8951770067214966, valid loss 0.13286159021986854
saving model


  0%|          | 0/14 [00:00<?, ?it/s]

  0%|          | 0/9 [00:00<?, ?it/s]

train dice score : 0.9058681130409241, train loss 0.11257686465978622
valid dice score : 0.8975934982299805, valid loss 0.13000566429562038
saving model


  0%|          | 0/14 [00:00<?, ?it/s]

  0%|          | 0/9 [00:00<?, ?it/s]

train dice score : 0.9128373861312866, train loss 0.10519900864788465
valid dice score : 0.90700364112854, valid loss 0.11952547646231121
saving model


  0%|          | 0/14 [00:00<?, ?it/s]

  0%|          | 0/9 [00:00<?, ?it/s]

train dice score : 0.9193894267082214, train loss 0.09727247804403305
valid dice score : 0.9127349853515625, valid loss 0.11471685187684165
saving model


  0%|          | 0/14 [00:00<?, ?it/s]

  0%|          | 0/9 [00:00<?, ?it/s]

train dice score : 0.9242850542068481, train loss 0.09221186648522105
valid dice score : 0.9182853698730469, valid loss 0.11168013927009371
saving model


  0%|          | 0/14 [00:00<?, ?it/s]

  0%|          | 0/9 [00:00<?, ?it/s]

train dice score : 0.9281825423240662, train loss 0.08833548905594009
valid dice score : 0.9213710427284241, valid loss 0.10464549561341603
saving model


  0%|          | 0/14 [00:00<?, ?it/s]

  0%|          | 0/9 [00:00<?, ?it/s]

train dice score : 0.9332866668701172, train loss 0.08161306753754616
valid dice score : 0.9249355792999268, valid loss 0.10189292828241985
saving model


  0%|          | 0/14 [00:00<?, ?it/s]

  0%|          | 0/9 [00:00<?, ?it/s]

train dice score : 0.936835765838623, train loss 0.07745571248233318
valid dice score : 0.9293911457061768, valid loss 0.10101876656214397
saving model


  0%|          | 0/14 [00:00<?, ?it/s]

  0%|          | 0/9 [00:00<?, ?it/s]

train dice score : 0.9412189722061157, train loss 0.07339128905108996
valid dice score : 0.9318115711212158, valid loss 0.0973506619532903
saving model


  0%|          | 0/14 [00:00<?, ?it/s]

  0%|          | 0/9 [00:00<?, ?it/s]

train dice score : 0.9426506757736206, train loss 0.0733142111982618
valid dice score : 0.9313256740570068, valid loss 0.10132050514221191


  0%|          | 0/14 [00:00<?, ?it/s]

  0%|          | 0/9 [00:00<?, ?it/s]

train dice score : 0.946719229221344, train loss 0.06798279551523072
valid dice score : 0.9350230097770691, valid loss 0.09499799087643623
saving model


  0%|          | 0/14 [00:00<?, ?it/s]

  0%|          | 0/9 [00:00<?, ?it/s]

train dice score : 0.9519217610359192, train loss 0.05929619020649365
valid dice score : 0.9397463202476501, valid loss 0.09250110346410009
saving model


  0%|          | 0/14 [00:00<?, ?it/s]

  0%|          | 0/9 [00:00<?, ?it/s]

train dice score : 0.9515948295593262, train loss 0.0634199074868645
valid dice score : 0.9415901899337769, valid loss 0.09197533958488041
saving model


  0%|          | 0/14 [00:00<?, ?it/s]

  0%|          | 0/9 [00:00<?, ?it/s]

train dice score : 0.9531935453414917, train loss 0.06132739117102964
valid dice score : 0.9399341940879822, valid loss 0.10006757287515534


  0%|          | 0/14 [00:00<?, ?it/s]

  0%|          | 0/9 [00:00<?, ?it/s]

train dice score : 0.9544306993484497, train loss 0.05864100796835763
valid dice score : 0.9411132335662842, valid loss 0.09777414136462742


  0%|          | 0/14 [00:00<?, ?it/s]

  0%|          | 0/9 [00:00<?, ?it/s]

train dice score : 0.9585186839103699, train loss 0.05312738727246012
valid dice score : 0.9445450305938721, valid loss 0.0932818026178413


  0%|          | 0/14 [00:00<?, ?it/s]

  0%|          | 0/9 [00:00<?, ?it/s]

train dice score : 0.9586935639381409, train loss 0.05564226768910885
valid dice score : 0.9391500949859619, valid loss 0.13231181932820213


  0%|          | 0/14 [00:00<?, ?it/s]

  0%|          | 0/9 [00:00<?, ?it/s]

train dice score : 0.960152804851532, train loss 0.0555925401193755
valid dice score : 0.9437630772590637, valid loss 0.10265875111023585
Training terminated, no improvement in valid loss
started fitting the model


  0%|          | 0/18 [00:00<?, ?it/s]

KeyboardInterrupt: 

In [None]:
# Utility methods
class DiceScore(nn.Module):

   #https://kornia.readthedocs.io/en/v0.1.2/_modules/torchgeometry/losses/dice.html
    
    
    def __init__(self,eps = 1e-6,loss=False) -> None:
        super(DiceScore, self).__init__()
        self.eps =  1e-6
        self.loss = loss
        

    def forward(
            self,
            input: torch.Tensor,
            target: torch.Tensor) -> torch.Tensor:
        
        # compute softmax over the classes axis
        input_soft = F.softmax(input, dim=1)

        # create the labels one hot tensor
        target_one_hot = torch.eye(2)[target.squeeze(1)]
        target_one_hot = target_one_hot.permute(0, 3, 1, 2).float()

        # compute the actual dice score
        dims = (1, 2, 3)
        intersection = torch.sum(input_soft * target_one_hot, dims)
        cardinality = torch.sum(input_soft + target_one_hot, dims)

        dice_score = 2. * intersection / (cardinality + self.eps)
        return torch.mean(1. - dice_score) if self.loss else torch.mean(dice_score)
   
def get_train_transforms()->transforms:
    train_transform =  A.Compose([
            A.HorizontalFlip(),
            A.VerticalFlip(),
            A.RandomRotate90(),
            A.OneOf([
                A.ElasticTransform(p=.3),
                A.GaussianBlur(p=.3),
                A.GaussNoise(p=.3),
                A.OpticalDistortion(p=0.3),
                A.GridDistortion(p=.1),
                A.PiecewiseAffine(p=0.3),
            ], p=0.3),
            A.OneOf([
                A.HueSaturationValue(15,25,0),
                A.CLAHE(clip_limit=2),
                A.RandomBrightnessContrast(brightness_limit=0.3, contrast_limit=0.3),
            ], p=0.3),

        ])
    return train_transform

def get_val_transforms()->transforms:
    validation_transform = A.Compose([ToTensorV2()])
    
    return validation_transform

def preprocessing_fucntion(preprocesing_function=None):
    return A.Compose([A.Lambda(image=preprocesing_function),ToTensorV2()])