This notebook implements semantic segmentation with a simple toy network. To goal is to explore what a simple network with very few layers can achieve.

I tried to keep the code clean. If you see ways to improve, please let me know in the comments.

Currently, there is no proper validation split; I will add that later.

In [None]:
from pathlib import Path
import pickle
import cv2
import numpy as np
import torch
from torch import nn
import matplotlib.pyplot as plt
import matplotlib as mpl
from tqdm.auto import tqdm
from statistics import mean, stdev
import pandas as pd
import itertools

In [None]:
data_dir = Path("/kaggle/input/sartorius-cell-instance-segmentation")
mask_dir = Path("/kaggle/input/cell-image-masks/train_masks")   # my public dataset

# Helper code

In [None]:
class MyImageDataset(torch.utils.data.Dataset):
    def __init__(self, filenames, mask_dir):
        self.filenames = filenames
        self.mask_dir = mask_dir

    def __getitem__(self, index):
        filename = Path(self.filenames[index])
        mask_filename = self.mask_dir.joinpath(f"mask_{filename.stem}.pkl")

        if not filename.exists():
            raise ValueError(f"Image {filename} does not exists")

        if not mask_filename.exists():
            raise ValueError(f"Mask {mask_filename} does not exists")

        img_data = load_img(filename)

        with mask_filename.open("rb") as f:
            mask_data = pickle.load(f)

        img_tensor = torch.tensor(img_data).unsqueeze(0)
        mask_tensor = torch.tensor(mask_data.astype(np.float32)).unsqueeze(0)

        return img_tensor, mask_tensor

    def __len__(self):
        return len(self.filenames)
    
    
def load_img(filename):
    """returns x-125 as float32"""
    if not filename.exists():
        raise IOError(f"File {filename} not found")

    img_data = cv2.imread(str(filename))
    assert img_data is not None

    return cv2.cvtColor(img_data, cv2.COLOR_BGR2GRAY).astype(np.float32) - 125

# Load data

In [None]:
batch_size = 32

image_filenames = list(data_dir.joinpath("train").glob("*.png"))

train_dataset = MyImageDataset(image_filenames, mask_dir)

train_dataloader = torch.utils.data.DataLoader(
    train_dataset,
    shuffle=True,
    batch_size=batch_size,
)

# Train simple network

Keep W/H the same across all layers and do simple log-loss against full semantic segmentation mask.

In [None]:
n_epochs = 50

NN = nn.Sequential(
    nn.Conv2d(1, 20, kernel_size=5, padding="same"),
    nn.BatchNorm2d(20),
    nn.ReLU(),
    nn.Conv2d(20, 10, kernel_size=1),
    
    nn.Conv2d(10, 10, kernel_size=5, padding="same"),
    nn.BatchNorm2d(10),
    nn.ReLU(),
    nn.Conv2d(10, 1, kernel_size=1),
).to("cuda")

##########################################################
lossfunc = nn.BCEWithLogitsLoss()

losses = []

optimizer = torch.optim.Adam(NN.parameters())

for epoch in tqdm(range(1, n_epochs+1)):
    for i, (X, Y) in enumerate(train_dataloader):
        torch.cuda.empty_cache()
        
        X = X.to("cuda")
        Y = Y.to("cuda")
        
        optimizer.zero_grad()
        
        pred=NN(X)
        
        loss = lossfunc(pred, Y)
        
        losses.append(float(loss))

        loss.backward()
        optimizer.step()

    print(f"{epoch:3}/{n_epochs}: {mean(losses[-10:]):.4g}")
    
torch.cuda.empty_cache()

In [None]:
pd.Series(losses).rolling(21).mean().plot(title="loss");

# Find best prediction mask threshold optimizing IoU score

In [None]:
# unoptimized and slow; any way to speed up?

def get_threshold(Y, pred):
    scores = list(pred.ravel())
    mask = list(Y.ravel())
    
    idxs=np.argsort(scores)[::-1]
    mask_sorted=np.array(mask)[idxs]
    sum_mask_one=np.cumsum(mask_sorted)
    IoU=sum_mask_one/(np.arange(1,len(mask_sorted)+1)+np.sum(mask_sorted)-sum_mask_one)
    best_IoU_idx=IoU.argmax()
    best_threshold=scores[idxs[best_IoU_idx]]
    best_IoU=IoU[best_IoU_idx]

    return best_threshold, best_IoU
    
img_thresholds = []         # one for each image
img_IoUs = []

N=3
for X, Y in tqdm(itertools.islice(train_dataloader, N), total=N):
    X = X.to("cuda")
    Y = Y.detach().numpy()

    with torch.no_grad():
        pred=torch.sigmoid(NN(X)).cpu().detach().numpy()

    for i in range(Y.shape[0]):
        best_img_threshold, best_img_IoU = get_threshold(Y[i], pred[i])
        img_thresholds.append(best_img_threshold)
        img_IoUs.append(best_img_IoU)
    
best_threshold = np.mean(img_thresholds)
best_threshold_spread = np.std(img_thresholds)
avg_IoU = mean(img_IoUs)

print(f"Best threshold: {best_threshold:.3g} (+-{best_threshold_spread:.3g}), Avg. Train IoU: {avg_IoU:.3f}")

# Visualize predictions

In [None]:
threshold = best_threshold

###########################
X, Y = next(iter(train_dataloader))
X = X.to("cuda")
Y = Y.detach().numpy()

with torch.no_grad():
    pred=torch.sigmoid(NN(X)).cpu().detach().numpy()
    
pred_Y = (pred >= threshold)
    
cmap = mpl.colors.ListedColormap(['black', 'gray', 'orange', 'green'])

def plot(img_Y, img_pred):
    output = np.zeros_like(img_Y)
    output = np.where((img_Y == 0) & (img_pred == 1), 1, output)
    output = np.where((img_Y == 1) & (img_pred == 0), 2, output)
    output = np.where((img_Y == 1) & (img_pred == 1), 3, output)

    plt.figure(figsize=(10,10))
    plt.imshow(output, cmap=cmap)
    plt.xticks([])
    plt.yticks([]);
    

N = 5
for i in range(N):
    img_Y = Y[i, 0]
    img_pred = pred_Y[i, 0]
    
    plot(img_Y, img_pred)
    plt.show()

# green: correct prediction
# gray: false positive (too much)
# orange: false negative (missed)