In [None]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

'''import os
for dirname, _, filenames in os.walk('/kaggle/input'):
    for filename in filenames:
        print(os.path.join(dirname, filename))'''

# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

In [None]:
!pip install -q timm
!pip install -q tqdm
!pip install -q pytorch-lightning
!pip install -q segmentation-models-pytorch

In [None]:
import numpy as np 
import pandas as pd 

import os

import timm
from tqdm.notebook import tqdm
import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning import loggers as pl_loggers
from torch.optim.lr_scheduler import ReduceLROnPlateau, StepLR
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
from pytorch_lightning import Trainer
from pytorch_lightning.loggers import TensorBoardLogger

import torch
import torch.nn as nn
from torch.nn import functional as F
from torch.utils.data import DataLoader
from torch.utils.data import random_split

import torchvision
from torchvision.utils import make_grid
from torchvision import transforms

import cv2
import matplotlib.pyplot as plt
from matplotlib.patches import Rectangle


from sklearn.model_selection import StratifiedKFold, KFold, StratifiedGroupKFold

from PIL import Image
import segmentation_models_pytorch as smp

In [None]:
BASE_PATH  = '/home/matteo/Documents/uw-madison-gi-tract-image-segmentation/dataset'
df = pd.read_csv('../input/create-mask-dataset-for-uw-madison/train_processed.csv')
df

In [None]:
#df["file_path"] = df["file_path"].apply(lambda x: "../input/create-mask-dataset-for-uw-madison" + x[47:] )
df["mask_path"] = df["mask_path"].apply(lambda x: "../input/create-mask-dataset-for-uw-madison" + x[1:] )
df

In [None]:
df.iloc[0]["mask_path"]

In [None]:
class CFG:
    seed          = 101
    debug         = False # set debug=False for Full Training
    exp_name      = 'Baselinev2'
    comment       = 'unet-efficientnet_b1-224x224-aug2-split2'
    model_name    = 'Unet'
    backbone      = 'efficientnet-b1'
    train_bs      = 128
    valid_bs      = 128
    img_size      = [224, 224]
    epochs        = 100
    lr            = 2e-3
    scheduler     = 'ReduceLROnPlateau'
    min_lr        = 1e-6
    T_max         = int(30000/train_bs*epochs)+50
    T_0           = 25
    warmup_epochs = 0
    wd            = 1e-6
    n_accumulate  = max(1, 32//train_bs)
    n_fold        = 5
    num_classes   = 4
    device        = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# Utils

In [None]:
def id2mask(id_):
    idf = df[df['id']==id_]
    wh = idf[['height','width']].iloc[0]
    shape = (wh.height, wh.width, 3)
    mask = np.zeros(shape, dtype=np.uint8)
    for i, class_ in enumerate(['large_bowel', 'small_bowel', 'stomach']):
        cdf = idf[idf['class']==class_]
        rle = cdf.segmentation.squeeze()
        if len(cdf) and not pd.isna(rle):
            mask[..., i] = rle_decode(rle, shape[:2])
    return mask

def rgb2gray(mask):
    pad_mask = np.pad(mask, pad_width=[(0,0),(0,0),(1,0)])
    gray_mask = pad_mask.argmax(-1)
    return gray_mask

def gray2rgb(mask):
    rgb_mask = tf.keras.utils.to_categorical(mask, num_classes=4)
    return rgb_mask[..., 1:].astype(mask.dtype)

# ref: https://www.kaggle.com/code/awsaf49/uwmgi-unet-train-pytorch#%F0%9F%93%92-Notebooks

def load_img(path):
    im = Image.open(path)
    im = np.array(im).astype(np.float32)
    return im

def load_msk(path):
    msk = np.load(path)
    msk = msk.astype('float32')
    msk*=255.0
    return msk
    

def show_img(img, mask=None):
    clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8,8))
    plt.imshow(img)#, cmap='bone')
    
    if mask is not None:
        plt.imshow(mask*255, alpha=0.5)
        handles = [Rectangle((0,0),1,1, color=_c) for _c in [(0.667,0.0,0.0), (0.0,0.667,0.0), (0.0,0.0,0.667)]]
        labels = ["Large Bowel", "Small Bowel", "Stomach"]
        plt.legend(handles,labels)
    plt.axis('off')
    
