# UNET-Efficientnet-B0

This notebook uses a UNET architecture with an EfficientNet-B0 backbone with pretrained weights.

**Reference Notebook:** https://www.kaggle.com/julian3833/sartorius-starter-baseline-torch-u-net

# Imports

In [None]:
import os
import cv2
import time
import random
import numpy as np
import pandas as pd
import plotly.express as px
from matplotlib import pyplot as plt

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import Adam
from torch.utils.data import DataLoader, Dataset
from torch.utils.data.distributed import DistributedSampler
from torch.optim.lr_scheduler import ReduceLROnPlateau, CosineAnnealingWarmRestarts, CosineAnnealingLR

# !pip install -U git+https://github.com/albumentations-team/albumentations
!pip install ../input/segmentation-libs/albumentations-1.1.0/albumentations-1.1.0
import albumentations as A
from albumentations.pytorch import ToTensorV2

!pip install ../input/segmentation-libs/timm-0.3.2-py3-none-any.whl
!pip install ../input/segmentation-libs/pretrainedmodels-0.7.4-py3-none-any.whl
!pip install ../input/segmentation-libs/efficientnet_pytorch-0.6.3-py3-none-any.whl
!pip install ../input/segmentation-libs/segmentation_models_pytorch-0.1.3-py3-none-any.whl

import collections.abc as container_abcs
torch._six.container_abcs = container_abcs
import segmentation_models_pytorch as smp

# EDA

In [None]:
train = "../input/sartorius-cell-instance-segmentation/train"
test = "../input/sartorius-cell-instance-segmentation/test"
save_path = os.getcwd()

train_df = pd.read_csv(os.path.join(os.path.dirname(train), "train.csv"))
train_df.head()

In [None]:
image_ids = train_df["id"].unique()
print(f"There are {len(image_ids)} images in the dataset")
cell_instances = len(train_df["annotation"])
print(f"There are {cell_instances} cell instances in the dataset")
cell_types = train_df["cell_type"].unique()

In [None]:
px.histogram(train_df["cell_type"], labels=cell_types)

In [None]:
IMAGE_SHAPE = (train_df["height"][0], train_df["width"][0])
IMAGE_RESIZE = (224, 224)

In [None]:
def rle_decode(mask_rle):
    
    mask_rle = np.array(mask_rle.split(), dtype=np.int)
    pixels = mask_rle.reshape(-1, 2)
#     assert len(start) == len(length)
    pixels[:, 0] -= 1
    mask = np.zeros(IMAGE_SHAPE[0] * IMAGE_SHAPE[1])
    for pixel in pixels:
        mask[pixel[0]:pixel[0] + pixel[1]] = 1
    return mask.reshape(IMAGE_SHAPE)

def prepare_image_mask(mask_annotations):
    
    mask = np.zeros(IMAGE_SHAPE)
    for mask_annotation in mask_annotations:
        mask += rle_decode(mask_annotation)
        
    mask = mask.clip(0, 1)
    return mask

def compute_iou(labels, y_pred):
    """
    Computes the IoU for instance labels and predictions.

    Args:
        labels (np array): Labels.
        y_pred (np array): predictions

    Returns:
        np array: IoU matrix, of size true_objects x pred_objects.
    """

    true_objects = len(np.unique(labels))
    pred_objects = len(np.unique(y_pred))

    # Compute intersection between all objects
    intersection = np.histogram2d(
        labels.flatten(), y_pred.flatten(), bins=(true_objects, pred_objects)
    )[0]

    # Compute areas (needed for finding the union between all objects)
    area_true = np.histogram(labels, bins=true_objects)[0]
    area_pred = np.histogram(y_pred, bins=pred_objects)[0]
    area_true = np.expand_dims(area_true, -1)
    area_pred = np.expand_dims(area_pred, 0)

    # Compute union
    union = area_true + area_pred - intersection
    iou = intersection / union
    
    return iou[1:, 1:]  # exclude background

