In [None]:
import os
import sys
import random
import numpy as np
import pandas as pd
import seaborn as sns
import cv2
import tifffile as tiff 
import matplotlib.pyplot as plt
from tqdm import tqdm
from glob import glob
from PIL import Image, ImageOps
from sklearn.metrics import f1_score, roc_auc_score

import torch
import torch.nn as nn
from torch.cuda.amp import GradScaler, autocast
from torch.utils.data import Dataset, DataLoader
from torch.optim import AdamW, lr_scheduler
from torchvision import transforms
from torchvision.transforms.functional import pad
from torchvision.models.segmentation import deeplabv3_resnet50
from torchvision.models.segmentation.deeplabv3 import DeepLabHead

#### It's a start, there are still some bugs to fix in the training

### Exploratory data analysis

In [None]:
path = '../input/hubmap-organ-segmentation/'
img_path = path + 'train_images/'

In [None]:
train_df = pd.read_csv(path + 'train.csv')
test_df = pd.read_csv(path + 'test.csv')
train_df.head()

In [None]:
fig, ax = plt.subplots(1, 3, figsize=(24, 8))

sns.countplot(data=train_df, x='organ', ax=ax[0])
sns.countplot(data=train_df, x='sex', ax=ax[1])
sns.histplot(data=train_df, x='age', ax=ax[2])

In [None]:
def rle_encode(img):
    pixels = img.T.flatten()
    pixels = np.concatenate([[0], pixels, [0]])
    runs = np.where(pixels[1:] != pixels[:-1])[0] + 1
    runs[1::2] -= runs[::2]
    return ' '.join(str(x) for x in runs)

def rle_decode(mask_rle, shape=(3000, 3000), get_stat=False):
    s = mask_rle.split()
    starts, lengths = [np.asarray(x, dtype=int) for x in (s[0:][::2], s[1:][::2])]
    starts -= 1
    ends = starts + lengths
    img = np.zeros(shape[0]*shape[1], dtype=np.uint8)
    for lo, hi in zip(starts, ends):
        img[lo:hi] = 1
    if get_stat:
        return img.sum()
    return img.reshape(shape).T

In [None]:
train_df['pixel_count'] = train_df.apply(lambda x: rle_decode(x['rle'], shape=(x['img_width'], x['img_height']), get_stat=True), axis=1)

In [None]:
fig, ax = plt.subplots(1, 2, figsize=(12, 8))
sns.violinplot(data=train_df, x='organ', y='pixel_count', hue='sex', ax=ax[0])
sns.violinplot(data=train_df, x='organ', y='age', hue='sex', ax=ax[1])

In [None]:
smp = 290
name = train_df.loc[smp]['id']
width = height = train_df.loc[smp]['img_height']
smp_image = img_path + '{}.tiff'.format(name)
img = tiff.imread(smp_image)#[:3000, :3000]

img = cv2.resize(img, (width, height))
rle = train_df.loc[smp]['rle']
mask = rle_decode(rle, shape=(width, height))#[:3000, :3000]

fig, ax = plt.subplots(1, 3, figsize=(12, 10), sharey=True)
ax[0].imshow(img)
ax[1].imshow(mask,cmap='gray')

ax[2].imshow(img)
ax[2].imshow(mask, cmap='cool', alpha=0.5)

ax[2].set_title(name)
ax[2].axis("off")