# ref: https://www.kaggle.com/paulorzp/run-length-encode-and-decode
def rle_decode(mask_rle, shape):
    '''
    mask_rle: run-length as string formated (start length)
    shape: (height,width) of array to return 
    Returns numpy array, 1 - mask, 0 - background

    '''
    s = mask_rle.split()
    starts, lengths = [np.asarray(x, dtype=int) for x in (s[0:][::2], s[1:][::2])]
    starts -= 1
    ends = starts + lengths
    img = np.zeros(shape[0]*shape[1], dtype=np.uint8)
    for lo, hi in zip(starts, ends):
        img[lo:hi] = 1
    return img.reshape(shape)  # Needed to align to RLE direction


# ref.: https://www.kaggle.com/stainsby/fast-tested-rle
def rle_encode(img):
    '''
    img: numpy array, 1 - mask, 0 - background
    Returns run length as string formated
    '''
    pixels = img.flatten()
    pixels = np.concatenate([[0], pixels, [0]])
    runs = np.where(pixels[1:] != pixels[:-1])[0] + 1
    runs[1::2] -= runs[::2]
    return ' '.join(str(x) for x in runs)

In [None]:
test = df.query("empty == False")
file_path,mask_path =  test.iloc[10]["file_path"], test.iloc[10]["mask_path"]

file_path = df.iloc[0]["file_path"]
mask_path = df.iloc[0]["mask_path"]

im = Image.open(file_path)
im = np.array(im).astype(np.float32)
#im = load_img(file_path)
#mask = load_msk(mask_path)
mask = np.load(mask_path)
show_img(im,mask)

# Dataset

In [None]:
class BuildDataset(torch.utils.data.Dataset):
    def __init__(self, df, label=True, transforms=None, target_transforms=None):
        self.df         = df
        self.label      = label
        self.img_paths  = df['file_path'].tolist()
        self.msk_paths  = df['mask_path'].tolist()
        self.transforms = transforms
        self.target_transforms = target_transforms
        
    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, index):
        img_path  = self.img_paths[index]
        img = load_img(img_path)
        
        if self.label:
            msk_path = self.msk_paths[index]
            msk = load_msk(msk_path)
            if self.transforms:
                img = self.transforms(img)
            if self.target_transforms:
                msk = self.target_transforms(msk)
            return img, msk
        else:
            if self.transforms:
                img = self.transforms(img)
            return img

# Dataloaders

In [None]:
classes = {
    "Large_Bowel" : [255,0,0] ,
    "Small_Bowel" : [0,255,0] , 
    "Stomach" : [0,0,255]
}


train_transforms = transforms.Compose([
    transforms.ToTensor(),
    transforms.Resize(CFG.img_size),
])

val_transforms = transforms.Compose([
    transforms.ToTensor(),
    transforms.Resize(CFG.img_size),
])

data_transforms = {
    "train": train_transforms,
    "valid": val_transforms
}


def prepare_loaders(fold=0, label = True, debug=False):
    train_df = df.query("fold!=@fold").reset_index(drop=True)
    valid_df = df.query("fold==@fold").reset_index(drop=True)
    if debug:
        train_df = train_df.head(32*5).query("empty==0")
        valid_df = valid_df.head(32*3).query("empty==0")
    train_dataset = BuildDataset(train_df,label=label, transforms=data_transforms['train'], target_transforms=data_transforms['train'])
    valid_dataset = BuildDataset(valid_df,label=label, transforms=data_transforms['valid'], target_transforms=data_transforms['valid'])

    train_loader = DataLoader(train_dataset, batch_size=CFG.train_bs if not debug else 20, 
                              num_workers=4, shuffle=True, pin_memory=True, drop_last=False)
    valid_loader = DataLoader(valid_dataset, batch_size=CFG.valid_bs if not debug else 20, 
                              num_workers=4, shuffle=False, pin_memory=True)
    
    return train_loader, valid_loader

# Metrics

In [None]:
"""Common image segmentation metrics.
"""

import torch

EPS = 1e-10

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

def nanmean(x):
    """Computes the arithmetic mean ignoring any NaNs."""
    return torch.mean(x[x == x])


