# Begin

In [None]:
import numpy as np
import pandas as pd
from PIL import Image
from tqdm.notebook import tqdm

tqdm.pandas()

import time
import os
import gc

gc.enable()
start_time = time.time()
time_limit = 8*3600 #8 hours max

In [None]:
train_df = pd.read_csv("../input/hubmap-organ-segmentation/train.csv")
train_df

In [None]:
train_df.describe()

# Helper Functions

In [None]:
#Credits to: ravishah1
#Source: https://www.kaggle.com/competitions/hubmap-organ-segmentation/discussion/332838

def mask2rle(img, orig_dim=160):
    '''
    img: numpy array, 1 - mask, 0 - background
    Returns run length as string formated
    '''
    #Rescale image to original size
    n = Image.fromarray(img)
    n = n.resize((orig_dim, orig_dim))
    n = np.array(n).astype(float)
    #Get pixels to flatten
    pixels = n.T.flatten()
    #Round the pixels using the half of the range of pixel value
    pixels = (pixels-min(pixels) > ((max(pixels)-min(pixels))/2)).astype(int)
    
    pixels = np.concatenate([[0], pixels, [0]])
    runs = np.where(pixels[1:] != pixels[:-1])[0] + 1
    runs[1::2] -= runs[::2]
    return ' '.join(str(x) for x in runs)

In [None]:
import matplotlib.pyplot as plt
from skimage import color

def show_masked_img(img, mask, title=''):
    mask = mask.reshape(img.shape[:2])
    fig, ax = plt.subplots(1, 3, figsize=(9, 3))
    fig.suptitle(title, fontsize=16)
    img, mask = img.numpy(), mask.numpy()
    
    ax[0].imshow(mask); ax[0].set_title('Mask')
    ax[1].imshow(img); ax[1].set_title('Image')
    ax[2].imshow(color.label2rgb(mask.T, img,
                               bg_label=0, bg_color=(1.,1.,1.), alpha=0.25))
    ax[2].set_title('Masked Image')
    plt.show()

In [None]:
def rle_to_2d_arr(rle, l):
    rle = np.array(list(map(int, rle.split())))
    label = np.zeros((l*l))
    
    for start, end in zip(rle[::2], rle[1::2]):
        label[start:start+end] = 1
        
    #Convert label to image
    label = Image.fromarray(label.reshape(l, l))
    #Resize label
    label = label.resize((t_size, t_size))
    label = np.array(label).astype(float)
    #rescale label
    label = np.round((label - label.min())/(label.max() - label.min()))
    
    return label

def random_rotate(X, Y):
    flipped_x, flipped_y = [], []
    for x, y in zip(X, Y):
        flip_v = np.random.random()>0.5
        flip_h = np.random.random()>0.5
            
        y = y.T
        #Flip the x or y on their axis
        if flip_v:
            x = x[::-1, :, :]
            y = y[::-1, :]

        if flip_h:
            x = x[:, ::-1, :]
            y = y[:, ::-1]
            
        flipped_x.append(x)
        flipped_y.append(y.T)
    return np.array(x), np.array(y.flatten())

# Data Loader

In [None]:
import torch
from torch.utils.data import Dataset, DataLoader
import torch.nn as nn

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print("Current device:", DEVICE.upper())

t_size = 160

class ImageDataLoader(Dataset):
    def __init__(self, df, img_size=160, rotate=True, train=True):
        self.df = df
        self.img_size = img_size
        self.is_train = train
        self.rotate = rotate
        loc = "train" if self.is_train else "test"
        
        self.paths = df["id"].apply(
            lambda x: f"../input/hubmap-organ-segmentation/{loc}_images/{x}.tiff"
        )
        
    def __len__(self):
        return self.df.shape[0]
    
    def __getitem__(self, idx):
        '''Fetch image at index idx. If train, fetch the labels.'''
        path = self.paths[idx]
        #Read Image
        image = Image.open(path)
        image = image.resize((self.img_size, self.img_size))
        image = np.array(image).astype(float)
        image /= 255.
        image = torch.Tensor(image)
        
        if self.is_train:
            label = rle_to_2d_arr(self.df.rle[idx], self.df.img_width[idx])
            label = torch.Tensor(label)
            if self.rotate:
                image, label = self.flip_image(image, label)
            return torch.Tensor(image), label.flatten()
        else:
            if self.rotate:
                image = self.flip_image(image)
            return torch.Tensor(image)
    
    def flip_image(self, image, label=None):
        flip_v = int(np.random.random()>0.5)
        flip_h = int(np.random.random()>0.5)
        dims = []
        if flip_v: dims.append(0)
        if flip_h: dims.append(1)
        #Flip the image or label on their axis
        image = torch.flip(image, dims)
            
        if self.is_train:
            label = torch.flip(label.T, dims)
            return image, label.T
        else:
            return image

In [None]:
test = ImageDataLoader(train_df)
print(test[100][0].shape, test[100][1].shape)
show_masked_img(*test[1], 'Sample')

In [None]:
#Initialize train and validation data
from sklearn.model_selection import train_test_split as tts

BATCH_SIZE = 16

def data_collate(batch):
    images, labels = [], []
    for data in batch:
        images.append(data[0])
        labels.append(data[1])
    #pad labels to match batch size
        
    return torch.stack(images), torch.stack(labels)

df_train, df_valid = tts(train_df, test_size = 0.1, shuffle=True)
df_train, df_valid = df_train.reset_index(drop=True), df_valid.reset_index(drop=True)
print("Train:", df_train.shape)
print("Validation:", df_valid.shape)

train_data = ImageDataLoader(df_train)
valid_data = ImageDataLoader(df_valid, rotate=False)