In [None]:
def create_patches(img_id, rle, width, height, wh=3000, window=500, save_path='/kaggle/working/train_patches/', min_patch_size=5000, testing=False):
    
    def check_pad(img, wh=3000, window=500, mask=False):
        if width == wh:
            return img
        if width > wh:
            return img[:3000, :3000]
        
        padding = int((wh - width) / 2)        
        if width % 2 == 0:
            if mask:
                pad_width = ((padding,padding), (padding,padding))
            else:
                pad_width = ((padding,padding), (padding,padding), (0,0))
            img = np.pad(img, pad_width=pad_width, constant_values=0, mode='constant')
        else:
            if mask:
                pad_width = ((padding,padding+1), (padding,padding+1))
            else:
                pad_width = ((padding,padding+1), (padding,padding+1), (0,0))
            img = np.pad(img, pad_width=pad_width, constant_values=0, mode='constant')

        return img
    
    img = tiff.imread(img_path + '{}.tiff'.format(img_id))
    img = check_pad(img)

    mask = rle_decode(rle, shape=(width, height))
    mask = check_pad(mask, mask=True)
    
    img_patches = []
    mask_patches = []
    possible_negative = np.array(range(36)).reshape(6,6)[1:-1, 1:-1].flatten().tolist()
    idx = 0
    for r in range(0, wh, window):
        for c in range(0, wh, window):
            img_patch = img[r: r+window, c: c+window]
            img_patches.append(img_patch)
            mask_patch = mask[r: r+window, c: c+window]
            mask_patches.append(mask_patch)
            if not testing:
                if mask_patch.sum() > min_patch_size:
                    tiff.imwrite(file=save_path + 'image_{}_{}.png'.format(img_id, idx), data=img_patch)
                    tiff.imwrite(file=save_path + 'mask_{}_{}.png'.format(img_id, idx), data=mask_patch)
                elif (mask_patch.sum() == 0) and (idx in possible_negative):
                    tiff.imwrite(file=save_path +  'neg_{}_{}.png'.format(img_id, idx), data=img_patch)
    
            idx += 1
    
    if testing:
        return img_patches, mask_patches
    return True

In [None]:
!mkdir /kaggle/working/train_patches/
train_df['shape'] = train_df.apply(lambda x: create_patches(x['id'], x['rle'], x['img_width'], x['img_height']), axis=1)

In [None]:
img_patches, mask_patches = create_patches(name, rle, width, height, testing=True)
fig, ax = plt.subplots(6, 6, figsize=(16, 10), sharey=True)
for idx, i in enumerate(ax.flatten()):   
    i.imshow(img_patches[idx])
    i.imshow(mask_patches[idx], cmap='cool', alpha=0.5)
    i.set_title(idx)
    i.axis("off")
plt.tight_layout()

### Building dataset

In [None]:
df = pd.DataFrame({'path': glob('/kaggle/working/train_patches/*.png')}).assign(iid=lambda x: x['path'].str.extract('(\d+)_\d+.png'), neg=0)
df.loc[df['path'].str.contains('neg'), 'neg'] = 1

In [None]:
training_df = pd.concat([df[df['path'].str.contains('image') & df['neg'].eq(0)], df[df['neg'].eq(1)].drop_duplicates(['iid'])], ignore_index=True)

In [None]:
class FlowFromDataFrame(Dataset):
    def __init__(self, df, img_transforms, mask_transforms, test=False):
        super().__init__()
        self.df = df
        self.img_transforms = img_transforms
        self.mask_transforms = mask_transforms
        self.test = test

    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, index):
        img_path = self.df.iloc[index]['path']
        image = tiff.imread(img_path)
        if self.test:
            return image
        if 'neg' in img_path:
            mask = torch.Tensor(np.zeros((1, 250, 250), dtype=np.uint8))
        else:
            mask = torch.Tensor(tiff.imread(img_path.replace('image', 'mask'))).unsqueeze(0)
        image = self.img_transforms(image)
        mask = self.mask_transforms(mask)
        return image, mask

In [None]:
img_transforms = transforms.Compose([
    transforms.ToPILImage(),
    transforms.Resize((250, 250), interpolation=transforms.InterpolationMode.NEAREST),
    transforms.ToTensor(),
    transforms.Normalize(mean = [0.485, 0.456, 0.406], std = [0.229, 0.224, 0.225])
])

mask_transforms = transforms.Compose([
    transforms.Resize((250, 250), interpolation=transforms.InterpolationMode.NEAREST)
])

In [None]:
batch_size = 64
train_ds = FlowFromDataFrame(training_df, img_transforms, mask_transforms)
train_dl = DataLoader(train_ds, batch_size=batch_size, pin_memory=True, shuffle=True)

In [None]:
images, masks = next(iter(train_dl))
cols = 4
fig, axes = plt.subplots(2, cols, figsize=(12, 10))

for idx, ax in enumerate(axes.flatten()):
    if idx < cols:
        ax.imshow(masks[idx%cols][0], cmap='gray')
    elif idx >= cols and idx <= 2*cols - 1:
        ax.imshow(images[idx%cols].permute(1,2,0))
    ax.axis("off")