def fast_hist(true, pred, num_classes, device = device):
    mask = (true >= 0) & (true < num_classes)
    hist = torch.bincount(
        num_classes * true[mask] + pred[mask],
        minlength=num_classes ** 2,
    ).reshape(num_classes, num_classes).to(device=device, dtype=torch.long)
    #.float()
    return hist


def overall_pixel_accuracy(hist):
    """Computes the total pixel accuracy.
    The overall pixel accuracy provides an intuitive
    approximation for the qualitative perception of the
    label when it is viewed in its overall shape but not
    its details.
    Args:
        hist: confusion matrix.
    Returns:
        overall_acc: the overall pixel accuracy.
    """
    correct = torch.diag(hist).sum()
    total = hist.sum()
    overall_acc = correct / (total + EPS)
    return overall_acc


def per_class_pixel_accuracy(hist, verbose = False):
    """Computes the average per-class pixel accuracy.
    The per-class pixel accuracy is a more fine-grained
    version of the overall pixel accuracy. A model could
    score a relatively high overall pixel accuracy by
    correctly predicting the dominant labels or areas
    in the image whilst incorrectly predicting the
    possibly more important/rare labels. Such a model
    will score a low per-class pixel accuracy.
    Args:
        hist: confusion matrix.
    Returns:
        avg_per_class_acc: the average per-class pixel accuracy.
    """
    correct_per_class = torch.diag(hist)
    total_per_class = hist.sum(dim=1)
    per_class_acc = correct_per_class / (total_per_class + EPS)
    avg_per_class_acc = nanmean(per_class_acc)
    if verbose:
        return per_class_acc
    else:
        return avg_per_class_acc



def jaccard_index(hist, verbose = False):
    """Computes the Jaccard index, a.k.a the Intersection over Union (IoU).
    Args:
        hist: confusion matrix.
    Returns:
        avg_jacc: the average per-class jaccard index.
    """
    A_inter_B = torch.diag(hist)
    A = hist.sum(dim=1)
    B = hist.sum(dim=0)
    jaccard = A_inter_B / (A + B - A_inter_B + EPS)
    avg_jacc = nanmean(jaccard)
    if verbose:
        return jaccard
    else:
        return avg_jacc


def dice_coefficient(hist, verbose = False):
    """Computes the Sørensen–Dice coefficient, a.k.a the F1 score.
    Args:
        hist: confusion matrix.
    Returns:
        avg_dice: the average per-class dice coefficient.
    """
    A_inter_B = torch.diag(hist)
    A = hist.sum(dim=1)
    B = hist.sum(dim=0)
    dice = (2 * A_inter_B) / (A + B + EPS)
    avg_dice = nanmean(dice)
    if verbose:
        return dice
    else:
        return avg_dice

def ret_hist(true, pred, num_classes):
    hist = torch.zeros((num_classes, num_classes)).to(device=device, dtype=torch.long)
    for t, p in zip(true, pred):
        hist += fast_hist(t.flatten(), p.flatten(), num_classes)
    return hist

def eval_metrics(hist, verbose = False):
    """Computes various segmentation metrics on 2D feature maps.
    Args:
        true: a tensor of shape [B, H, W] or [B, 1, H, W].
        pred: a tensor of shape [B, H, W] or [B, 1, H, W].
        num_classes: the number of classes to segment. This number
            should be less than the ID of the ignored class.
    Returns:
        overall_acc: the overall pixel accuracy.
        avg_per_class_acc: the average per-class pixel accuracy.
        avg_jacc: the jaccard index.
        avg_dice: the dice coefficient.
    """
    overall_acc = overall_pixel_accuracy(hist)
    avg_per_class_acc = per_class_pixel_accuracy(hist, verbose)
    avg_jacc = jaccard_index(hist, verbose)
    avg_dice = dice_coefficient(hist, verbose)
    if verbose:
        return overall_acc.item(), list(avg_per_class_acc.cpu().numpy()), list(avg_jacc.cpu().numpy()), list(avg_dice.cpu().numpy())
    else:
        return overall_acc.item(), avg_per_class_acc.item(), avg_jacc.item(), avg_dice.item()

# Model

