This is a training notebook that I use for the [HubMAP competition](https://www.kaggle.com/c/hubmap-kidney-segmentation). You can this notebook to experiment on many other segmentation problem. For convinience, I using Catalyst framework and also segmentation_model_pytorch library and pytorch_toolbelt which is a great library for doing quick development. If you like my notebook, please give an upvote, thank you.

## Install and import packages

In [None]:
!pip install -q catalyst==20.12
!pip install -q pytorch-toolbelt==0.4.2
!pip install -q torch-optimizer==0.1.0
!pip install -q segmentation-models-pytorch==0.1.3
!pip install -q ttach==0.0.3
!pip install -q albumentations==0.5.2
!pip install -q timm

In [None]:
import segmentation_models_pytorch as smp
import torch
import torch.nn as nn
from torch.optim.optimizer import Optimizer
import torch.nn.functional as F
import math
import warnings

from catalyst.contrib.nn import OneCycleLRWithWarmup
from torch.optim.lr_scheduler import (
    ExponentialLR,
    CyclicLR,
    MultiStepLR,
    CosineAnnealingLR,
    CosineAnnealingWarmRestarts,
    ReduceLROnPlateau
)
from pytorch_toolbelt.losses import *
from pytorch_toolbelt.utils import image_to_tensor
from pytorch_toolbelt.utils.random import set_manual_seed
from torch.nn import KLDivLoss
from catalyst import utils
from catalyst.contrib.nn import OneCycleLRWithWarmup
from catalyst import dl
from catalyst.contrib.utils.cv import image as cata_image
from catalyst.contrib.callbacks.wandb_logger import WandbLogger
from catalyst.dl import (
    SupervisedRunner, 
    CriterionCallback, 
    EarlyStoppingCallback, 
    SchedulerCallback, 
    MetricAggregationCallback, 
    IouCallback, 
    DiceCallback
)
from torch.optim.lr_scheduler import CyclicLR, ReduceLROnPlateau
from torch.utils.data import Dataset, DataLoader

import albumentations as A

import os
import cv2
import json
from PIL import Image
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.metrics import precision_recall_curve, auc, average_precision_score
from pathlib import Path
from tqdm .auto import tqdm
import plotly.express as px
from sklearn.model_selection import train_test_split
from pathlib import Path
from collections import OrderedDict
from typing import List,  Optional, Dict

%matplotlib inline

## Define data paths

In [None]:
MAIN_PATH = Path('../input/hubmap-256x256')
IMG_PATHS = MAIN_PATH / 'train'
MASK_PATHS = MAIN_PATH / 'masks'
LABEL = '../input/hubmap-kidney-segmentation/train.csv'

## Initial seed

In [None]:
SEED = 1999
set_manual_seed(SEED)   
utils.set_global_seed(SEED)
utils.prepare_cudnn(deterministic=False, benchmark=True)

## Define config class 

In [None]:
class BaseConfig:
    __basedir__ = MAIN_PATH
    train_img_path = IMG_PATHS
    train_mask_path = MASK_PATHS
    
    #Data config
    augmentation = 'medium' #options: normal, easy, medium, advanced
    scale_size = 256
    
    #Train config
    num_epochs = 10
    batch_size = 32
    val_batch_size = 32
    learning_rate = 1e-5
    learning_rate_decode = 1e-3
    weight_decay = 2.5e-5
    is_fp16 = True
    
    #Model config
    model_name = None
    model_params = None
    
    #Metric config
    metric = "dice"
    mode = "max"
    
    #Optimize config
    criterion = {"bce": 0.8, 'log_dice': 0.2}
    pos_weights = [200]
    optimizer = "madgrad"
    scheduler = "simple"

    resume_path = None #Resume training
    
    @classmethod
    def get_all_attributes(cls):
        d = {}
        attributes = dict(cls.__dict__)

        for k, v in attributes.items():
            if not k.startswith('__') and k != 'get_all_attributes':
                d[k] = v
                
        return d

In [None]:
train_config = BaseConfig.get_all_attributes()

## Dataset

In [None]:
class HUBMAPSegmentation(Dataset):
    def __init__(self, images: List[Path], masks: List[Path] = None, transform=None, preprocessing_fn=None):
        self.images = images
        self.masks = masks
        self.transform = transform
        self.preprocessing_fn = preprocessing_fn

    def __len__(self):
        return len(self.images)

    def __getitem__(self, index: int) -> dict:
        image_path = self.images[index]

        result = OrderedDict()
        image = cata_image.imread(image_path)
        result['image'] = image
        
        if self.masks is not None:
            mask = Image.open(self.masks[index]).convert('L')
            mask = mask.point(lambda x: 255 if x > 0 else 0, '1')
            mask = np.asarray(mask).astype(np.uint8)
            result['mask'] = mask

        if self.transform is not None:
            transformed = self.transform(**result)
            image = transformed['image']
            if self.masks is not None:
                mask = transformed['mask']
                mask = image_to_tensor(mask, dummy_channels_dim=True).float()
                result['mask'] = mask
                
        if self.preprocessing_fn is not None:
            image = self.preprocessing_fn(image = image)['image']
            
        image = image_to_tensor(image).float()           
        result['image'] = image
        result['filename'] = image_path.name

        return result

## Get preprocessing function

In [None]:
def get_preprocessing_fn():
    mean = [0.65459856,0.48386562,0.69428385],
    std = [0.15167958,0.23584107,0.13146145]

    def preprocessing(x, mean=mean, std=std, **kwargs):
        x = x / 255.0
        if mean is not None:
            mean = np.array(mean)
            x = x - mean

        if std is not None:
            std = np.array(std)
            x = x / std
        return x

    return preprocessing, mean, std

preprocessing_fn, mean, std = get_preprocessing_fn()

## Get data augmentation

In [None]:
class BaseTransform(object):
    def __init__(self, image_size: int = 1024, preprocessing_fn=None):
        self.image_size = image_size
        self.preprocessing_fn = preprocessing_fn

    def pre_transform(self):
        raise NotImplementedError()

    def hard_transform(self):
        raise NotImplementedError()

    def resize_transforms(self):
        raise NotImplementedError()

    def _get_compose(self, transform):
        result = A.Compose([
            item for sublist in transform for item in sublist
        ])
        return result

    def train_transform(self):
        return self._get_compose([
            self.resize_transforms(),
            self.hard_transform()
        ])

    def validation_transform(self):
        return self._get_compose([
            self.pre_transform()
        ])

    def test_transform(self):
        return self.validation_transform()

    def get_preprocessing(self):
        return A.Compose([
            A.Lambda(image=self.preprocessing_fn)
        ])
    
class NormalTransform(BaseTransform):
    def __init__(self, *args, **kwargs):
        super(NormalTransform, self).__init__(*args, **kwargs)

    def hard_transform(self):
        return [
            A.VerticalFlip(p=0.5),
            A.HorizontalFlip(p=0.5),
            A.RandomRotate90(p=0.7),
        ]
    
    def resize_transforms(self):
        return [
            A.LongestMaxSize(self.image_size),
            A.PadIfNeeded(min_height=self.image_size, min_width=self.image_size,
                          border_mode=cv2.BORDER_CONSTANT, value=0)
        ]
    
    def pre_transform(self):
        return self.resize_transforms()
            
class MediumTransform(NormalTransform):
    def __init__(self, *args, **kwargs):
        super(MediumTransform, self).__init__(*args, **kwargs)

    def hard_transform(self):
        return [
            A.VerticalFlip(p=0.5),
            A.HorizontalFlip(p=0.5),
            A.RandomRotate90(p=0.7),
            A.OneOf([
                A.ElasticTransform(alpha=120, sigma=120 * 0.05,
                                   alpha_affine=120 * 0.03, p=0.5),
                A.GridDistortion(p=0.5),
                A.OpticalDistortion(distort_limit=2, shift_limit=0.5, p=0.5)
            ], p=0.5),
            A.CLAHE(p=0.5),
            A.RandomBrightnessContrast(p=0.5),
            A.RandomGamma(p=0.5)
        ]
    

def get_transform(name):
    if name == 'normal':
        return NormalTransform
    if name == 'easy':
        return EasyTransform
    if name == 'medium':
        return MediumTransform
    if name == 'advanced':
        return AdvancedTransform

## Showing some examples from dataset

In [None]:
#example of train images with masks
normal = get_transform('medium')(train_config['scale_size'], preprocessing_fn)
transform = normal.train_transform()
preprocessing = normal.get_preprocessing()
ds = HUBMAPSegmentation(
        sorted(train_config['train_img_path'].glob('*.*')), 
        sorted(train_config['train_mask_path'].glob('*.*')), 
        transform,    
        preprocessing
)

dl = DataLoader(ds,batch_size=64,shuffle=False,num_workers=2)
batch_dict = next(iter(dl))
imgs = batch_dict['image']
masks = batch_dict['mask']

plt.figure(figsize=(16,16))
for i,(img,mask) in enumerate(zip(imgs,masks)):
    img = ((img.permute(1,2,0)*np.array(std) + np.array(mean))*255.0).numpy().astype(np.uint8)
    plt.subplot(8,8,i+1)
    plt.imshow(img,vmin=0,vmax=255)
    plt.imshow(mask.squeeze().numpy(), alpha=0.2)
    plt.axis('off')
    plt.subplots_adjust(wspace=None, hspace=None)
    
del ds,dl,imgs,masks

## Get data loader

In [None]:
def get_loader(
    images: List[Path],
    random_state: int,
    valid_size: float = 0.2,
    batch_size: int = 4,
    val_batch_size: int = 8,
    num_workers: int = 4,
    train_transforms_fn=None,
    valid_transforms_fn=None,
    preprocessing_fn=None,
    masks: List[Path] = None,
):    
    indices = np.arange(len(images))

    train_indices, valid_indices = train_test_split(
        indices, test_size=valid_size, random_state=random_state, shuffle=True)

    np_images = np.array(images)
    train_images = np_images[train_indices].tolist()
    val_images = np_images[valid_indices].tolist()
    np_masks = np.array(masks)
    train_masks = np_masks[train_indices].tolist()
    val_masks = np_masks[valid_indices].tolist()

    train_dataset = HUBMAPSegmentation(
        sorted(train_images),
        masks=sorted(train_masks),
        transform=train_transforms_fn,
        preprocessing_fn=preprocessing_fn,
    )

    valid_dataset = HUBMAPSegmentation(
        sorted(val_images),
        masks=sorted(val_masks),
        transform=valid_transforms_fn,
        preprocessing_fn=preprocessing_fn,
    )

    train_loader = DataLoader(
        train_dataset,
        batch_size=batch_size,
        num_workers=num_workers,
        shuffle=True,
        pin_memory=True,
        drop_last=True
    )

    valid_loader = DataLoader(
        valid_dataset,
        batch_size=val_batch_size,
        num_workers=num_workers,
        shuffle=True,
        pin_memory=True,
        drop_last=True
    )

    loaders = OrderedDict()
    loaders['train'] = train_loader
    loaders['valid'] = valid_loader

    return loaders

## Get model 

In [None]:
def get_model(params, model_name):
    
    # Model return logit values
    model = getattr(smp, model_name)(
        **params
    )
    return model

In [None]:
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import math
from typing import TYPE_CHECKING, Any, Callable, Optional

import torch
import torch.optim

if TYPE_CHECKING:
    from torch.optim.optimizer import _params_t
else:
    _params_t = Any

class MADGRAD(torch.optim.Optimizer):
    """
    MADGRAD_: A Momentumized, Adaptive, Dual Averaged Gradient Method for Stochastic 
    Optimization.

    .. _MADGRAD: https://arxiv.org/abs/2101.11075

    MADGRAD is a general purpose optimizer that can be used in place of SGD or
    Adam may converge faster and generalize better. Currently GPU-only.
    Typically, the same learning rate schedule that is used for SGD or Adam may
    be used. The overall learning rate is not comparable to either method and
    should be determined by a hyper-parameter sweep.

    MADGRAD requires less weight decay than other methods, often as little as
    zero. Momentum values used for SGD or Adam's beta1 should work here also.

    On sparse problems both weight_decay and momentum should be set to 0.

    Arguments:
        params (iterable): 
            Iterable of parameters to optimize or dicts defining parameter groups.
        lr (float): 
            Learning rate (default: 1e-2).
        momentum (float): 
            Momentum value in  the range [0,1) (default: 0.9).
        weight_decay (float): 
            Weight decay, i.e. a L2 penalty (default: 0).
        eps (float): 
            Term added to the denominator outside of the root operation to improve numerical stability. (default: 1e-6).
    """

    def __init__(
        self, params: _params_t, lr: float = 1e-2, momentum: float = 0.9, weight_decay: float = 0, eps: float = 1e-6,
    ):
        if momentum < 0 or momentum >= 1:
            raise ValueError(f"Momentum {momentum} must be in the range [0,1]")
        if lr <= 0:
            raise ValueError(f"Learning rate {lr} must be positive")
        if weight_decay < 0:
            raise ValueError(f"Weight decay {weight_decay} must be non-negative")
        if eps < 0:
            raise ValueError(f"Eps must be non-negative")

        defaults = dict(lr=lr, eps=eps, momentum=momentum, weight_decay=weight_decay)
        super().__init__(params, defaults)

    @property
    def supports_memory_efficient_fp16(self) -> bool:
        return False

    @property
    def supports_flat_params(self) -> bool:
        return True

    def step(self, closure: Optional[Callable[[], float]] = None) -> Optional[float]:
        """Performs a single optimization step.

        Arguments:
            closure (callable, optional): A closure that reevaluates the model
                and returns the loss.
        """
        loss = None
        if closure is not None:
            loss = closure()

        # step counter must be stored in state to ensure correct behavior under
        # optimizer sharding
        if 'k' not in self.state:
            self.state['k'] = torch.tensor([0], dtype=torch.long)
        k = self.state['k'].item()

        for group in self.param_groups:
            eps = group["eps"]
            lr = group["lr"] + eps
            decay = group["weight_decay"]
            momentum = group["momentum"]

            ck = 1 - momentum
            lamb = lr * math.pow(k + 1, 0.5)

            for p in group["params"]:
                if p.grad is None:
                    continue
                grad = p.grad.data
                state = self.state[p]

                if "grad_sum_sq" not in state:
                    state["grad_sum_sq"] = torch.zeros_like(p.data).detach()
                    state["s"] = torch.zeros_like(p.data).detach()
                    if momentum != 0:
                        state["x0"] = torch.clone(p.data).detach()

                if momentum != 0.0 and grad.is_sparse:
                    raise RuntimeError("momentum != 0 is not compatible with sparse gradients")

                grad_sum_sq = state["grad_sum_sq"]
                s = state["s"]

                # Apply weight decay
                if decay != 0:
                    if grad.is_sparse:
                        raise RuntimeError("weight_decay option is not compatible with sparse gradients")

                    grad.add_(p.data, alpha=decay)

                if grad.is_sparse:
                    grad = grad.coalesce()
                    grad_val = grad._values()

                    p_masked = p.sparse_mask(grad)
                    grad_sum_sq_masked = grad_sum_sq.sparse_mask(grad)
                    s_masked = s.sparse_mask(grad)

                    # Compute x_0 from other known quantities
                    rms_masked_vals = grad_sum_sq_masked._values().pow(1 / 3).add_(eps)
                    x0_masked_vals = p_masked._values().addcdiv(s_masked._values(), rms_masked_vals, value=1)

                    # Dense + sparse op
                    grad_sq = grad * grad
                    grad_sum_sq.add_(grad_sq, alpha=lamb)
                    grad_sum_sq_masked.add_(grad_sq, alpha=lamb)

                    rms_masked_vals = grad_sum_sq_masked._values().pow_(1 / 3).add_(eps)

                    s.add_(grad, alpha=lamb)
                    s_masked._values().add_(grad_val, alpha=lamb)

                    # update masked copy of p
                    p_kp1_masked_vals = x0_masked_vals.addcdiv(s_masked._values(), rms_masked_vals, value=-1)
                    # Copy updated masked p to dense p using an add operation
                    p_masked._values().add_(p_kp1_masked_vals, alpha=-1)
                    p.data.add_(p_masked, alpha=-1)
                else:
                    if momentum == 0:
                        # Compute x_0 from other known quantities
                        rms = grad_sum_sq.pow(1 / 3).add_(eps)
                        x0 = p.data.addcdiv(s, rms, value=1)
                    else:
                        x0 = state["x0"]

                    # Accumulate second moments
                    grad_sum_sq.addcmul_(grad, grad, value=lamb)
                    rms = grad_sum_sq.pow(1 / 3).add_(eps)

                    # Update s
                    s.data.add_(grad, alpha=lamb)

                    # Step
                    if momentum == 0:
                        p.data.copy_(x0.addcdiv(s, rms, value=-1))
                    else:
                        z = x0.addcdiv(s, rms, value=-1)

                        # p is a moving average of z
                        p.data.mul_(1 - ck).add_(z, alpha=ck)


        self.state['k'] += 1
        return loss

## Many optimizers to try

In [None]:
def get_optimizer(
    optimizer_name: str, parameters, learning_rate: float, weight_decay=1e-5, eps=1e-5, **kwargs
) -> Optimizer:
    from torch.optim import SGD, Adam, RMSprop, AdamW
    from torch_optimizer import RAdam, Lamb, DiffGrad, NovoGrad, Ranger, Lookahead
    
    lookahead = False
    if len(optimizer_name.split('_')) > 1:
        optimizer_name = optimizer_name.split('_')[0]
        lookahead=True

    if optimizer_name.lower() == "sgd":
        base_optim =  SGD(parameters, learning_rate, momentum=0.9, nesterov=True, weight_decay=weight_decay, **kwargs)

    if optimizer_name.lower() == "adam":
        # As Jeremy suggests
        base_optim = Adam(parameters, learning_rate, weight_decay=weight_decay, eps=eps, **kwargs)

    if optimizer_name.lower() == "rms":
        base_optim = RMSprop(parameters, learning_rate, weight_decay=weight_decay, **kwargs)

    if optimizer_name.lower() == "adamw":
        base_optim =  AdamW(parameters, learning_rate, weight_decay=weight_decay, eps=eps, **kwargs)

    if optimizer_name.lower() == "radam":
        # As Jeremy suggests
        base_optim =  RAdam(parameters, learning_rate, weight_decay=weight_decay, eps=eps, **kwargs)

    # Optimizers from torch-optimizer
    if optimizer_name.lower() == "ranger":
        base_optim = Ranger(parameters, learning_rate, eps=eps, weight_decay=weight_decay, **kwargs)

    if optimizer_name.lower() == "lamb":
        base_optim =  Lamb(parameters, learning_rate, eps=eps, weight_decay=weight_decay, **kwargs)

    if optimizer_name.lower() == "diffgrad":
        base_optim = DiffGrad(parameters, learning_rate, eps=eps, weight_decay=weight_decay, **kwargs)

    if optimizer_name.lower() == "novograd":
        base_optim = NovoGrad(parameters, learning_rate, eps=eps, weight_decay=weight_decay, **kwargs)

    if optimizer_name.lower() == "madgrad":
        base_optim = MADGRAD(parameters, learning_rate, eps=eps, weight_decay=weight_decay, **kwargs)
    else:
        raise ValueError("Unsupported optimizer name " + optimizer_name)
    
    if lookahead:
        return Lookahead(base_optim)
    return base_optim

## Many losses to try 

In [None]:
class WeightedBCEWithLogits(nn.Module):
    def __init__(self, pos_weights, ignore_index: Optional[int] = -100, reduction="mean"):
        super().__init__()
        self.ignore_index = ignore_index
        self.reduction = reduction
        self.pos_weights = pos_weights

    def forward(self, label_input: torch.Tensor, target: torch.Tensor):

        if self.ignore_index is not None:
            not_ignored_mask = (target != self.ignore_index).float()

        loss = nn.BCEWithLogitsLoss(reduce=None, pos_weight=self.pos_weights)(label_input, target)

        if self.ignore_index is not None:
            loss = loss * not_ignored_mask.float()

        if self.reduction == "mean":
            loss = loss.mean()

        if self.reduction == "sum":
            loss = loss.sum()

        return loss

class KLDivLossWithLogits(KLDivLoss):
    def __init__(self):
        super().__init__()

    def forward(self, input, target):

        # Resize target to size of input
        target = F.interpolate(target, size=input.size()[2:], mode="bilinear", align_corners=False)

        input = torch.cat([input, 1 - input], dim=1)
        log_p = F.logsigmoid(input)

        target = torch.cat([target, 1 - target], dim=1)

        loss = F.kl_div(log_p, target, reduction="mean")
        return loss

class SymmetricLovasz(nn.Module):
    def __init__(self):
        super().__init__()
    def forward(self, outputs, targets):
        return 0.5*(lovasz_hinge(outputs, targets) + lovasz_hinge(-outputs, 1.0 - targets))

def get_loss(loss_name: str, ignore_index=None):
    if loss_name.lower() == "kl":
        return KLDivLossWithLogits()

    if loss_name.lower() == "bce":
        return SoftBCEWithLogitsLoss(ignore_index=ignore_index)

    if loss_name.lower() == 'wbce':
        return WeightedBCEWithLogits(ignore_index=ignore_index)

    if loss_name.lower() == "ce":
        return nn.CrossEntropyLoss()

    if loss_name.lower() == "soft_bce":
        return SoftBCEWithLogitsLoss(smooth_factor=0.1, ignore_index=ignore_index)

    if loss_name.lower() == "focal":
        return BinaryFocalLoss(alpha=None, gamma=1.5, ignore_index=ignore_index)

    if loss_name.lower() == "jaccard":
        assert ignore_index is None
        return JaccardLoss(mode="binary")

    if loss_name.lower() == "log_jaccard":
        assert ignore_index is None
        return JaccardLoss(mode="binary", log_loss=True)

    if loss_name.lower() == "dice":
        assert ignore_index is None
        return DiceLoss(mode="binary", log_loss=False)

    if loss_name.lower() == "log_dice":
        assert ignore_index is None
        return DiceLoss(mode="binary", log_loss=True)

    raise KeyError(loss_name)

## Many lr scheduler to try

In [None]:
class CosineAnnealingWarmRestartsWithDecay(CosineAnnealingWarmRestarts):
    def __init__(self, optimizer, T_0, T_mult=1, eta_min=0, last_epoch=-1, gamma=0.9):
        super().__init__(optimizer, T_0, T_mult, eta_min, last_epoch)
        self.gamma = gamma

    def get_lr(self):
        if not self._get_lr_called_within_step:
            warnings.warn(
                "To get the last learning rate computed by the scheduler, " "please use `get_last_lr()`.",
                DeprecationWarning,
            )

        return [
            self.eta_min
            + (base_lr * self.gamma ** self.last_epoch - self.eta_min)
            * (1 + math.cos(math.pi * self.T_cur / self.T_i))
            / 2
            for base_lr in self.base_lrs
        ]

def get_scheduler(scheduler_name: str, optimizer, lr, num_epochs, batches_in_epoch=None, mode=None):
    if scheduler_name is None or scheduler_name.lower() == "none":
        return None

    if scheduler_name.lower() == "reduce":
        return ReduceLROnPlateau(optimizer, mode=mode, patience=5)

    if scheduler_name.lower() == "cos":
        return CosineAnnealingLR(optimizer, num_epochs, eta_min=1e-6)

    if scheduler_name.lower() == "cos2":
        return CosineAnnealingLR(optimizer, num_epochs, eta_min=float(lr * 0.5))

    if scheduler_name.lower() == "cosr":
        return CosineAnnealingWarmRestarts(optimizer, T_0=max(2, num_epochs // 4), eta_min=1e-6)

    if scheduler_name.lower() == "cosrd":
        return CosineAnnealingWarmRestartsWithDecay(optimizer, T_0=max(2, num_epochs // 6), gamma=0.96, eta_min=1e-6)

    if scheduler_name.lower() in {"1cycle", "one_cycle"}:
        return OneCycleLRWithWarmup(
            optimizer,
            lr_range=(lr, 1e-6),
            num_steps=batches_in_epoch * num_epochs,
            warmup_fraction=0.05,
            decay_fraction=0.1,
        )

    if scheduler_name.lower() == "exp":
        return ExponentialLR(optimizer, gamma=0.95)

    if scheduler_name.lower() == "clr":
        return CyclicLR(
            optimizer,
            base_lr=1e-6,
            max_lr=lr,
            step_size_up=batches_in_epoch // 4,
            # mode='exp_range',
            cycle_momentum=True,
            gamma=0.99,
        )

    if scheduler_name.lower() == "multistep":
        return MultiStepLR(
            optimizer, milestones=[int(num_epochs * 0.5), int(num_epochs * 0.7), int(num_epochs * 0.9)], gamma=0.3
        )

    if scheduler_name.lower() == "simple":
        return MultiStepLR(optimizer, milestones=[int(num_epochs * 0.4), int(num_epochs * 0.7)], gamma=0.1)

    raise KeyError(scheduler_name)

## Prepare everything we need for the experiment

In [None]:
def prepare_everything(exp_name):
    print("===> Get model")
    model = get_model(
        train_config['model_params'],
        train_config['model_name']
    )

    print("===> Get transformation")
    #Define transform (augemntation)
    Transform = get_transform(train_config['augmentation'])
    transforms = Transform(
        train_config['scale_size'],
        preprocessing_fn=preprocessing_fn
    )

    train_transform = transforms.train_transform()
    val_transform = transforms.validation_transform()
    preprocessing = transforms.get_preprocessing()
    
    print("===> Get data loader")
    loader = get_loader(
        images = sorted(train_config['train_img_path'].glob('*.*')),
        masks = sorted(train_config['train_mask_path'].glob('*.*')),
        random_state = SEED,
        valid_size = 0.2,
        batch_size = train_config['batch_size'],
        val_batch_size = train_config['val_batch_size'],
        num_workers = 2,
        train_transforms_fn=train_transform,
        valid_transforms_fn=val_transform,
        preprocessing_fn=preprocessing
    )

    print("===> Get optimizer")
    param_group = []
    if hasattr(model, 'encoder'):
        encoder_params = filter(lambda p: p.requires_grad, model.encoder.parameters())
        param_group += [{'params': encoder_params, 'lr': train_config['learning_rate']}]        
    if hasattr(model, 'decoder'):
        decoder_params = filter(lambda p: p.requires_grad, model.decoder.parameters())
        param_group += [{'params': decoder_params}]
    if hasattr(model, 'segmentation_head'):
        head_params = filter(lambda p: p.requires_grad, model.segmentation_head.parameters())
        param_group += [{'params': head_params}]        
    if len(param_group) == 0:
        param_group = list(model.parameters())
        
    total = int(sum(p.numel() for p in model.parameters()))
    trainable = int(sum(p.numel() for p in model.parameters() if p.requires_grad))
    count_parameters = {"total": total, "trainable": trainable}

    print(
        f'[INFO] total and trainable parameters in the model {count_parameters}'
    )
    
    #Set optimizer
    optimizer = get_optimizer(
        train_config['optimizer'], param_group, train_config['learning_rate_decode'], train_config['weight_decay'])
    
    print("===> Get shceduler")
    scheduler = get_scheduler(
        train_config['scheduler'], optimizer, train_config['learning_rate'], train_config['num_epochs'],
        batches_in_epoch=len(loader['train']), mode=train_config['mode']
    )
        
    print("===> Get loss")
    criterion = {}
    for loss_name in train_config['criterion']:
        if loss_name == 'wbce':
            pos_weights = torch.tensor(train_config['pos_weights'], device=utils.get_device())
            loss_fn = WeightedBCEWithLogits(pos_weights=pos_weights)
        else:
            loss_fn = get_loss(loss_name)
        criterion[loss_name] = loss_fn
        
    print("===> Get callbacks")
    #Define callbacks
    callbacks = []
    losses = []
    for loss_name, loss_weight in train_config['criterion'].items():
        criterion_callback = CriterionCallback(
            input_key="mask",
            output_key="logits",
            criterion_key=loss_name,
            prefix="loss_"+loss_name,
            multiplier=float(loss_weight)
        )

        callbacks.append(criterion_callback)
        losses.append(criterion_callback.prefix)

    callbacks += [MetricAggregationCallback(
        prefix="loss",
        mode="sum",
        metrics=losses
    )]

    if isinstance(scheduler, (CyclicLR, OneCycleLRWithWarmup)):
        callbacks += [SchedulerCallback(mode="batch")]
    elif isinstance(scheduler, (ReduceLROnPlateau)):
        callbacks += [SchedulerCallback(reduced_metric=train_config['metric'])]
        
    early_stopping = EarlyStoppingCallback(
        patience=10, metric=train_config['metric'], minimize=False)

    iou_scores = IouCallback(
        input_key="mask",
        activation="Sigmoid",
        threshold=0.5
    )

    dice_scores = DiceCallback(
        input_key="mask",
        activation="Sigmoid",
        threshold=0.5
    )
    
    prefix = exp_name
    log_dir = os.path.join("./", prefix)
    if not os.path.exists(log_dir):
        os.makedirs(log_dir)
    
    
    callbacks += [early_stopping,
                  iou_scores, dice_scores]
    
    print("===> Saving config setting")
    #Save config as JSON format
    with open(os.path.join(log_dir, 'config.json'), 'w') as f:
        save_config = train_config.copy()
        save_config['train_img_path'] = str(save_config['train_img_path'])
        save_config['train_mask_path'] = str(save_config['train_mask_path'])
        json.dump(save_config, f)
        
    print("===> Done")
    
    return {
        'model': model,
        'loader': loader,
        'optimizer': optimizer,
        'criterion': criterion,
        'callbacks': callbacks,
        'scheduler': scheduler,
        'log_dir': log_dir
    }

## Model 1

In [None]:
train_config['model_name'] =  'UnetPlusPlus'
train_config['model_params'] = {
  'classes': 1,
  'decoder_attention_type': 'scse',
  'decoder_use_batchnorm': True,
  'encoder_depth': 5,
  'encoder_name': 'efficientnet-b2',
  'encoder_weights': 'imagenet',
  'in_channels': 3
}

In [None]:
runner = SupervisedRunner(
    device=utils.get_device(), input_key="image", input_target_key="mask")

if train_config['is_fp16']:
    fp16_params = dict(amp=True)  # params for FP16
else:
    fp16_params = None

everything = prepare_everything(f"experiment_{train_config['model_name'].lower()}")
print("Let's go!!!!")
runner.train(
    model=everything['model'],
    criterion=everything['criterion'],
    optimizer=everything['optimizer'],
    callbacks=everything['callbacks'],
    logdir=everything['log_dir'],
    loaders=everything['loader'],
    num_epochs=train_config['num_epochs'],
    scheduler=everything['scheduler'],
    main_metric=train_config['metric'],
    minimize_metric=False,
    timeit=True,
    fp16=fp16_params,
    resume=train_config['resume_path'],
    verbose=False,
)

## Model 2

In [None]:
train_config['model_name'] = 'DeepLabV3Plus'
train_config['model_params'] = {
      'classes': 1,
      'decoder_atrous_rates': [6, 12, 18],
      'encoder_depth': 5,
      'encoder_name': 'se_resnext50_32x4d',
      'encoder_weights': 'imagenet',
      'in_channels': 3
}

In [None]:
runner = SupervisedRunner(
    device=utils.get_device(), input_key="image", input_target_key="mask")

if train_config['is_fp16']:
    fp16_params = dict(amp=True)  # params for FP16
else:
    fp16_params = None

everything = prepare_everything(f"experiment_{train_config['model_name'].lower()}")
print("Let's go!!!!")
runner.train(
    model=everything['model'],
    criterion=everything['criterion'],
    optimizer=everything['optimizer'],
    callbacks=everything['callbacks'],
    logdir=everything['log_dir'],
    loaders=everything['loader'],
    num_epochs=train_config['num_epochs'],
    scheduler=everything['scheduler'],
    main_metric=train_config['metric'],
    minimize_metric=False,
    timeit=True,
    fp16=fp16_params,
    resume=train_config['resume_path'],
    verbose=True,
)

## THE END