def precision_at(threshold, iou):
    """
    Computes the precision at a given threshold.

    Args:
        threshold (float): Threshold.
        iou (np array): IoU matrix.

    Returns:
        int: Number of true positives,
        int: Number of false positives,
        int: Number of false negatives.
    """
    matches = iou > threshold
    true_positives = np.sum(matches, axis=1) == 1  # Correct objects
    false_positives = np.sum(matches, axis=0) == 0  # Missed objects
    false_negatives = np.sum(matches, axis=1) == 0  # Extra objects
    tp, fp, fn = (
        np.sum(true_positives),
        np.sum(false_positives),
        np.sum(false_negatives),
    )
    return tp, fp, fn

def iou_map(truths, preds, verbose=0):
    """
    Computes the metric for the competition.
    Masks contain the segmented pixels where each object has one value associated,
    and 0 is the background.

    Args:
        truths (list of masks): Ground truths.
        preds (list of masks): Predictions.
        verbose (int, optional): Whether to print infos. Defaults to 0.

    Returns:
        float: mAP.
    """
    ious = [compute_iou(truth, pred) for truth, pred in zip(truths, preds)]

    if verbose:
        print("Thresh\tTP\tFP\tFN\tPrec.")

    prec = []
    for t in np.arange(0.5, 1.0, 0.05):
        tps, fps, fns = 0, 0, 0
        for iou in ious:
            tp, fp, fn = precision_at(t, iou)
            tps += tp
            fps += fp
            fns += fn

        p = tps / (tps + fps + fns)
        prec.append(p)

        if verbose:
            print("{:1.3f}\t{}\t{}\t{}\t{:1.3f}".format(t, tps, fps, fns, p))

    if verbose:
        print("AP\t-\t-\t-\t{:1.3f}".format(np.mean(prec)))

    return np.mean(prec)

In [None]:
paths = []
for image_id in image_ids:
    paths.append(os.path.join(train, image_id + ".png"))

In [None]:
idx = np.random.choice(np.arange(len(paths)))
img_path = paths[idx]
img = cv2.imread(img_path)
mask_annotations = train_df["annotation"][train_df["id"] == image_ids[idx]].values
print(f"Chose image {img_path}.")
print(f"This image has {len(mask_annotations[0:][::2])} cells.")

mask = prepare_image_mask(mask_annotations)
mask = np.clip(mask, 0, 1)
fig, ax = plt.subplots(1, 3, figsize=(20, 15))
fig.suptitle(f"Cell Type: {train_df['cell_type'][train_df['id'] == image_ids[idx]].values[0]}")
ax[0].imshow(img, cmap="bone_r")
ax[1].imshow(img, cmap="bone_r")
ax[1].imshow(mask, alpha=0.2, cmap="bone_r")
ax[2].imshow(mask, cmap="gray")

# Dataset

In [None]:
class SartoriusDataset(Dataset):
    
    def __init__(self, image_paths, data_df, transforms=None):
        self.image_paths = image_paths
        self.data_df = data_df
        self.transforms = transforms
    
    def __len__(self):
        return len(self.image_paths)
    
    def __getitem__(self, idx):
        image_path = self.image_paths[idx]
        image_id = image_path.split('/')[-1].split('.')[0]
        image = cv2.imread(image_path)
        
        image_annotations = self.data_df["annotation"][self.data_df["id"] == image_id].values
        mask = prepare_image_mask(image_annotations)
        mask = (mask >= 1).astype(np.float32)
        if self.transforms is not None:
            augmented = self.transforms(image=image, mask=mask)
            image = augmented["image"]
            mask = augmented["mask"]

        return image, mask.unsqueeze(0)