In [None]:
classes = {
    "Large_Bowel" : [255,0,0] ,  # 0
    "Small_Bowel" : [0,255,0] ,  # 1
    "Stomach" : [0,0,255],       # 2
    "Background": [0,0,0]        # 4
}


def rgb2idx(msks):
    
    assert len(msks.shape) == 4 or len(msks.shape) == 3, "inconsistent input shape"
    
    Large_Bowel = torch.tensor([255,0,0]).float()
    Small_Bowel = torch.tensor([0,255,0]).float()
    Stomach = torch.tensor([0,0,255]).float()
    
    if len(msks.shape) == 4:
        
        H = msks.size(2)
        W = msks.size(3)
        res = torch.full((msks.size(0),H,W),4)
        for idx, msk in enumerate(msks):
            for i in range(H):
                for j in range(W):
                    if torch.equal(msk[:,i,j], Large_Bowel):
                        res[idx,i,j] = 0
                    elif torch.equal(msk[:,i,j] , Small_Bowel):
                        res[idx,i,j] = 1
                    elif torch.equal(msk[:,i,j] , Stomach):
                        res[idx,i,j] = 2
        return res
    
    elif len(msks.shape) == 3:
        H = msks.size(1)
        W = msks.size(2)
        res = torch.full((H,W),4)
        for i in range(H):
            for j in range(W):
                if torch.equal(msks[:,i,j], Large_Bowel):
                    res[i,j] = 0
                elif torch.equal(msks[:,i,j] , Small_Bowel):
                    res[i,j] = 1
                elif torch.equal(msks[:,i,j] , Stomach):
                    res[i,j] = 2
        return res
        

rgb2idx_jit = torch.jit.script(rgb2idx)

def idx2rgb(msks):
    
    assert len(msks.shape) == 2 or len(msks.shape) == 3, "inconsistent input shape"
    
    Large_Bowel = torch.tensor([255,0,0]).float()
    Small_Bowel = torch.tensor([0,255,0]).float()
    Stomach = torch.tensor([0,0,255]).float()
    
    Large_Bowel_index = torch.tensor([0])
    Small_Bowel_index = torch.tensor([1])
    Stomach_index = torch.tensor([2])
    
    if len(msks.shape) == 3:
        H = msks.size(1)
        W = msks.size(2)
        res = torch.zeros((msks.size(0),3,H,W))
        for idx, msk in enumerate(msks):
            for i in range(H):
                for j in range(W):
                    if torch.equal(msk[i,j], Large_Bowel_index):
                        res[idx,:,i,j] = Large_Bowel
                    elif torch.equal(msk[i,j] , Small_Bowel_index):
                        res[idx,:,i,j] = Small_Bowel
                    elif torch.equal(msk[i,j] , Stomach_index):
                        res[idx,:,i,j] = Stomach
        return res
    
    elif len(msks.shape) == 2:
        H = msks.size(0)
        W = msks.size(1)
        res = torch.zeros((3,H,W))
        for i in range(H):
            for j in range(W):
                if torch.equal(msks[i,j], Large_Bowel_index):
                    res[:,i,j] = Large_Bowel
                elif torch.equal(msks[i,j] , Small_Bowel_index):
                    res[:,i,j] = Small_Bowel
                elif torch.equal(msks[i,j] , Stomach_index):
                        res[:,i,j] = Stomach
        return res
    
    

idx2rgb_jit = torch.jit.script(idx2rgb)

# LightningModule