In [None]:
def get_model():
    model = deeplabv3_resnet50(pretrained=True)
    model.classifier = DeepLabHead(2048, 1)
    return model

In [None]:
class DiceLoss(nn.Module):
    def __init__(self):
        super(DiceLoss, self).__init__()

    def forward(self, targets, inputs, smooth=1e-8):
        inputs = torch.sigmoid(inputs)
        inputs = inputs.view(-1)
        targets = targets.view(-1)

        intersection = (inputs * targets).sum()
        dice = (2.*intersection)/(inputs.sum() + targets.sum() + smooth)
        return 1 - dice

def dice_accuracy(y_true, y_pred, smooth=1e-8):
    y_true_f = y_true.view(-1)
    y_pred_f = y_pred.view(-1)
    intersection = (y_true_f * y_pred_f).sum()
    return (2. * intersection + smooth) / ((y_true_f).sum() + (y_pred_f).sum() + smooth)


def loss_fn(y_true, y_pred, bce_fn, dice_fn):
    bce = bce_fn(y_pred, y_true)
    dice = dice_fn(y_true, y_pred)
    return 0.8 * bce + 0.2 * dice

In [None]:
def train(model, epochs):
    device = 'cuda'
    bce_fn = nn.BCEWithLogitsLoss()
    dice_fn = DiceLoss()
    model = model.to(device)
    optimizer = AdamW(model.parameters(), lr=1e-3)
    #scheduler = lr_scheduler.CosineAnnealingLR(optimizer, epochs, eta_min=1e-4,)
    #scaler = GradScaler()
    best_loss = 99
    criterion = nn.MSELoss()
    
    for epoch in range(epochs):
        model.train()
        dice_batch_score = 0
        batch_loss = 0
        
        for image, mask in tqdm(train_dl, leave=False):
            image, mask = image.to(device), mask.to(device)
            optimizer.zero_grad()
            
            with autocast():
                output = model(image)
                #loss = criterion(output['out'], mask)
                #loss = loss_fn(mask, output['out'], bce_fn, dice_fn)
                loss = bce_fn(output['out'], mask)
                
            loss.backward()
            optimizer.step()
            #scheduler.step()
            
            y_pred = (torch.sigmoid(output['out']) > 0.5).float()
            dice_score = dice_accuracy(mask, y_pred) 
            dice_batch_score += dice_score
            batch_loss += loss.item() * image.size(0)
        
        epoch_loss = batch_loss / len(train_dl.dataset)
        epoch_dice_score = dice_batch_score / len(train_dl.dataset)
        
        if epoch_loss < best_loss:
            best_loss = epoch_loss
            torch.save(model,'/kaggle/working/deeplabv3_resnet50_test.pt')

        print('Epoch: {} Loss: {:.4f} Dice Score: {:.4f}'.format(epoch, epoch_loss, epoch_dice_score))

In [None]:
epochs = 30
model = get_model()
train(model, epochs)

In [None]:
test_img = tiff.imread('../input/hubmap-organ-segmentation/test_images/10078.tiff')
plt.imshow(test_img)

In [None]:
def make_prediction(model, test_transforms):
    device = 'cuda'
    test_img = tiff.imread('../input/hubmap-organ-segmentation/test_images/10078.tiff')
    pred_patches = np.zeros((2000, 2000), dtype=np.uint8)
    window = 500
    with torch.no_grad():
        model.eval()
        for r in range(0, 2000, window):
            for c in range(0, 2000, window):
                img_patch = test_transforms(test_img[r: r+window, c: c+window]).unsqueeze(0).to(device)
                output = model(img_patch)
                y_pred = (output['out'] > 0.5).cpu()
                pred_patches[r: r+window, c: c+window] = y_pred
    return pred_patches

In [None]:
test_transforms = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean = [0.485, 0.456, 0.406], std = [0.229, 0.224, 0.225])
])
prediction = make_prediction(model, test_transforms)

In [None]:
test_mask = Image.fromarray(prediction)
test_mask = np.pad(prediction, pad_width=((0,23), (0,23)), mode='linear_ramp', end_values=1)
fig, ax = plt.subplots(1, 1, figsize=(12, 12))
ax.imshow(test_img)
ax.imshow(test_mask, cmap='cool', alpha=0.5)