train_loader = DataLoader(train_data, batch_size=BATCH_SIZE, shuffle=True,
                          drop_last=True, collate_fn=data_collate, num_workers=2,
                          prefetch_factor=BATCH_SIZE//2)

valid_loader = DataLoader(valid_data, batch_size=BATCH_SIZE, shuffle=False,
                          drop_last=False, collate_fn=data_collate, num_workers=2,
                          prefetch_factor=BATCH_SIZE//2)

# Model

In [None]:
import torchvision

class MaskerModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.resnet = torchvision.models.resnet18(pretrained=False)
        num_ftrs = self.resnet.fc.in_features
        self.resnet.fc = nn.Linear(num_ftrs, 2048)
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(0.25)
        self.out_fc = nn.Linear(2048, t_size**2)
        self.sigmoid = nn.Sigmoid()
    
    def forward(self, image):
        i = image.reshape((-1, 3, t_size, t_size))
        o = self.resnet(i)
        return self.sigmoid(self.out_fc(self.dropout(self.relu(o))))

In [None]:
m = MaskerModel()
m

In [None]:
for images, labels in valid_loader:
    #print(images, labels)
    preds = m(images)
    print(preds.shape)
    break

counts = 0
for c, (a, b) in zip(preds.detach(), zip(images, labels)):
    print("#"*100)
    show_masked_img(a, b, f'Original #{counts}')
    show_masked_img(a, c, f'Predicted #{counts}')
    counts += 1
    if counts > 4: print('#'*100);break

In [None]:
del images, labels, preds, a, b, c, m
gc.collect()

# Train

In [None]:
model = MaskerModel().to(DEVICE)

optim = torch.optim.Adam(model.parameters(), lr=1e-4, weight_decay=1e-6, amsgrad=False)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optim, T_max=4, eta_min=1e-6, last_epoch=-1)
criterion = nn.BCELoss()

EPOCHS = 100
no_improvement_limit = 20
no_impr_count = 0
best_loss = np.inf

for epoch in range(EPOCHS):
    print(f"Epoch [{epoch+1}/{EPOCHS}]:")

    if time.time() - start_time > time_limit:
        break #Avoid TimeLimitExceeded Error

    model.train()
    train_loss = 0
    for images, labels in tqdm(train_loader, desc='train...'):
        images = images.to(DEVICE)
        labels = labels.to(DEVICE)
        optim.zero_grad()
        preds = model(images)
        loss = criterion(preds, labels)
        train_loss += loss.item()
        loss.backward()
        optim.step()    
    train_loss /= len(train_loader)
    
    valid_loss = 0
    for images, labels in tqdm(valid_loader, desc='validation...'):
        images = images.to(DEVICE)
        labels = labels.to(DEVICE)
        with torch.no_grad():
            preds = model(images)
            loss = criterion(preds, labels)
            valid_loss += loss.item()
    valid_loss /= len(valid_loader)
    
    print(f"[epoch {epoch+1}/{EPOCHS}] train_loss:{train_loss:.4f}, valid_loss:{valid_loss:.4f}")
    
    if epoch%5 == 0:
        show_masked_img(images.detach()[0], labels.detach()[0], f'Original epoch {epoch}')
        show_masked_img(images.detach()[0], preds.detach()[0], f'Predicted epoch {epoch}')

    if valid_loss < best_loss:
        no_impr_count = 0
        best_loss = valid_loss
        torch.save(model.state_dict(), "hubseg_model.pth")
        print("Saved...")
    else:
        no_impr_count += 1
        print("Loss did not improve...")
        if no_impr_count >= no_improvement_limit:
            print("Early stopping!")
            break

## Preview of train predictions

In [None]:
for images, labels in train_loader:
    preds = model(images)
    print(preds.shape)
    break
counts = 0
for c, (a, b) in zip(preds.detach(), zip(images, labels)):
    print("#"*100)
    show_masked_img(a, b, f'Original #{counts}')
    show_masked_img(a, c, f'Predicted #{counts}')
    counts += 1
    if counts > 4: print('#'*100);break

## Preview of validation predictions

In [None]:
for images, labels in valid_loader:
    preds = model(images)
    print(preds.shape)
    break
counts = 0
for c, (a, b) in zip(preds.detach(), zip(images, labels)):
    print("#"*100)
    show_masked_img(a, b, f'Original #{counts}')
    show_masked_img(a, c, f'Predicted #{counts}')
    counts += 1
    if counts > 4: print('#'*100);break

In [None]:
del train_loader, valid_loader, images, labels, preds, a, b, c
gc.collect()

# Prediction

In [None]:
test_df = pd.read_csv("../input/hubmap-organ-segmentation/test.csv")
test_data = ImageDataLoader(test_df, rotate=False, train=False)
test_loader = DataLoader(test_data, batch_size=BATCH_SIZE, shuffle=False,
                          drop_last=False, num_workers=2,
                          prefetch_factor=BATCH_SIZE//2)

model.load_state_dict(torch.load("hubseg_model.pth", map_location=DEVICE))

preds = []

for image in tqdm(test_loader, desc='Predicting...'):
    with torch.no_grad():
        pred = model(image)
        preds.append(pred.detach().numpy())

preds = np.concatenate(preds,axis=0)
orig_size = test_df.img_height.values

sub = {'id':[], 'rle':[]}
for i, (p, s) in zip(test_df.id, zip(preds, orig_size)):
    sub['id'].append(i)
    sub['rle'].append(mask2rle(p, s))

sub = pd.DataFrame(sub)
sub.to_csv('submission.csv', index=False)
sub