This is my baseline Unet architecture for the competition segmentation task. I'm currently using [this](https://www.kaggle.com/xhlulu/hubmap-512x512-full-size-tiles) dataset. Please give my notebook an upvote if you find it useful. 

I will try to update this notebook continuously. This version has the follwoing features:

- Unet with Attention Gates as skip connections
- Data Augmentations with Albumentation
- Resnest as Unet Encoder (The current version support almost all *Res* models though)
- Gradient Accumulation
- Mixed Precison

### Installing necessary libraries

In [None]:
! pip install -q timm

## Config

I'll convert this cell to a **class** later. You can change any parameters such as learning rate, batch size, encoder model from here.

In [None]:
import cv2
import albumentations as A
from albumentations.augmentations.transforms import Equalize, Posterize, Downscale
from albumentations import (
    PadIfNeeded, HorizontalFlip, VerticalFlip, CenterCrop,    
    RandomCrop, Resize, Crop, Compose, HueSaturationValue,
    Transpose, RandomRotate90, ElasticTransform, GridDistortion, 
    OpticalDistortion, RandomSizedCrop, Resize, CenterCrop,
    VerticalFlip, HorizontalFlip, OneOf, CLAHE, Normalize,
    RandomBrightnessContrast, Cutout, RandomGamma, ShiftScaleRotate ,
    GaussNoise, Blur, MotionBlur, GaussianBlur, 
)

SEED = 69
n_epochs = 10
device = 'cuda:0'
data_dir = '../input/hubmap-512x512-full-size-tiles'
loss_thr = 1e6
img_path = f'{data_dir}/train'
label_path = f'{data_dir}/masks'
encoder_model = 'resnest50_fast_1s1x64d'
model_name= 'ResUnest50' # Will come up with a better name later
model_dir = 'model_dir'
history_dir = 'history_dir'
load_model = False
apply_log = False
img_dim = 320
batch_size = 24
accum_step = 2
learning_rate = 2.50e-3
num_workers = 4
mixed_precision = True
patience = 3
train_aug = A.Compose([A.CenterCrop(p=0.3, height=300, width=300),
A.augmentations.transforms.RandomCrop(280, 280, p=0.3),
A.augmentations.transforms.Rotate(limit=30, interpolation=1, border_mode=4, value=None, mask_value=None, always_apply=False, p=0.5),
A.augmentations.transforms.Resize(320, 320, interpolation=1, always_apply=True, p=0.6),
Cutout(num_holes=8, max_h_size=20, max_w_size=20, fill_value=0, always_apply=False, p=0.2),
# A.RandomBrightnessContrast(brightness_limit=0.3, contrast_limit=0.3, brightness_by_max=True, always_apply=False, p=0.3),
# A.augmentations.transforms.HueSaturationValue(hue_shift_limit=10, sat_shift_limit=20, val_shift_limit=20, always_apply=False, p=0.4),
OneOf([
        GaussNoise(var_limit=0.1),
        Blur(),
        GaussianBlur(blur_limit=3),
        # RandomGamma(p=0.7),
        ], p=0.3),
A.HorizontalFlip(p=0.3)])
val_aug = Compose([Normalize(always_apply=True)])

## Dataset Class: Borrowed from [here](https://www.kaggle.com/orkatz2/hubmap-res34unet-baseline-train)

In [None]:
from __future__ import print_function, division
import numpy as np
import os
import random
# os.environ["CUDA_VISIBLE_DEVICES"]="-1"
import torch
import numpy as np
import cv2
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utils

class HuBMAPDataset(Dataset):
    def __init__(self, ids, transforms=None, preprocessing=None):
        self.ids = ids
        self.transforms = transforms
        self.preprocessing = preprocessing
    def __getitem__(self, idx):
        name = self.ids[idx]
        img = cv2.imread(f"{img_path}/{name}")
        img = cv2.resize(img, (img_dim, img_dim))/255.
        mask = cv2.imread(f"{label_path}/{name}")[:,:,0:1]
        mask = cv2.resize(mask, (img_dim, img_dim), interpolation = cv2.INTER_AREA)
        if self.transforms:
            augmented = self.transforms(image=img, mask=mask)
            img = augmented['image']
            mask = augmented['mask']
        if self.preprocessing:
            preprocessed = self.preprocessing(image=img, mask=mask)
            img = preprocessed['image']
            mask = preprocessed['mask']
        return img.reshape(img_dim, img_dim, 3).transpose(2, 0, 1), mask

    def __len__(self):
        return len(self.ids)

## Unet with Resnet