In [None]:
transforms_train = A.Compose([
  A.RandomResizedCrop(IMAGE_RESIZE[0], IMAGE_RESIZE[1], scale=(0.9, 1), p=1), 
  A.HorizontalFlip(p=0.5),
  A.ShiftScaleRotate(p=0.5),
  A.HueSaturationValue(hue_shift_limit=10, sat_shift_limit=10, val_shift_limit=10, p=0.7),
  A.RandomBrightnessContrast(brightness_limit=(-0.2,0.2), contrast_limit=(-0.2, 0.2), p=0.7),
  A.CLAHE(clip_limit=(1,4), p=0.5),
  A.OneOf([
    A.OpticalDistortion(distort_limit=1.0),
    A.GridDistortion(num_steps=5, distort_limit=1.),
    A.ElasticTransform(alpha=3),
  ], p=0.2),
  A.OneOf([
    A.GaussNoise(var_limit=[10, 50]),
    A.GaussianBlur(),
    A.MotionBlur(),
    A.MedianBlur(),
  ], p=0.2),
  A.Resize(IMAGE_RESIZE[0], IMAGE_RESIZE[1]),
  A.OneOf([
    A.JpegCompression(),
    A.Downscale(scale_min=0.1, scale_max=0.15),
    ], p=0.2),
  A.IAAPiecewiseAffine(p=0.2),
  A.IAASharpen(p=0.2),
  A.Cutout(max_h_size=int(IMAGE_RESIZE[0] * 0.1), max_w_size=int(IMAGE_RESIZE[1] * 0.1), num_holes=5, p=0.5),
  A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
  A.VerticalFlip(p=0.5),
  ToTensorV2()
])

transforms_valid = A.Compose([
  A.Resize(IMAGE_RESIZE[0], IMAGE_RESIZE[1]),
  A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
  ToTensorV2()
])

In [None]:
train_dataset = SartoriusDataset(paths, train_df, transforms = transforms_train)

for i in range(5):
    image, mask = train_dataset[i]
    print(f"Image size: {image.size()}, Mask size: {mask.size()}")
    plt.imshow(image.permute(1, 2, 0))
    plt.show()
    plt.imshow(mask[0], cmap="gray")
    plt.show() 

In [None]:
def dice_loss(pred, target):
    pred = pred.sigmoid().view(-1)
    target = target.view(-1)
    
    numerator = 2.0 * (pred * target).sum() + 1.0
    denominator = pred.sum() + target.sum() + 1.0
    
    return numerator / denominator

class FocalLoss(nn.Module):
    def __init__(self, gamma):
        super().__init__()
        self.gamma = gamma

    def forward(self, input, target):
        if not (target.size() == input.size()):
            raise ValueError("Target size ({}) must be the same as input size ({})"
                             .format(target.size(), input.size()))
        max_val = (-input).clamp(min=0)
        loss = input - input * target + max_val + \
            ((-max_val).exp() + (-input - max_val).exp()).log()
        invprobs = F.logsigmoid(-input * (target * 2.0 - 1.0))
        loss = (invprobs * self.gamma).exp() * loss
        return loss.mean()


class MixedLoss(nn.Module):
    def __init__(self, alpha, gamma):
        super().__init__()
        self.alpha = alpha
        self.focal = FocalLoss(gamma)

    def forward(self, input, target):
        loss = self.alpha*self.focal(input, target) - torch.log(dice_loss(input, target))
        return loss.mean()

In [None]:
class CustomUnet(nn.Module):
    
    def __init__(self, model_name="efficientnet-b0"):
        super(CustomUnet, self).__init__()
        self.model = smp.Unet(model_name, encoder_weights="imagenet", in_channels=3, classes=1, activation=None)
        
    def forward(self, x):
        prediction = self.model(x)
        return prediction

In [None]:
def train_fn(model, train_loader, criterion, optimizer, device):
    model.train()
    losses = []

    for batch_idx, (images, masks) in enumerate(train_loader):

        images, masks = images.to(device), masks.to(device)
        optimizer.zero_grad()
        logits = model(images)
        loss = criterion(logits, masks)
        loss.backward()
        optimizer.step()

        losses.append(loss.item())
        if batch_idx % 5 == 0:
            print(f'Training Loss at {batch_idx}th batch: {loss.item():.4f}')

    loss_train = np.mean(losses)
    return loss_train

