# Unet with Deep watershed transform(DWT) [Infer]
[[Train notebook]](https://www.kaggle.com/ebinan92/unet-with-deep-watershed-transform-dwt-train)  
Inference pipeline is almost same as [Awsaf's notebook](https://www.kaggle.com/awsaf49/pytorch-sartorius-unet-strikes-back-infer) expect watershed algorithm added. 

### import, seed, config

In [None]:
!pip install -q ../input/pytorch-segmentation-models-lib/pretrainedmodels-0.7.4/pretrainedmodels-0.7.4
!pip install -q ../input/pytorch-segmentation-models-lib/efficientnet_pytorch-0.6.3/efficientnet_pytorch-0.6.3
!pip install -q ../input/pytorch-segmentation-models-lib/timm-0.4.12-py3-none-any.whl
!pip install -q ../input/pytorch-segmentation-models-lib/segmentation_models_pytorch-0.2.0-py3-none-any.whl

In [None]:
import skimage.morphology
import segmentation_models_pytorch as smp
import cupy as cp
import os
import skimage
from skimage.morphology import thin
from scipy import ndimage as ndi
from skimage.measure import label
from skimage.segmentation import watershed
import numpy as np
from albumentations.pytorch import ToTensorV2
import albumentations as A
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import torch.nn as nn
import torch
import matplotlib.pyplot as plt
import cv2
import gc
from tqdm import tqdm
from glob import glob
import random
import pandas as pd

In [None]:
class config:
    SAMPLE_SUBMISSION = '../input/sartorius-cell-instance-segmentation/sample_submission.csv'
    TRAIN_CSV = "../input/sartorius-cell-instance-segmentation/train.csv"
    TRAIN_PATH = "../input/sartorius-cell-instance-segmentation/train"
    TEST_PATH = "../input/sartorius-cell-instance-segmentation/test"
    MODEL_PATH = "../input/resnet101-dwt/models"
    RESNET_MEAN = (0.485, 0.456, 0.406)
    RESNET_STD = (0.229, 0.224, 0.225)
    IMAGE_RESIZE = (512, 704)
    model_name = 'resnet101'
    device = 'cuda'
    BS = 1
    num_workers = 2
    ttas = [0, 1, 2, 3]
    mask_len = 6
    

def fix_all_seeds(seed):
    np.random.seed(seed)
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)


fix_all_seeds(2021)

In [None]:
test_df = pd.DataFrame(glob(f'{config.TEST_PATH}/*'), columns=['image_path'])
test_df['id'] = test_df.image_path.map(lambda x: x.split('/')[-1].split('.')[0])

### Dataset and Augmentation

In [None]:
class TestDataset(Dataset):
    def __init__(self, df, transforms=None):
        self.df = df
        self.img_paths = df['image_path'].values
        try:  # if there is no mask then only send images --> test data
            self.msk_paths = df['mask_path'].values
        except BaseException:
            self.msk_paths = None
        self.transforms = transforms

    def __len__(self):
        return len(self.df)

    def __getitem__(self, index):
        img_path = self.img_paths[index]
        img = cv2.imread(img_path)
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        if self.msk_paths is not None:
            msk_path = self.msk_paths[index]
            msk = np.load(msk_path)
            if self.transforms:
                data = self.transforms(image=img, mask=msk)
                img = data['image']
                msk = data['mask']
            msk = np.expand_dims(msk, axis=0)  # output_shape: (batch_size, 1, img_size, img_size)
            return img, msk
        else:
            if self.transforms:
                data = self.transforms(image=img)
                img = data['image']
            return img, img_path
        
data_transforms = {
    "valid": A.Compose([
        A.Resize(config.IMAGE_RESIZE[0], config.IMAGE_RESIZE[1]),
        A.Normalize(mean=config.RESNET_MEAN, std=config.RESNET_STD, p=1),
        ToTensorV2()], p=1.0)
}

### Utils

In [None]:
def ins2rle(ins):
    '''
    img: numpy array, 1 - mask, 0 - background
    Returns run length as string formated
    '''
    ins = cp.array(ins)
    pixels = ins.flatten()
    pad = cp.array([0])
    pixels = cp.concatenate([pad, pixels, pad])
    runs = cp.where(pixels[1:] != pixels[:-1])[0] + 1
    runs[1::2] -= runs[::2]
    return ' '.join(str(x) for x in runs)


def mask2rle(lab_mask, cutoff=0.5, min_object_size=1.0):
    """ Return run length encoding of mask.
        ref: https://www.kaggle.com/raoulma/nuclei-dsb-2018-tensorflow-u-net-score-0-352
    """
    # segment image and label different objects
#     lab_mask = skimage.morphology.label(mask > cutoff)

    # Keep only objects that are large enough.
    (mask_labels, mask_sizes) = np.unique(lab_mask, return_counts=True)
    if (mask_sizes < min_object_size).any():
        mask_labels = mask_labels[mask_sizes < min_object_size]
        for n in mask_labels:
            lab_mask[lab_mask == n] = 0
        lab_mask = skimage.morphology.label(lab_mask > cutoff)

    # Loop over each object excluding the background labeled by 0.
    for i in range(1, lab_mask.max() + 1):
        yield ins2rle(lab_mask == i)


def aug(img, axis=0):
    if axis == 1:
        return torch.flip(img, dims=(1,))
    elif axis == 2:
        return torch.flip(img, dims=(2,))
    elif axis == 3:
        return torch.flip(img, dims=(1, 2))
    elif axis == 4:
        return torch.rot90(img, k=1, dims=(1, 2))
    elif axis == 5:
        return torch.rot90(img, k=1, dims=(2, 1))
    else:
        return img


def reverse_aug(img, axis=0):
    if axis == 1:
        return torch.flip(img, dims=(1,))
    elif axis == 2:
        return torch.flip(img, dims=(2,))
    elif axis == 3:
        return torch.flip(img, dims=(1, 2))
    elif axis == 4:
        return torch.rot90(img, k=1, dims=(2, 1))
    elif axis == 5:
        return torch.rot90(img, k=1, dims=(1, 2))
    else:
        return img


def get_aug_img(img, ttas=config.ttas):
    """
    Args:
        img  :  image
        ttas :  tta modes ex [0, 1]
    Return:
        augmentated images shape (num_tta, dim0, dim1, channel)
    """
    if len(ttas) == 0:
        return img.unsqueeze(0)
    aug_img = []
    for idx, tta_mode in enumerate(ttas):
        aug_img.append(aug(img, axis=tta_mode))
    aug_img = torch.stack(aug_img, dim=0)
    return aug_img


def fix_aug_img(aug_pred, ttas=config.ttas):
    """
    Args:
        aug_pred  :  prediction of augmented images
        ttas      :  tta modes ex [0, 1]
    Return:
        final image after ensemble
    """
    if len(ttas) == 0:
        return aug_pred
    fixed_pred = []
    for idx, tta_mode in enumerate(ttas):
        fixed_pred.append(reverse_aug(aug_pred[idx], axis=tta_mode))
    fixed_pred = torch.stack(fixed_pred, dim=0)
    fixed_pred = torch.mean(fixed_pred, dim=0)
    return fixed_pred


def watershed_energy(msk=None,
                     energy=None,
                     threshold=0.5,
                     threshold_energy=0.6,
                     line=False):

    msk_ths = (np.copy(msk) > 255 * threshold) * 1
    energy_ths = (np.copy(energy) > 255 * threshold_energy) * 1

    markers = label(energy_ths)
    labels = watershed(-energy,
                       markers,
                       mask=msk_ths,
                       watershed_line=line)
    return labels


def load_model(path):
    model = smp.Unet(config.model_name, encoder_weights=None, activation=None, classes=config.mask_len)
    model.load_state_dict(torch.load(path))
    model = model.to(config.device)
    model.eval()
    return model

### Inference

In [None]:
@torch.no_grad()
def infer(model_paths, test_loader, num_log=3):
    pred_strings = []
    pred_paths = []
    msks = []
    imgs = []
    energys = []
    for idx, (img, img_path) in enumerate(tqdm(test_loader, total=len(test_loader), desc='Infer')):
        img = img.to(config.device, dtype=torch.float).squeeze()
        img = get_aug_img(img, ttas=config.ttas)
        msk = []
        energy = []
        for path in model_paths:
            model = load_model(path)
            out = model(img).squeeze(0)  # removing batch axis
            out = fix_aug_img(out, ttas=config.ttas)
            out = nn.Sigmoid()(out)  # removing channel axis
            msk.append(out[0])
            energy.append(torch.mean(out[:-1], dim=0))
        msk = torch.mean(torch.stack(msk, dim=0), dim=0)
        msk = F.interpolate(msk[None, None, ], size=(520, 704), mode='nearest')[0, 0]
        msk = msk.cpu().detach().numpy()
        energy = torch.mean(torch.stack(energy, dim=0), dim=0)
        energy = F.interpolate(energy[None, None, ], size=(520, 704), mode='nearest')[0, 0]
        energy = energy.cpu().detach().numpy()
        img = F.interpolate(img[0:1, ], size=(520, 704), mode='nearest')[0]  # first dim is image w/o aug
        img = img.squeeze().permute((1, 2, 0)).cpu().detach().numpy()
        msk = watershed_energy(msk * 255, energy * 255, 0.5, 0.7)
        if idx < num_log:
            msks.append(msk)
            energys.append(energy)
            imgs.append(img)
        rle = list(mask2rle(msk))
        pred_strings.extend(rle)
        pred_paths.extend(img_path * len(rle))
        del img, msk
        gc.collect()
        torch.cuda.empty_cache()
    return pred_strings, pred_paths, imgs, msks, energys

In [None]:
test_dataset = TestDataset(test_df, transforms=data_transforms['valid'])
test_loader = DataLoader(test_dataset, batch_size=config.BS,
                         num_workers=config.num_workers, shuffle=False, pin_memory=True)
model_paths = glob(f'{config.MODEL_PATH}/resnet101_*.pth')

pred_strings, pred_paths, imgs, msks, energys = infer(model_paths, test_loader)

### Check result

In [None]:
for img, msk, energy in zip(imgs, msks, energys):
    plt.figure(figsize=(15, 7))
    plt.subplot(1, 3, 1)
    plt.imshow(img)
    plt.axis('OFF')
    plt.title('image')
    plt.subplot(1, 3, 2)
    plt.imshow(msk)
    plt.axis('OFF')
    plt.title('mask')
    plt.subplot(1, 3, 3)
    plt.imshow(energy)
    plt.axis('OFF')
    plt.title('energy')
    plt.tight_layout()
    plt.show()

### Submission

In [None]:
ids = list(map(lambda x: x.split('/')[-1].split('.')[0], pred_paths))
pred_df = pd.DataFrame({'id': ids,
                        'predicted': pred_strings})
sub_df = pd.read_csv('/kaggle/input/sartorius-cell-instance-segmentation/sample_submission.csv')
del sub_df['predicted']
sub_df = sub_df.merge(pred_df, on='id', how='left')
sub_df.to_csv('submission.csv', index=False)