In [None]:
class UW_model(pl.LightningModule):

    def __init__(self,df,fold):
        super().__init__()
        
        self.model = smp.Unet(
            encoder_name="resnet18",        
            encoder_weights="imagenet",     
            in_channels=1,                  
            classes=4,                      
        )
        self.criterion = nn.CrossEntropyLoss(ignore_index = 3)
        self.fold = fold
        self.df = df
        #self.automatic_optimization = False
        self.metrics = {}
        
        self.freeze_encoder_upper_layers()
    
    def print_freezed_layers(self):
        for name, param in self.model.encoder.layer1.named_parameters():
            if(param.requires_grad == False):
                print(f"[{name}] layer freezed")
        
    def freeze_encoder_upper_layers(self):
        for name, param in self.model.encoder.layer1.named_parameters():
            param.requires_grad = False
        #for name, param in self.model.encoder.layer2.named_parameters():
        #    param.requires_grad = False
        #print("[INFO] encoder layer 1 and layer 2 freezed")
        print("[INFO] encoder layer 1 freezed")
            
    def unfreeze_encoder_upper_layers(self):
        for name, param in self.model.encoder.layer1.named_parameters():
            param.requires_grad = True
        #for name, param in self.model.encoder.layer2.named_parameters():
        #    param.requires_grad = True
        #print("[INFO] encoder layer 1 and layer 2 unfreezed")
        print("[INFO] encoder layer 1 unfreezed")
            
    def forward(self, image):
        mask = self.model(image)
        return mask

    def shared_step(self, batch, stage):
        image, mask = batch
        
        predicted_mask = self.forward(image)
        
        pad = torch.full((mask.size(0),1,mask.size(2),mask.size(3)),0.5)
        msks_padded = torch.cat((mask,pad.to(self.device)),1)
        
        msks_prob, pred_prob = torch.softmax(msks_padded,dim=1), torch.softmax(predicted_mask,dim=1)
        msks_index, pred_index = torch.argmax(msks_prob,dim=1), torch.argmax(pred_prob,dim=1)
        
        msks_index_loss = msks_index.long()
        pred_index_loss = pred_index.float().requires_grad_()
        
        loss = self.criterion(predicted_mask, msks_index_loss.to(self.device))
        
        hist = fast_hist(msks_index.flatten().type(torch.LongTensor),pred_index.flatten().type(torch.LongTensor),4)
        metrics = eval_metrics(hist, verbose = True)
        OA , acc, iou, f1 = metrics

        return {
            "loss": loss,
            "OA":OA ,
            "acc": acc,
            "iou": iou,
            "f1":f1
        }

    def shared_epoch_end(self, outputs, stage):
        
        m_loss  = torch.stack([x['loss'] for x in outputs]).mean()

        m_OA    = np.vstack([x['OA'] for x in outputs]).mean()
        m_acc   = np.hstack([x['acc'] for x in outputs]).mean( axis=0)
        m_iou   = np.hstack([x['iou'] for x in outputs]).mean( axis=0)
        m_f1    = np.hstack([x['f1'] for x in outputs]).mean( axis=0)

        acc   = np.vstack([x['acc'] for x in outputs]).mean( axis=0)
        iou   = np.vstack([x['iou'] for x in outputs]).mean( axis=0)
        f1    = np.vstack([x['f1'] for x in outputs]).mean( axis=0)
        

        metrics = {
            f"loss/{stage}": m_loss,
            f"m_OA/{stage}": m_OA,
            f"m_acc/{stage}": m_acc,
            f"m_iou/{stage}": m_iou,
            f"m_f1/{stage}": m_f1,
            f"m_acc_Large_Bowel/{stage}": acc[0],
            f"m_acc_Small_Bowel/{stage}": acc[1],
            f"m_acc_Stomach/{stage}": acc[2],
            f"m_acc_Backgorund/{stage}": acc[3],
            f"m_iou_Large_Bowel/{stage}": iou[0],
            f"m_iou_Small_Bowel/{stage}": iou[1],
            f"m_iou_Stomach/{stage}": iou[2],
            f"m_iou_Backgorund/{stage}": iou[3],
            f"m_f1_Large_Bowel/{stage}": f1[0],
            f"m_f1_Small_Bowel/{stage}": f1[1],
            f"m_f1_Stomach/{stage}": f1[2],
            f"m_f1_Backgorund/{stage}": f1[3],
        }
        
        self.metrics[stage] = metrics
        
        self.log_dict(metrics,prog_bar=True)

    def training_step(self, batch, batch_idx):        
        return self.shared_step(batch, "Train")     

    def training_epoch_end(self, outputs):
        res = self.shared_epoch_end(outputs, "Train")
        '''if self.metrics and "Train" in self.metrics.keys():
            if self.metrics["Train"]["m_OA/Train"] >= 0.5:
                self.unfreeze_encoder_upper_layers()'''
                
        if self.current_epoch == 15:
            self.unfreeze_encoder_upper_layers()
        return res

    def validation_step(self, batch, batch_idx):
        return self.shared_step(batch, "Val")

    def validation_epoch_end(self, outputs):
        
        res = self.shared_epoch_end(outputs, "Val")
            
        '''sch = self.lr_schedulers()
        if isinstance(sch, torch.optim.lr_scheduler.ReduceLROnPlateau):
            sch.step(self.trainer.callback_metrics["loss/Val"])'''
            
        opt = self.optimizers()
        for param_group in opt.param_groups:
            self.log("learning_rate",(param_group['lr']))
            
        return res

    '''def test_step(self, batch, batch_idx):
        return self.shared_step(batch, "Test")  

    def test_epoch_end(self, outputs):
        return self.shared_epoch_end(outputs, "Test")'''

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=CFG.lr)
        '''scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer,
                                                               mode='min',
                                                               factor=0.2,
                                                               patience=5,
                                                               min_lr=CFG.min_lr,
                                                               verbose=True)'''
        scheduler = StepLR(optimizer, step_size=15, gamma=0.5)
        return [optimizer], [scheduler]
    
    def train_dataloader(self):
        fold = self.fold
        train_df = self.df.query("fold!=@fold").reset_index(drop=True)
        train_dataset = BuildDataset(train_df,
                                     label=True, 
                                     transforms=data_transforms['train'], 
                                     target_transforms=data_transforms['train'])
        train_loader = DataLoader(train_dataset, 
                                  batch_size=CFG.train_bs, 
                                  num_workers=4, 
                                  shuffle=True, 
                                  pin_memory=True, 
                                  drop_last=True)
        
        return train_loader

    def val_dataloader(self):
        fold = self.fold
        valid_df = self.df.query("fold==@fold").reset_index(drop=True)
        
        
        valid_dataset = BuildDataset(valid_df,
                                     label=True, 
                                     transforms=data_transforms['valid'], 
                                     target_transforms=data_transforms['valid'])

        
        valid_loader = DataLoader(valid_dataset, 
                                  batch_size=CFG.valid_bs, 
                                  num_workers=4, 
                                  shuffle=False, 
                                  pin_memory=True,
                                  drop_last=True)

        return valid_loader