def val_fn(model, val_loader, criterion, device):
    model.eval()
    losses = []
    trues = []
    preds = []

    with torch.no_grad():
        for batch_idx, (images, masks) in enumerate(val_loader):

            images, masks = images.to(device), masks.to(device)

            logits = model(images)
            loss = criterion(logits, masks)
            
            preds.append(logits.sigmoid().detach().cpu().numpy())
            trues.append(masks.detach().cpu().numpy())
            losses.append(loss.item())
            if batch_idx % 5 == 0:
                print(f'Validation Loss at {batch_idx}th batch: {loss.item():.4f}')

    loss_valid = np.mean(losses)
    return loss_valid, trues, preds

def post_process_preds(preds, threshold=0.5):
    """
        Only thresholds images and returns a list of images.
    """
    processed_preds = []
    for pred in preds:
        pred = pred[:, 0, :, :]
        pred = (pred > threshold).astype(np.float32)
        processed_preds.append(pred)
    
    return processed_preds

In [None]:
def train_loop(epochs, lr, early_stopping, device, print_freq=1):
    
    # This is a very straightforward model with no folding
    # Just do a 90/10 split for train/val.
    idx = int(len(paths) * 0.9)
    train_dataset = SartoriusDataset(paths[:idx], train_df, transforms=transforms_train)
    val_dataset = SartoriusDataset(paths[idx:], train_df, transforms=transforms_valid)
    
    train_dataloader = DataLoader(train_dataset, batch_size=64, shuffle=True, num_workers=0)
    val_dataloader = DataLoader(val_dataset, batch_size=1, shuffle=False, num_workers=0)
    
    model = CustomUnet()
    model.to(device)
    optimizer = Adam(model.parameters(), lr=lr, weight_decay=1e-6, amsgrad=False)
    scheduler = CosineAnnealingLR(optimizer, T_max=5)
    
    criterion = MixedLoss(10, 2)
    best_loss = np.inf
    best_avg_prec = -np.inf
    
    early_stop_epochs = 0
    for epoch in range(epochs):
        
        start_time = time.time()
        train_loss = train_fn(model, train_dataloader, criterion, optimizer, device)
        
        val_loss, trues, preds = val_fn(model, val_dataloader, criterion, device)
        scheduler.step()

        time_taken = time.time() - start_time
        
        preds = post_process_preds(preds)
        avg_prec = iou_map(trues, preds, 1)
        if epoch % print_freq == 0:
            print(f'At Epoch {epoch}, Training Loss: {train_loss:.5f}, Validation Loss: {val_loss:.4f}, Avg Prec: {avg_prec:.4f}. Took {time_taken:.0f}s.')
        
        if avg_prec > best_avg_prec:
            best_avg_prec = avg_prec
            torch.save(model.state_dict(), os.path.join(save_path, f"unet_{epoch}_{best_avg_prec:.4f}.pth"))
        
        if val_loss < best_loss:
            best_loss = val_loss
            print(f"Validation loss < Best loss. Saving model...")
            torch.save(model.state_dict(), os.path.join(save_path, f"unet_{epoch}_{val_loss:.2f}.pth"))
            early_stop_epochs = 0
        else:
            early_stop_epochs += 1
            if early_stop_epochs == early_stopping:
                break
    
    return model

In [None]:
!mkdir -p /root/.cache/torch/hub/checkpoints/
!cp ../input/efficientnetv2weights/efficientnet-b0-355c32eb.pth  /root/.cache/torch/hub/checkpoints/efficientnet-b0-355c32eb.pth

LR = 5e-4
EPOCHS = 100
EARLY_STOPPING = 15
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
model = train_loop(EPOCHS, LR, EARLY_STOPPING, device)

# Prediction

## Utilities

In [None]:
def separate_mask_components(mask, min_size=300):
    num_component, component = cv2.connectedComponents(mask.astype(np.uint8))
    predictions = []
    for c in range(1, num_component):
        p = (component == c)
        if p.sum() > min_size:
            a_prediction = np.zeros(IMAGE_SHAPE, np.float32)
            a_prediction[p] = 1
            predictions.append(a_prediction)
    return predictions

