In [None]:
# Libraries
import os
from os.path import join
from tqdm import tqdm
import random

import numpy as np
import pandas as pd
pd.set_option('display.max_rows', 500)
pd.set_option('display.max_columns', 500)
pd.set_option('display.width', 1000)
import matplotlib.pyplot as plt
plt.rcParams.update({'font.size': 18})
import cv2

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset, sampler

from albumentations import (HorizontalFlip, VerticalFlip, ShiftScaleRotate, Normalize, Resize, Compose, GaussNoise)

def initialize_seeds(seed):
    np.random.seed(seed)
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)

    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    
initialize_seeds(2021)

## Load Dataset

### Understand the Structure of the Dataset


*   train - train images in PNG format
*   The training annotations -> run length encoded masks
*   Images -> PNG format (The number of images is small, but the number of annotated objects is quite high.)
*   Test set -> 240 images

### Files

**train.csv** - IDs and masks for all training objects. None of this metadata is provided for the test set.
* id - unique identifier for object
* annotation - run length encoded pixels for the identified neuronal cell
* width - source image width
* height - source image height
* cell_type - the cell line
* plate_time - time plate was created
* sample_date - date sample was created
* sample_id - sample identifier
* elapsed_timedelta - time since first image taken of sample

**sample_submission.csv** - a sample submission file in the correct format

**train** - train images in PNG format

**test** - test images in PNG format. Only a few test set images are available for download; the remainder can only be accessed by your notebooks when you submit.

**train_semi_supervised** - unlabeled images offered in case you want to use additional data for a semi-supervised approach.

**LIVECell_dataset_2021** - A mirror of the data from the LIVECell dataset. LIVECell is the predecessor dataset to this competition. You will find extra data for the SH-SHY5Y cell line, plus several other cell lines not covered in the competition dataset that may be of interest for transfer learning.

In [None]:
DATA_PATH             = '../input/sartorius-cell-instance-segmentation'
SAMPLE_SUBMISSION     = join(DATA_PATH,'train')
TRAIN_CSV             = join(DATA_PATH,'train.csv')
TRAIN_PATH            = join(DATA_PATH,'train')
TEST_PATH             = join(DATA_PATH,'test')

df_train = pd.read_csv(TRAIN_CSV)
print(f'Training Set Shape: {df_train.shape} - {df_train["id"].nunique()} \
Images - Memory Usage: {df_train.memory_usage().sum() / 1024 ** 2:.2f} MB')

In [None]:
def rle_decode(mask_rle, shape, color=1):
    '''
    mask_rle: run-length as string formated (start length)
    shape: (height,width) of array to return 
    Returns numpy array, 1 - mask, 0 - background

    '''
    s = mask_rle.split()
    starts, lengths = [np.asarray(x, dtype=int) for x in (s[0:][::2], s[1:][::2])]
    starts -= 1
    ends = starts + lengths
    img = np.zeros(shape[0] * shape[1], dtype=np.float32)
    for lo, hi in zip(starts, ends):
        img[lo : hi] = color
    return img.reshape(shape)

def build_masks(df_train, image_id, input_shape):
    height, width = input_shape
    labels = df_train[df_train["id"] == image_id]["annotation"].tolist()
    mask = np.zeros((height, width))
    for label in labels:
        mask += rle_decode(label, shape=(height, width))
    mask = mask.clip(0, 1)
    return np.array(mask)

In [None]:

class CellDataset(Dataset):
    def __init__(self, df: pd.core.frame.DataFrame, train:bool):
        self.IMAGE_RESIZE = (224, 224)
        self.RESNET_MEAN = (0.485, 0.456, 0.406)
        self.RESNET_STD = (0.229, 0.224, 0.225)
        self.df = df
        self.base_path = TRAIN_PATH
        self.gb = self.df.groupby('id')
        self.transforms = Compose([Resize( self.IMAGE_RESIZE[0],  self.IMAGE_RESIZE[1]), 
                                   Normalize(mean=self.RESNET_MEAN, std= self.RESNET_STD, p=1), 
                                   HorizontalFlip(p=0.5),
                                   VerticalFlip(p=0.5)])
        
        # Split train and val set
        all_image_ids = np.array(df_train.id.unique())
        np.random.seed(42)
        iperm = np.random.permutation(len(all_image_ids))
        num_train_samples = int(len(all_image_ids) * 0.9)

        if train:
            self.image_ids = all_image_ids[iperm[:num_train_samples]]
        else:
             self.image_ids = all_image_ids[iperm[num_train_samples:]]

    def __getitem__(self, idx: int) -> dict:

        image_id = self.image_ids[idx]
        df = self.gb.get_group(image_id)

        # Read image
        image_path = os.path.join(self.base_path, image_id + ".png")
        image = cv2.imread(image_path)

        # Create the mask
        mask = build_masks(df_train, image_id, input_shape=(520, 704))
        mask = (mask >= 1).astype('float32')
        augmented = self.transforms(image=image, mask=mask)
        image = augmented['image']
        mask = augmented['mask']
        # print(np.moveaxis(image,0,2).shape)
        return np.moveaxis(np.array(image),2,0), mask.reshape((1, self.IMAGE_RESIZE[0], self.IMAGE_RESIZE[1]))


    def __len__(self):
        return len(self.image_ids)

In [None]:
ds_train = CellDataset(df_train, train=True)
dl_train = DataLoader(ds_train, batch_size=16, num_workers=2, pin_memory=True, shuffle=False)

In [None]:
# plot simages and mask from dataloader
batch = next(iter(dl_train))
images, masks = batch
print(f"image shape: {images.shape},\nmask shape:{masks.shape},\nbatch len: {len(batch)}")

plt.figure(figsize=(10, 5))

plt.subplot(1, 3, 1)
plt.imshow(images[1][1])
plt.title('Original image')

plt.subplot( 1, 3, 2)
plt.imshow(masks[1][0])
plt.title('Mask')

plt.subplot( 1, 3, 3)
plt.imshow(images[1][1])
plt.imshow(masks[1][0],alpha=0.2)
plt.title('Both')
plt.tight_layout()
plt.show()

## U-Net model architecture 

A u-net is commonly used for biological image segmentation because its shape allows for local and global features to be combined to create highly-precise segmentations.

A u-net is shaped like an autoencoder, it has:



1.   a standard convolutional network with downsampling, like one used for imagenet
2.   upsampling layers that ultimately return an image at the same size as the input image In addition to these downsampling and upsampling blocks, it has skip connections from the downsampling blocks TO the upsampling blocks, which allows it to propagate more precise local information to the later layers.

**Model is based on https://deeplearning.neuromatch.io/projects/Neuroscience/cellular_segmentation.html**

In [None]:

def convbatchrelu(in_channels, out_channels, sz):
  return nn.Sequential(
      nn.Conv2d(in_channels, out_channels, sz, padding=sz//2),
      nn.BatchNorm2d(out_channels, eps=1e-5),
      nn.ReLU(inplace=True),
      )


class convdown(nn.Module):
  def __init__(self, in_channels, out_channels, kernel_size):
    super().__init__()
    self.conv = nn.Sequential()
    for t in range(2):
      if t == 0:
        self.conv.add_module('conv_%d'%t,
                             convbatchrelu(in_channels,
                                           out_channels,
                                           kernel_size))
      else:
        self.conv.add_module('conv_%d'%t,
                             convbatchrelu(out_channels,
                                           out_channels,
                                           kernel_size))

  def forward(self, x):
    x = self.conv[0](x)
    x = self.conv[1](x)
    return x


class downsample(nn.Module):
  def __init__(self, nbase, kernel_size):
    super().__init__()
    self.down = nn.Sequential()
    self.maxpool = nn.MaxPool2d(2, 2)
    for n in range(len(nbase) - 1):
      self.down.add_module('conv_down_%d'%n,
                           convdown(nbase[n],
                                    nbase[n + 1],
                                    kernel_size))

  def forward(self, x):
    xd = []
    for n in range(len(self.down)):
      if n > 0:
        y = self.maxpool(xd[n - 1])
      else:
        y = x
      xd.append(self.down[n](y))
    return xd


class convup(nn.Module):
  def __init__(self, in_channels, out_channels, kernel_size):
    super().__init__()
    self.conv = nn.Sequential()
    self.conv.add_module('conv_0', convbatchrelu(in_channels,
                                                 out_channels,
                                                 kernel_size))
    self.conv.add_module('conv_1', convbatchrelu(out_channels,
                                                 out_channels,
                                                 kernel_size))

  def forward(self, x, y):
    x = self.conv[0](x)
    x = self.conv[1](x + y)
    return x


class upsample(nn.Module):
  def __init__(self, nbase, kernel_size):
    super().__init__()
    self.upsampling = nn.Upsample(scale_factor=2, mode='nearest')
    self.up = nn.Sequential()
    for n in range(len(nbase) - 1 , 0, -1):
      self.up.add_module('conv_up_%d'%(n - 1),
              convup(nbase[n], nbase[n - 1], kernel_size))

  def forward(self, xd):
    x = xd[-1]
    for n in range(0, len(self.up)):
      if n > 0:
        x = self.upsampling(x)
      x = self.up[n](x, xd[len(xd) - 1 - n])
    return x


class Unet(nn.Module):
  def __init__(self, nbase, nout, kernel_size):
    super(Unet, self).__init__()
    self.nbase = nbase
    self.nout = nout
    self.kernel_size = kernel_size
    self.downsample = downsample(nbase, kernel_size)
    nbaseup = nbase[1:]
    nbaseup.append(nbase[-1])
    self.upsample = upsample(nbaseup, kernel_size)
    self.output = nn.Conv2d(nbase[1], self.nout, kernel_size,
                            padding=kernel_size//2)

  def forward(self, data):
    T0 = self.downsample(data)
    T0 = self.upsample(T0)
    T0 = self.output(T0)
    return T0

  def save_model(self, filename):
    torch.save(self.state_dict(), filename)

  def load_model(self, filename, cpu=False):
    if not cpu:
      self.load_state_dict(torch.load(filename))
    else:
      self.__init__(self.nbase,
                    self.nout,
                    self.kernel_size,
                    self.concatenation)

      self.load_state_dict(torch.load(filename,
                                      map_location=torch.device('cpu')))

## Define the network

In [None]:
kernel_size = 3
nbase = [3, 32, 64, 128, 256]  # number of channels per layer
nout = 1  # number of outputs

net = Unet(nbase, nout, kernel_size)
# put on GPU here if you have it
device = 'cuda' if torch.cuda.is_available() else 'cpu'
net.to(device);  # remove semi-colon to see net structure
print(f"The device is {device}!!")

## Train the network

### Loss Function

In [None]:
def dice_loss(input, target):
    input = torch.sigmoid(input)
    smooth = 1.0
    iflat = input.view(-1)
    tflat = target.view(-1)
    intersection = (iflat * tflat).sum()
    return ((2.0 * intersection + smooth) / (iflat.sum() + tflat.sum() + smooth))


class FocalLoss(nn.Module):
    def __init__(self, gamma):
        super().__init__()
        self.gamma = gamma

    def forward(self, input, target):
        if not (target.size() == input.size()):
            raise ValueError("Target size ({}) must be the same as input size ({})"
                             .format(target.size(), input.size()))
        max_val = (-input).clamp(min=0)
        loss = input - input * target + max_val + \
            ((-max_val).exp() + (-input - max_val).exp()).log()
        invprobs = F.logsigmoid(-input * (target * 2.0 - 1.0))
        loss = (invprobs * self.gamma).exp() * loss
        return loss.mean()


class MixedLoss(nn.Module):
    def __init__(self, alpha, gamma):
        super().__init__()
        self.alpha = alpha
        self.focal = FocalLoss(gamma)

    def forward(self, input, target):
        loss = self.alpha*self.focal(input, target) - torch.log(dice_loss(input, target))
        return loss.mean()

In [None]:
from datetime import datetime
from tqdm import tqdm

# train the network
# parameters related to training the network
### you will want to increase n_epochs!
n_epochs = 20  # number of times to cycle through all the data during training
learning_rate = 0.1
weight_decay = 1e-5 # L2 regularization of weights
momentum = 0.9 # how much to use previous gradient direction
n_epochs_per_save = 5 # how often to save the network
val_frac = 0.05 # what fraction of data to use for validation

# where to save the network
# make sure to clean these out every now and then, as you will run out of space
now = datetime.now()
timestamp = now.strftime('%Y%m%dT%H%M%S')
n_train=72


# gradient descent flavor
optimizer = torch.optim.SGD(net.parameters(),
                            lr=learning_rate,
                            weight_decay=weight_decay,
                            momentum=0.9)
# set learning rate schedule
LR = np.linspace(0, learning_rate, 10)
LR = np.append(LR, learning_rate*np.ones(n_epochs-5))
for i in range(5):
    LR = np.append(LR, LR[-1]/2 * np.ones(10))

criterion = MixedLoss(10.0, 2.0)

# store loss per epoch
epoch_losses = np.zeros(n_epochs)
epoch_losses[:] = np.nan

# when we last saved the network
saveepoch = None

# loop through entire training data set nepochs times
for epoch in range(n_epochs):
  net.train() # put in train mode (affects batchnorm)
  epoch_loss = 0
  iters = 0
  for param_group in optimizer.param_groups:
    param_group['lr'] = LR[epoch]
  with tqdm(total=545, desc=f"Epoch {epoch + 1}/{n_epochs}", unit='img') as pbar:
    # loop through each batch in the training data
    for batch_idx, batch in enumerate(dl_train):
      # transfer to torch + GPU
      images, masks = batch

      # transfer to torch + GPU
      images = images.to(device=device)
      masks = masks.to(device=device)

      # compute the loss
      y = net(images)
      loss = criterion(y, masks)
      epoch_loss += loss.item()
      pbar.set_postfix(**{'loss (batch)': loss.item()})
      # gradient descent
      optimizer.zero_grad()
      loss.backward()
      optimizer.step()
      iters+=1
      pbar.update(masks.shape[0])

    epoch_losses[epoch] = epoch_loss
    pbar.set_postfix(**{'loss (epoch)': epoch_loss})  #.update('loss (epoch) = %f'%epoch_loss)

  # save checkpoint networks every now and then
  if epoch % n_epochs_per_save == 0:
    print(f"\nSaving network state at epoch {epoch+1}")
    saveepoch = epoch
    savefile = f"unet_epoch{saveepoch+1}.pth"
    net.save_model(savefile)
print(f"\nSaving network state at epoch {epoch+1}")
net.save_model(f"unet_epoch{epoch+1}.pth")

In [None]:
ds_val = CellDataset(df_train, train=False)
dl_val = DataLoader(ds_val, batch_size=4, num_workers=2, pin_memory=True, shuffle=False)

In [None]:
def post_process(probability, threshold=0.5, min_size=20):
    mask = cv2.threshold(probability_mask, 0.25, 1,  cv2.THRESH_BINARY)[1]
    num_component, component = cv2.connectedComponents(mask.astype(np.uint8))
    predictions = []
    im = np.zeros((240, 240), np.float32)
    for c in range(1, num_component):
        p = (component == c)
        if p.sum() > min_size:
            a_prediction = np.zeros((240, 240), np.float32)
            a_prediction[p] = 1
            predictions.append(a_prediction)
        im[p] = 1
    return predictions, im

In [None]:
net.eval()

submission = []
for i, batch in enumerate(tqdm(dl_val)):
    preds = torch.sigmoid(net(batch[0].cuda()))
    preds = preds.detach().cpu().numpy()[:, 0, :, :] # (batch_size, 1, size, size) -> (batch_size, size, size)
    for index,probability_mask in  enumerate(preds):
            print(f"\nsum prob: {np.sum(probability_mask)}")
            plt.figure(figsize=(10, 5))
            probability_mask = cv2.resize(probability_mask, dsize=(240, 240), interpolation=cv2.INTER_LINEAR)
            plt.subplot(1, 3, 1)
            plt.imshow(batch[1][index][0])
            plt.title('Original Mask')

            plt.subplot( 1, 3, 2)
            plt.imshow(probability_mask)
            plt.title('Probality Prediction')

            predictions, pred_mask = post_process(probability_mask, threshold=0.65)
            print(f"prediction num {len(predictions)}")
            plt.subplot(1, 3, 3)
            plt.imshow(pred_mask)
            plt.title('Predicted Mask')
            plt.show()