In [None]:
'''tmp = df.query("empty==False")
df_reduced = tmp.sample(frac=0.10, replace=True, random_state=42).reset_index(drop=True)

skf = StratifiedGroupKFold(n_splits=5, shuffle=True, random_state=CFG.seed)

for fold, (train_idx, val_idx) in enumerate(skf.split(df_reduced, df_reduced['empty'], groups = df_reduced["case"])):
    df_reduced.loc[val_idx, 'fold'] = fold
display(df_reduced.groupby(['fold','empty'])['id'].count())'''

In [None]:
skf = StratifiedGroupKFold(n_splits=5, shuffle=True, random_state=CFG.seed)

for fold, (train_idx, val_idx) in enumerate(skf.split(df, df['empty'], groups = df["case"])):
    df.loc[val_idx, 'fold'] = fold
display(df.groupby(['fold','empty'])['id'].count())

In [None]:
fold = 0
len(df.query("fold != @fold")), len(df.query("fold == @fold"))

In [None]:
df.query("fold == 0")['empty'].value_counts().plot.bar()

In [None]:
import warnings
warnings.filterwarnings("ignore")

In [None]:
epochs = 5

model = UW_model(df = df, fold = 0)

drive_logs_folder = os.path.join("./experiments/logs")
os.makedirs(drive_logs_folder, exist_ok=True)

model_name = "ResUNet18"
tb_logger = pl_loggers.TensorBoardLogger(save_dir=drive_logs_folder, name="{}_logs".format(model_name))

checkpoint_callback = ModelCheckpoint(
    #dirpath=drive_logs_folder,
    monitor="loss/Val",
    save_top_k=2,
    mode="min",
    save_last = True
)

eaerly_stopping_cb = EarlyStopping(monitor="loss/Train", mode="min",patience = 30)

trainer = pl.Trainer(
    max_epochs=CFG.epochs,
    gpus=1,
    logger = tb_logger,
    default_root_dir=drive_logs_folder,
    callbacks=[checkpoint_callback,eaerly_stopping_cb]
    )

In [None]:
trainer.fit(model)