def rle_encoding(x):
    dots = np.where(x.flatten() == 1)[0]
    run_lengths = []
    prev = -2
    for b in dots:
        if (b>prev+1): run_lengths.extend((b + 1, 0))
        run_lengths[-1] += 1
        prev = b
    return ' '.join(map(str, run_lengths))

def one_hot(y, num_classes, dtype=np.uint8):
    y = np.array(y, dtype='int')
    input_shape = y.shape
    if input_shape and input_shape[-1] == 1 and len(input_shape) > 1:
        input_shape = tuple(input_shape[:-1])
    y = y.ravel()
    if not num_classes:
        num_classes = np.max(y) + 1
    n = y.shape[0]
    categorical = np.zeros((n, num_classes), dtype=dtype)
    categorical[np.arange(n), y] = 1
    output_shape = input_shape + (num_classes,)
    categorical = np.reshape(categorical, output_shape)
    return categorical

def fix_overlap(msk):
    """
    Args:
        mask: multi-channel mask, each channel is an instance of cell, shape:(520,704,None)
    Returns:
        multi-channel mask with non-overlapping values, shape:(520,704,None)
    """
    msk = np.array(msk)
#     print(msk.shape)
#     msk = np.pad(msk, [[0,0],[0,0],[1,0]]) # add dummy mask for background
#     ins_len = msk.shape[-1]
    msk = np.argmax(msk,axis=-1)# convert multi channel mask to single channel mask, argmax will remove overlap
    msk = one_hot(msk, num_classes=ins_len) # back to multi-channel mask, some instance might get removed
    msk = msk[...,1:] # remove background mask
    msk = msk[...,np.any(msk, axis=(0,1))] # remove all-zero masks
    #assert np.prod(msk, axis=-1).sum()==0 # overlap check, will raise error if there is overlap
    return msk

def check_overlap(msk):
    msk = msk.astype(np.bool).astype(np.uint8) # binary mask
    return np.any(np.sum(msk, axis=-1)>1) # only one channgel will contain value

In [None]:
class SartoriusTestDataset(Dataset):
    
    def __init__(self, root, transforms=None):
        self.root = root
        self.image_ids = [f_name[:-4] for f_name in os.listdir(self.root)]
        self.transforms = transforms
        
    def __getitem__(self, idx):
        image_id = self.image_ids[idx]
        image_path = os.path.join(self.root, image_id + ".png")
        image = cv2.imread(image_path)
        
        if self.transforms is not None:
            augmented = self.transforms(image=image)
            image = augmented["image"]
        return image, image_id
    
    def __len__(self):
        return len(self.image_ids)

In [None]:
test_dataset = SartoriusTestDataset(test, transforms_valid)
test_dataloader = DataLoader(test_dataset, batch_size=1, shuffle=False)

weights = [w_file for w_file in os.listdir(save_path) if "0." in w_file[:-5]]
avgs = [float(w[:-4].split('_')[-1]) for w in weights]
best_avg_weights = weights[np.argmax(avgs)]
print(best_avg_weights)
model.load_state_dict(torch.load(os.path.join(save_path, best_avg_weights)))
model.eval()

submission = []
for image, image_id in test_dataloader:
    prediction = model(image.to(device)).sigmoid().detach().cpu().numpy()
    prediction = post_process_preds([prediction])[0][0]
    probability_mask = cv2.resize(prediction, dsize=(IMAGE_SHAPE[1], IMAGE_SHAPE[0]), interpolation=cv2.INTER_LINEAR)
    
    cell_instances = separate_mask_components(probability_mask)
    cell_instances = np.stack(cell_instances, axis=-1)
    if check_overlap(cell_instances):
        cell_instances = fix_overlap(cell_instances)
        
    for cell in cell_instances:
        submission.append([image_id[0], rle_encoding(cell)])
    
    image_ids = [image_id for image_id, cell in submission]
    if image_id not in image_ids:
        submission.append([image_id, ""])

df = pd.DataFrame(submission, columns=["id", "predicted"])
df.to_csv("submission.csv", index=False)