# Unet with Deep watershed transform(DWT) [Train]
[[Inference notebook]](https://www.kaggle.com/ebinan92/unet-with-deep-watershed-transform-dwt-infer)  
This notebook is the simple implementation of DWT method [(paper)](https://arxiv.org/abs/1611.08303). <br>
Orignal paper's approch use two Unet to learn disntace transformation.  
For simplicity, I tried single Unet with multi task learning approach invented by snakers41.  
Please see the detail info at [his blog](https://spark-in.me/post/playing-with-dwt-and-ds-bowl-2018)

## import, seed, config

In [None]:
!pip install -q ../input/pytorch-segmentation-models-lib/pretrainedmodels-0.7.4/pretrainedmodels-0.7.4
!pip install -q ../input/pytorch-segmentation-models-lib/efficientnet_pytorch-0.6.3/efficientnet_pytorch-0.6.3
!pip install -q ../input/pytorch-segmentation-models-lib/timm-0.4.12-py3-none-any.whl
!pip install -q ../input/pytorch-segmentation-models-lib/segmentation_models_pytorch-0.2.0-py3-none-any.whl

In [None]:
from tqdm.notebook import tqdm as tqdm
from torch.utils.data import DataLoader, Dataset
from torch.optim.lr_scheduler import CosineAnnealingLR
from sklearn.model_selection import StratifiedKFold
from skimage.segmentation import watershed
from skimage.measure import label
import matplotlib.pyplot as plt
import torch.optim as optim
import torch.nn.functional as F
import torch
import segmentation_models_pytorch as smp
import pandas as pd
import numpy as np
import cv2
from albumentations.pytorch import ToTensorV2
import albumentations as A
import random
import pickle
import os
from statistics import mean, stdev

def fix_all_seeds(seed):
    np.random.seed(seed)
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)


fix_all_seeds(2021)


class config:
    SAMPLE_SUBMISSION = '../input/sartorius-cell-instance-segmentation/sample_submission.csv'
    TRAIN_CSV = "../input/sartorius-cell-instance-segmentation/train.csv"
    TRAIN_PATH = "../input/sartorius-cell-instance-segmentation/train"
    TEST_PATH = "../input/sartorius-cell-instance-segmentation/test"
    MODEL_PATH = "models"
    MASK_PATH = "../input/cell-masks/train_masks"
    RESNET_MEAN = [0.485]
    RESNET_STD = [0.229]
    IMAGE_RESIZE = [512, 704]
    LR = 5e-4
    min_LR = 5e-5
    device = 'cuda'
    BS = 4
    num_workers = 2
    N_EPOCH = 50
    N_FOLD = 5
    
os.makedirs(config.MODEL_PATH, exist_ok=True)

In [None]:
df_train = pd.read_csv(config.TRAIN_CSV).groupby('id').first().reset_index()

## check mask dataset

In [None]:
image_id = df_train.iloc[0].id
with open(f'{config.MASK_PATH}/mask_{image_id}.pkl', 'rb') as f:
    masks = pickle.load(f)

masks.shape

## Train

### Dataset and Augmentation

In [None]:
from torchvision import transforms as T
from PIL import Image

class CellDataset(Dataset):
    def __init__(self, df, transforms):
        self.df = df
        self.base_path = config.TRAIN_PATH
        self.transforms = transforms
        self.image_ids = df.id.unique().tolist()

    def __getitem__(self, idx):
        image_id = self.image_ids[idx]
        image_path = os.path.join(self.base_path, image_id + ".png")
        image = Image.open(image_path)
        
        with open(f'{config.MASK_PATH}/mask_{image_id}.pkl', 'rb') as f:
            masks = pickle.load(f)

        masks = Image.fromarray(masks)
        image = self.transforms(image)
        masks = masks.resize(config.IMAGE_RESIZE)
        masks = T.ToTensor()(masks)
        return {'image': image.float(), 'masks': masks.int()}

    def __len__(self):
        return len(self.image_ids)


data_transforms = {
    "train": T.Compose([
        T.Resize(config.IMAGE_RESIZE),
#         T.RandomHorizontalFlip(p=0.5),
#         T.RandomVerticalFlip(p=0.5),
        T.ToTensor(),
        T.Normalize(mean=config.RESNET_MEAN, std=config.RESNET_STD),
    ]),

    "valid": T.Compose([
        T.Resize(config.IMAGE_RESIZE),
        T.ToTensor(),
        T.Normalize(mean=config.RESNET_MEAN, std=config.RESNET_STD),
    ])
}

### Utils

In [None]:
def get_threshold(Y, pred):
    scores = list(pred.ravel())
    mask = list(Y.ravel())
    
    idxs=np.argsort(scores)[::-1]
    mask_sorted=np.array(mask)[idxs]
    sum_mask_one=np.cumsum(mask_sorted)
    IoU=sum_mask_one/(np.arange(1,len(mask_sorted)+1)+np.sum(mask_sorted)-sum_mask_one)
    best_IoU_idx=IoU.argmax()
    best_threshold=scores[idxs[best_IoU_idx]]
    best_IoU=IoU[best_IoU_idx]

    return best_threshold, best_IoU

### Train and valid loop

In [None]:
def train_loop(model, optimizer, loader, criterion):
    losses, lrs = [], []
    model.train()
    optimizer.zero_grad()
    for d in loader:
        y = d['masks'].to(config.device)
        pred_y = model(d['image'].to(config.device))
        loss = criterion(pred_y, y.float())
        losses.append(loss.item())
        step_lr = np.array([param_group["lr"] for param_group in optimizer.param_groups]).mean()
        lrs.append(step_lr)
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
    return np.array(losses).mean(), np.array(lrs).mean()


def valid_loop(model, loader, criterion):
    losses, true_masks, pred_masks, pred_energys = [], [], [], []
    model.eval()
    for d in loader:
        with torch.no_grad():
            y = d['masks'].to(config.device)
            pred_y = model(d['image'].to(config.device))
            loss = criterion(pred_y, y.float())
        losses.append(loss.item())
        pred_masks.append(F.sigmoid(pred_y.cpu()))
        true_masks.append(y.cpu())
    pred_masks = torch.cat(pred_masks)
    true_masks = torch.cat(true_masks)
    return np.array(losses).mean(), true_masks, pred_masks

### Network

In [None]:
import torch.nn as nn

class Attention_block(nn.Module):
    def __init__(self,F_g,F_l,F_int):
        super(Attention_block,self).__init__()
        self.W_g = nn.Sequential(
            nn.Conv2d(F_g, F_int, kernel_size=1,stride=1,padding=0,bias=True),
            nn.BatchNorm2d(F_int)
            )
        
        self.W_x = nn.Sequential(
            nn.Conv2d(F_l, F_int, kernel_size=1,stride=1,padding=0,bias=True),
            nn.BatchNorm2d(F_int)
        )

        self.psi = nn.Sequential(
            nn.Conv2d(F_int, 1, kernel_size=1,stride=1,padding=0,bias=True),
            nn.BatchNorm2d(1),
            nn.Sigmoid()
        )
        
        self.relu = nn.ReLU(inplace=True)
        
    def forward(self,g,x):
        g1 = self.W_g(g)
        x1 = self.W_x(x)
        psi = self.relu(g1+x1)
        psi = self.psi(psi)

        return x*psi

In [None]:
class conv_block(nn.Module):
    def __init__(self,ch_in,ch_out):
        super(conv_block,self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(ch_in, ch_out, kernel_size=3,stride=1,padding=1,bias=True),
            nn.BatchNorm2d(ch_out),
            nn.ReLU(inplace=True),
            nn.Conv2d(ch_out, ch_out, kernel_size=3,stride=1,padding=1,bias=True),
            nn.BatchNorm2d(ch_out),
            nn.ReLU(inplace=True)
        )


    def forward(self,x):
        x = self.conv(x)
        return x

class up_conv(nn.Module):
    def __init__(self,ch_in,ch_out):
        super(up_conv,self).__init__()
        self.up = nn.Sequential(
            nn.Upsample(scale_factor=2),
            nn.Conv2d(ch_in,ch_out,kernel_size=3,stride=1,padding=1,bias=True),
            nn.BatchNorm2d(ch_out),
            nn.ReLU(inplace=True)
        )

    def forward(self,x):
        x = self.up(x)
        return x

In [None]:
class AttU_Net(nn.Module):
    def __init__(self,img_ch=3,output_ch=1,scale_factor=1):
        super(AttU_Net,self).__init__()
        
        self.Maxpool = nn.MaxPool2d(kernel_size=2,stride=2)

        filters = [64, 128, 256, 512, 1024]
        filters = [int(x / scale_factor) for x in filters]

        self.Conv1 = conv_block(ch_in=img_ch,ch_out=filters[0])
        self.Conv2 = conv_block(ch_in=filters[0],ch_out=filters[1])
        self.Conv3 = conv_block(ch_in=filters[1],ch_out=filters[2])
        self.Conv4 = conv_block(ch_in=filters[2],ch_out=filters[3])
        self.Conv5 = conv_block(ch_in=filters[3],ch_out=filters[4])

        self.Up5 = up_conv(ch_in=filters[4],ch_out=filters[3])
        self.Att5 = Attention_block(F_g=filters[3],F_l=filters[3],F_int=filters[2])
        self.Up_conv5 = conv_block(ch_in=filters[4], ch_out=filters[3])

        self.Up4 = up_conv(ch_in=filters[3], ch_out=filters[2])
        self.Att4 = Attention_block(F_g=filters[2],F_l=filters[2],F_int=filters[1])
        self.Up_conv4 = conv_block(ch_in=filters[3], ch_out=filters[2])
        
        self.Up3 = up_conv(ch_in=filters[2], ch_out=filters[1])
        self.Att3 = Attention_block(F_g=filters[1],F_l=filters[1],F_int=filters[0])
        self.Up_conv3 = conv_block(ch_in=filters[2], ch_out=filters[1])
        
        self.Up2 = up_conv(ch_in=filters[1], ch_out=filters[0])
        self.Att2 = Attention_block(F_g=filters[0],F_l=filters[0],F_int=filters[0] // 2)
        self.Up_conv2 = conv_block(ch_in=filters[1], ch_out=filters[0])

        self.Conv_1x1 = nn.Conv2d(filters[0],output_ch,kernel_size=1,stride=1,padding=0)


    def forward(self,x):
        # encoding path
        x1 = self.Conv1(x)

        x2 = self.Maxpool(x1)
        x2 = self.Conv2(x2)
        
        x3 = self.Maxpool(x2)
        x3 = self.Conv3(x3)

        x4 = self.Maxpool(x3)
        x4 = self.Conv4(x4)

        x5 = self.Maxpool(x4)
        x5 = self.Conv5(x5)

        # decoding + concat path
        d5 = self.Up5(x5)
        x4 = self.Att5(g=d5,x=x4)
        d5 = torch.cat((x4,d5),dim=1)        
        d5 = self.Up_conv5(d5)
        
        d4 = self.Up4(d5)
        x3 = self.Att4(g=d4,x=x3)
        d4 = torch.cat((x3,d4),dim=1)
        d4 = self.Up_conv4(d4)

        d3 = self.Up3(d4)
        x2 = self.Att3(g=d3,x=x2)
        d3 = torch.cat((x2,d3),dim=1)
        d3 = self.Up_conv3(d3)

        d2 = self.Up2(d3)
        x1 = self.Att2(g=d2,x=x1)
        d2 = torch.cat((x1,d2),dim=1)
        d2 = self.Up_conv2(d2)

        d1 = self.Conv_1x1(d2)

        return d1

### Training

In [None]:
skf = StratifiedKFold(n_splits=config.N_FOLD, shuffle=True)
for fold, (_, valid_idx) in enumerate(skf.split(df_train, df_train.cell_type)):
    df_train.loc[valid_idx, 'fold'] = fold
    
for fold in range(config.N_FOLD):
    print(f"Fold: {fold}")
    train_dset = CellDataset(df_train.query(f"fold!={fold}"), data_transforms['train'])
    valid_dset = CellDataset(df_train.query(f"fold=={fold}"), data_transforms['valid'])

    train_loader = DataLoader(train_dset, batch_size=config.BS,
                              pin_memory=True, shuffle=True, num_workers=config.num_workers,
                              worker_init_fn=lambda x: np.random.seed(torch.initial_seed() // 2 ** 32 + x))
    valid_loader = DataLoader(valid_dset, batch_size=config.BS * 2,
                              pin_memory=True, shuffle=False, drop_last=False, num_workers=config.num_workers)

    model = AttU_Net(img_ch=1,output_ch=1,scale_factor=1)
    model = model.to(config.device)

    optimizer = optim.Adam(model.parameters(), lr=config.LR)
    criterion = smp.losses.JaccardLoss(mode='binary')
    scheduler = CosineAnnealingLR(optimizer, T_max=config.N_EPOCH, eta_min=config.min_LR)

    valid_best_score = 0.
    for epoch in tqdm(range(config.N_EPOCH)):
        img_thresholds = []         # one for each image
        img_IoUs = []
        train_loss, lrs = train_loop(model, optimizer, train_loader, criterion)
        valid_loss, valid_mask, valid_pred_mask = valid_loop(model, valid_loader, criterion)
#         for i in range(valid_mask.shape[0]):
#             best_img_threshold, best_img_IoU = get_threshold(valid_mask[i], valid_pred_mask[i])
#             img_thresholds.append(best_img_threshold)
#             img_IoUs.append(best_img_IoU)
#         best_threshold = np.mean(img_thresholds)
#         best_threshold_spread = np.std(img_thresholds)
#         avg_IoU = mean(img_IoUs)
#         if avg_IoU > valid_best_score:
        print(f"epoch: {epoch}, train_loss: {train_loss:.3f}, valid_loss: {valid_loss:.3f}")
#             print(f"Best threshold: {best_threshold:.3g} (+-{best_threshold_spread:.3g}), Avg. Train IoU: {avg_IoU:.3f}")
#             torch.save(model.state_dict(), f'{config.MODEL_PATH}/{config.model_name}_{fold}.pth')
        scheduler.step()
    break