I've created an Unet architecture that supports [Fast-Resnest](https://arxiv.org/abs/2004.08955) as it's encoder. I've found out that, this model is shallow, converges faster and gives better results. I've also replaces the skip connections with [Attention Gates](https://github.com/LeeJunHyun/Image_Segmentation/blob/master/network.py#L108). Attention technique has proven to be greatly useful for image classification and segmentation tasks recently. If you are more interested, please read this paper: [
Attention U-Net: Learning Where to Look for the Pancreas](https://arxiv.org/abs/1804.03999).

### Model Utils

In [None]:
import torch
from torch import nn

class DecoderBlock(nn.Module):
    def __init__(self, in_channels, n_filters, is_deconv=False, scale=True):
        super().__init__()

        # B, C, H, W -> B, C/4, H, W
        self.conv1 = nn.Conv2d(in_channels, in_channels // 4, 1)
        self.norm1 = nn.BatchNorm2d(in_channels // 4)
        nonlinearity = nn.ReLU
        self.relu1 = nonlinearity(inplace=True)

        if scale:
            # B, C/4, H, W -> B, C/4, H, W
            if is_deconv:
                self.upscale = nn.ConvTranspose2d(in_channels // 4, in_channels // 4, 3,
                                                  stride=2, padding=1, output_padding=1)
            else:
                self.upscale = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
        else:
            self.upscale = nn.Conv2d(in_channels // 4, in_channels // 4, 3, padding=1)
        self.norm2 = nn.BatchNorm2d(in_channels // 4)
        self.relu2 = nonlinearity(inplace=True)

        # B, C/4, H, W -> B, C, H, W
        self.conv3 = nn.Conv2d(in_channels // 4, n_filters, 1)
        self.norm3 = nn.BatchNorm2d(n_filters)
        self.relu3 = nonlinearity(inplace=True)

    def forward(self, x):
        x = self.conv1(x)
        x = self.norm1(x)
        x = self.relu1(x)
        x = self.upscale(x)
        x = self.norm2(x)
        x = self.relu2(x)
        x = self.conv3(x)
        x = self.norm3(x)
        x = self.relu3(x)
        return x

class Attention_block(nn.Module):
    def __init__(self,F_g,F_l,F_int):
        super(Attention_block,self).__init__()
        self.W_g = nn.Sequential(
            nn.Conv2d(F_g, F_int, kernel_size=1,stride=1,padding=0,bias=True),
            nn.BatchNorm2d(F_int)
            )
        
        self.W_x = nn.Sequential(
            nn.Conv2d(F_l, F_int, kernel_size=1,stride=1,padding=0,bias=True),
            nn.BatchNorm2d(F_int)
        )

        self.psi = nn.Sequential(
            nn.Conv2d(F_int, 1, kernel_size=1,stride=1,padding=0,bias=True),
            nn.BatchNorm2d(1),
            nn.Sigmoid()
        )
        
        self.relu = nn.ReLU(inplace=True)
        
    def forward(self,g,x):
        g1 = self.W_g(g)
        x1 = self.W_x(x)
        psi = self.relu(g1+x1)
        psi = self.psi(psi)

        return x*psi

class conv_block(nn.Module):
    def __init__(self,ch_in,ch_out):
        super(conv_block,self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(ch_in, ch_out, kernel_size=3,stride=1,padding=1,bias=True),
            nn.BatchNorm2d(ch_out),
            nn.ReLU(inplace=True),
        )
    def forward(self,x):
        x = self.conv(x)
        return x

In [None]:
def save_model(valid_loss, valid_dice, best_valid_loss, best_valid_dice, best_state, savepath):
    if valid_loss<best_valid_loss:
        print(f'Validation loss has decreased from:  {best_valid_loss:.4f} to: {valid_loss:.4f}. Saving checkpoint')
        torch.save(best_state, savepath+'_loss.pth')
        best_valid_loss = valid_loss
    if valid_dice>best_valid_dice:
        print(f'Validation dice has increased from:  {best_valid_dice:.4f} to: {valid_dice:.4f}. Saving checkpoint')
        torch.save(best_state, savepath + '_dice.pth')
        best_valid_dice = valid_dice
    else:
        torch.save(best_state, savepath + '_last.pth')
    return best_valid_loss, best_valid_dice 

### Model

In [None]:
import torch
from torch import nn
import timm

class Resnest(nn.Module):

    def __init__(self, model_name, out_neurons=600):
        super().__init__()
        try:
            self.backbone = timm.create_model(model_name, pretrained=True)
        except:
            self.backbone = torch.hub.load('zhanghang1989/ResNeSt', model_name, pretrained=True)
            
        self.in_features = 2048
        
    def forward(self, x):
        x = self.backbone.conv1(x)
        x = self.backbone.bn1(x)
        try:
            x = self.backbone.act1(x)
        except:
            x = self.backbone.relu(x)
        x = self.backbone.maxpool(x)

        layer1 = self.backbone.layer1(x)
        layer2 = self.backbone.layer2(layer1)
        layer3 = self.backbone.layer3(layer2)

        return x, layer1, layer2, layer3

class ResnestDecoder(nn.Module):
    def __init__(self):
        super().__init__()
        nonlinearity = nn.ReLU
        self.decode1 = DecoderBlock(1024, 512)
        self.decode2 = DecoderBlock(512, 256)
        self.decode3 = DecoderBlock(256, 64)
        self.decode4 = DecoderBlock(64, 32)
        self.decode5 = DecoderBlock(32, 16)
        self.conv1 = conv_block(1024, 512)
        self.conv2 = conv_block(512, 256)
        self.conv3 = conv_block(128, 64)
        self.Att1 = Attention_block(512, 512, 256)
        self.Att2 = Attention_block(256, 256, 64)
        self.Att3 = Attention_block(64, 64, 32)
        self.Att4 = Attention_block(64, 64, 32)
        self.conv4 = nn.Conv2d(64, 64, 3, 2, 1)
        self.finalconv2 = nn.Conv2d(16, 4, 3, padding=1)
        self.finalrelu2 = nonlinearity(inplace=True)
        self.finalconv3 = nn.Conv2d(4, 1, 3, padding=1)
    
    def forward(self, x, l1, l2, l3):
        d1 = self.decode1(l3)
        l2 = self.Att1(d1, l2)
        d1 = torch.cat((l2,d1),dim=1)
        d1 = self.conv1(d1)
        d2 = self.decode2(d1)
        l1 = self.Att2(d2, l1)
        d2 = torch.cat((l1,d2),dim=1)
        d2 = self.conv2(d2)
        d3 = self.decode3(d2)
        d3 = self.conv4(d3)
        x = self.Att3(d3, x)
        d3 = torch.cat((x,d3),dim=1)
        d3 = self.conv3(d3)
        d4 = self.decode4(d3)
        d5 = self.decode5(d4)
        out = self.finalconv2(d5)
        out = self.finalrelu2(out)
        out = self.finalconv3(out)
        return out

class resUnest(nn.Module):
    def __init__(self, encoder_model):
        super().__init__()
        self.resnest = Resnest(model_name=encoder_model)
        self.decoder = ResnestDecoder()
        
    def forward(self, x):
        x, l1, l2, l3 = self.resnest(x)
        out = self.decoder(x, l1, l2, l3)
        return out

### Metric and Loss Function

Since competion metric is [Dice Score](https://en.wikipedia.org/wiki/S%C3%B8rensen%E2%80%93Dice_coefficient) and it is differentiable, I decided to use it our loss function.

In [None]:
def dice_value(pred, target, threshold = 0.5, smooth = 1e-6):
    if threshold is not None:
        pred = (torch.sigmoid(pred) > threshold).float()
    else:
        pred = torch.sigmoid(pred)
    num = pred.size(0)
    m1 = pred.view(num, -1)  # Flatten
    m2 = target.view(num, -1)  # Flatten
    intersection = (m1 * m2).sum(dim=1)
    dice_score = (2. * intersection + smooth) / (m1.sum(dim=1) + m2.sum(dim=1) + smooth) 
    return dice_score.mean()

def dice_coeff(pred, target, threshold = None, smooth = 1e-6):
    pred = torch.sigmoid(pred)
    if threshold:
        pred = (pred>threshold).float()
    num = pred.size(0)
    m1 = pred.view(num, -1)  # Flatten
    m2 = target.view(num, -1)  # Flatten
    intersection = (m1 * m2).sum()

    return (2. * intersection + smooth) / (m1.sum() + m2.sum() + smooth)


def DiceLoss(input, target, apply_log=False):
    if apply_log:
        loss = - torch.log(dice_coeff(input, target))
    else:
        loss = 1 - dice_coeff(input, target)
    return loss.mean()

### Training

In [None]:
import logging
logging.basicConfig(level=logging.ERROR)
from functools import partial
from collections import Counter
import gc
import time
import pandas as pd
import torch
from torch import nn
from torch import optim

m_p = mixed_precision
if m_p:
  scaler = torch.cuda.amp.GradScaler() 

np.random.seed(SEED)

file_ids = os.listdir(img_path)
train_len = int(0.8*len(file_ids))
train_ds = HuBMAPDataset(file_ids[:train_len], train_aug)
train_loader = torch.utils.data.DataLoader(train_ds, batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=True)

val_ds = HuBMAPDataset(file_ids[train_len:])
valid_loader = torch.utils.data.DataLoader(
val_ds, batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=True)

os.makedirs(model_dir, exist_ok=True)
os.makedirs(history_dir, exist_ok=True)

result = pd.DataFrame(columns=['name', 'prediction', 'label', 'difference'])
if os.path.exists(f'{history_dir}/history_{model_name}_{img_dim}.csv'):
    history = pd.read_csv(f'{history_dir}/history_{model_name}_{img_dim}.csv')
else:
    history = pd.DataFrame(columns=['train_loss','train_time','val_loss','val_dice', 'val_time'])

model = resUnest(encoder_model=encoder_model).to(device)
criterion = partial(DiceLoss, apply_log=apply_log)


### Data Visualization 

In [None]:
from torchvision.utils import make_grid
from matplotlib import pyplot as plt
from matplotlib.pyplot import figure
imgs, masks = iter(valid_loader).next()
grid_imgs = make_grid(imgs, nrow=5)
figure(num=None, figsize=(12, 9), dpi=80, facecolor='w', edgecolor='k')
plt.imshow(np.transpose(grid_imgs.numpy(), (1, 2, 0)))

In [None]:
# grid_imgs = make_grid(masks, nrow=5)
# print(grid_imgs.size())
# figure(num=None, figsize=(12, 9), dpi=80, facecolor='w', edgecolor='k')
# plt.imshow(np.transpose(grid_imgs.numpy(), (0, 1, 2)))

I'm going to train a model for 10 epochs and save models at the end of each epoch, the models that have the lowest loss and highest dice scores.

In [None]:
def train_val(epoch, dataloader, optimizer, pretrained=None, train=True, mode='train', record=True):
    global m_p
    global result
    global batch_size
    global accum_step
    t1 = time.time()
    running_loss = 0
    epoch_samples = 0
    dice_scores = 0
    raw_dice_coeff = 0
    if pretrained:
        model.load_state_dict(pretrained)
    if train:
        model.train()
        print("Initiating train phase ...")
    else:
        model.eval()
        print("Initiating val phase ...")
    for idx, (img, labels) in enumerate(dataloader):
        with torch.set_grad_enabled(train):
            img = img.to(device, dtype=torch.float32)
            labels = labels.to(device, dtype=torch.float32)
            epoch_samples += len(img)
            optimizer.zero_grad()
            with torch.cuda.amp.autocast(m_p):
                if m_p:
                    img = img.half()
                else:
                    img = img.float()
                outputs = model(img)

                loss = criterion(outputs, labels).sum()
                running_loss += loss.item()*len(img)
                loss = loss/accum_step
      
                if train:
                     if m_p:
                         scaler.scale(loss).backward()
                         if (idx+1) % accum_step == 0:
                             scaler.step(optimizer)
                             scaler.update() 
                             optimizer.zero_grad()
                      # cyclic_scheduler.step()
                     else:
                         loss.backward()
                         if (idx+1) % accum_step == 0:
                             optimizer.step()
                             optimizer.zero_grad()

        elapsed = int(time.time() - t1)
        eta = int(elapsed / (idx+1) * (len(dataloader)-(idx+1)))
        dice_val = dice_value(outputs, labels).cpu().numpy()
        raw_dice_val = dice_value(outputs, labels, None).data.cpu().numpy()
        dice_scores += dice_val * len(labels) 
        raw_dice_coeff += raw_dice_val * len(labels)
        # if mode != 'train':
            # result = analyzer(idx, name, torch.sigmoid(outputs), labels, loss, result)
        if train:
            msg = f"Epoch: {epoch} Progress: [{idx}/{len(dataloader)}] loss: {(running_loss/epoch_samples):.4f} Time: {elapsed}s ETA: {eta} s"
        else:
            msg = f'Epoch {epoch} Progress: [{idx}/{len(dataloader)}] loss: {(running_loss/epoch_samples):.4f} Time: {elapsed}s ETA: {eta} s'
        print(msg, end= '\r')
    history.loc[epoch, f'{mode}_loss'] = running_loss/epoch_samples
    history.loc[epoch, f'{mode}_time'] = elapsed
    if mode=='val' or mode=='test':
        val_dice = dice_scores/epoch_samples
        raw_val_dice = raw_dice_coeff/epoch_samples
        lr_reduce_scheduler.step(running_loss)
        msg = f'{mode} Loss: {running_loss/epoch_samples:.4f} \n {mode} Dice Score: {val_dice:.4f} \n {mode} Raw Dice Score:{raw_val_dice:.4f}'
        print(msg)
        history.loc[epoch, f'{mode}_loss'] = running_loss/epoch_samples
        history.loc[epoch, f'{mode}_dice'] = val_dice
        history.loc[epoch, f'Raw_{mode}_dice'] = raw_val_dice
        # NaN check
        if running_loss/epoch_samples > loss_thr or running_loss!=running_loss:
            print('\033[91mMixed Precision\033[0m rendering nan value. Forcing \033[91mMixed Precision\033[0m to be False ...')
            m_p = False
            batch_size = batch_size//2
            accum_step = accum_step*2
            print('Loading last best model ...')
            tmp = torch.load(os.path.join(model_dir, model_name+'_loss.pth'))
            model.load_state_dict(tmp['model'])
            optimizer.load_state_dict(tmp['optim'])
            lr_reduce_scheduler.load_state_dict(tmp['scheduler'])
            # cyclic_scheduler.load_state_dict(tmp['cyclic_scheduler'])
            del tmp
            
        if record:
            history.to_csv(f'{history_dir}/history_{model_name}_{img_dim}.csv', index=False)
        else:
            result.to_csv('Result.csv', index=False)
            # confusion_matrix_generator(result)
        return running_loss/epoch_samples, raw_val_dice, val_dice


plist = [ 
        {'params': model.resnest.parameters(),  'lr': learning_rate/100},
        {'params': model.decoder.parameters(),  'lr': learning_rate}
    ]
optimizer = optim.Adam(plist, lr=learning_rate)
lr_reduce_scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=patience, verbose=True, threshold=1e-4, threshold_mode='rel', cooldown=0, min_lr=1e-7, eps=1e-08)


In [None]:
def main():
    prev_epoch_num = 0
    best_valid_loss = np.inf
    best_valid_dice = 0.0
    best_raw_val_dice = 0.0

    if load_model:
        tmp = torch.load(os.path.join(model_dir, model_name+'_loss.pth'))
        model.load_state_dict(tmp['model'])
        optimizer.load_state_dict(tmp['optim'])
        lr_reduce_scheduler.load_state_dict(tmp['scheduler'])
        # cyclic_scheduler.load_state_dict(tmp['cyclic_scheduler'])
        if m_p:
            try:
                scaler.load_state_dict(tmp['scaler'])
            except: pass
        prev_epoch_num = tmp['epoch']
        best_valid_loss = tmp['best_loss']
        best_valid_loss, best_raw_val_dice, best_valid_dice = train_val(prev_epoch_num+1, valid_loader, optimizer=optimizer, train=False, mode='val', record=False)
        del tmp
        print('Model Loaded!')
  
    for epoch in range(prev_epoch_num, n_epochs):
        torch.cuda.empty_cache()
        print(gc.collect())
        train_val(epoch, train_loader, optimizer=optimizer, train=True, mode='train')
        valid_loss, raw_valid_dice, valid_dice = train_val(epoch, valid_loader, optimizer=optimizer, train=False, mode='val')
        print("#"*20)
        print(f"Epoch {epoch} Report:")
        print(f"Validation Loss: {valid_loss :.4f} Validation dice: {valid_dice :.4f} Raw Validation Dice {raw_valid_dice  :.4f}")
        if m_p:
            best_state = {'model': model.state_dict(), 'optim': optimizer.state_dict(), 'scheduler':lr_reduce_scheduler.state_dict(), 
            # 'cyclic_scheduler':cyclic_scheduler.state_dict(), 
            'scaler': scaler.state_dict(), 'best_loss':valid_loss, 'best_dice':valid_dice, 'epoch':epoch}
        else:
            best_state = {'model': model.state_dict(), 'optim': optimizer.state_dict(), 'scheduler':lr_reduce_scheduler.state_dict(), 
            # 'cyclic_scheduler':cyclic_scheduler.state_dict(), 
            'best_loss':valid_loss, 'best_dice':valid_dice, 'epoch':epoch}
        best_valid_loss, best_valid_dice = save_model(valid_loss, valid_dice, best_valid_loss, best_valid_dice, best_state, os.path.join(model_dir, model_name))
        print("#"*20)   
if __name__== '__main__':
    main()

### Future Tasks
I intend to add the following features in the upcoming verisons of this notebook:

- Data Preprocessing
- Stratified Train-Validation Split
- Performance Visualiztion through **[Wandb](wandb.ai)**