# **Training Notebook**

https://www.kaggle.com/code/vexxingbanana/hubmap-unet-semantic-approach-train

# **Install segmentation_models_pytorch**

In [1]:
!cp -r ../input/pytorch-segmentation-models-lib/ ./

In [2]:
!pip config set global.disable-pip-version-check true

Writing to /root/.config/pip/pip.conf


In [3]:
!pip install -q ./pytorch-segmentation-models-lib/pretrainedmodels-0.7.4/pretrainedmodels-0.7.4
!pip install -q ./pytorch-segmentation-models-lib/efficientnet_pytorch-0.6.3/efficientnet_pytorch-0.6.3
!pip install -q ./pytorch-segmentation-models-lib/timm-0.4.12-py3-none-any.whl
!pip install -q ./pytorch-segmentation-models-lib/segmentation_models_pytorch-0.2.0-py3-none-any.whl

[0m

# **Import Libraries**

In [4]:
import numpy as np
import pandas as pd
import time
import matplotlib.pyplot as plt
import cv2
import glob
import os
import shutil
import timm
import random
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
from torch.cuda import amp
import albumentations as A
from albumentations.pytorch import ToTensorV2
import transformers
from sklearn.model_selection import StratifiedKFold, KFold, StratifiedGroupKFold
import multiprocessing as mp
import segmentation_models_pytorch as smp
import copy
from collections import defaultdict
import gc
from tqdm import tqdm
import tifffile
from colorama import Fore, Back, Style

# **Config**

In [5]:
class CFG:
    seed = 0
    batch_size = 16
    head = "UNet"
    backbone = "efficientnet-b0"
    img_size = [512, 512]
    lr = 1e-3
    scheduler = 'CosineAnnealingLR' #['CosineAnnealingLR']
    epochs = 20
    warmup_epochs = 2
    n_folds = 5
    folds_to_run = [0]
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    base_path = '../input/hubmap-organ-segmentation'
    num_workers = mp.cpu_count()
    num_classes = 1
    n_accumulate = max(1, 16//batch_size)
    loss = 'Dice'
    optimizer = 'Adam'
    weight_decay = 1e-6
    ckpt_path = '../input/hubmap-unet-semantic-approach-train/last_epoch-00.bin' #Checkpoint path
    threshold = 0.5

# **Helper Functions**

In [6]:
# ref: https://www.kaggle.com/paulorzp/run-length-encode-and-decode
def rle_decode(mask_rle, shape):
    '''
    mask_rle: run-length as string formated (start length)
    shape: (height,width) of array to return 
    Returns numpy array, 1 - mask, 0 - background

    '''
    s = mask_rle.split()
    starts, lengths = [np.asarray(x, dtype=int) for x in (s[0:][::2], s[1:][::2])]
    starts -= 1
    ends = starts + lengths
    img = np.zeros(shape[0]*shape[1], dtype=np.uint8)
    for lo, hi in zip(starts, ends):
        img[lo:hi] = 1
    return img.reshape(shape)  # Needed to align to RLE direction


# ref.: https://www.kaggle.com/stainsby/fast-tested-rle
def rle_encode(img):
    '''
    img: numpy array, 1 - mask, 0 - background
    Returns run length as string formated
    '''
    pixels = img.flatten()
    pixels = np.concatenate([[0], pixels, [0]])
    runs = np.where(pixels[1:] != pixels[:-1])[0] + 1
    runs[1::2] -= runs[::2]
    return ' '.join(str(x) for x in runs)

#ref: https://www.kaggle.com/code/bguberfain/memory-aware-rle-encoding/notebook
def rle_encode_less_memory(img):
    '''
    img: numpy array, 1 - mask, 0 - background
    Returns run length as string formated
    This simplified method requires first and last pixel to be zero
    '''
    pixels = img.T.flatten()
    
    # This simplified method requires first and last pixel to be zero
    pixels[0] = 0
    pixels[-1] = 0
    runs = np.where(pixels[1:] != pixels[:-1])[0] + 2
    runs[1::2] -= runs[::2]
    
    return ' '.join(str(x) for x in runs)

In [7]:
def read_tiff(path, scale=None, verbose=0): #Modified from https://www.kaggle.com/code/abhinand05/hubmap-extensive-eda-what-are-we-hacking
    image = tifffile.imread(path)
    if len(image.shape) == 5:
        image = image.squeeze().transpose(1, 2, 0)
    
    if verbose:
        print(f"[{path}] Image shape: {image.shape}")
    
    if scale:
        new_size = (image.shape[1] // scale, image.shape[0] // scale)
        image = cv2.resize(image, new_size)
        
        if verbose:
            print(f"[{path}] Resized Image shape: {image.shape}")
        
    mx = np.max(image)
    image = image.astype(np.float32)
    if mx:
        image /= mx # scale image to [0, 1]
    return image

# **Grab Metadata**

In [8]:
df = pd.read_csv("../input/hubmap-organ-segmentation/test.csv")
df.head()

Unnamed: 0,id,organ,data_source,img_height,img_width,pixel_size,tissue_thickness
0,10078,spleen,Hubmap,2023,2023,0.4945,4


# **Data Processing**

In [9]:
df['image_path'] = df['id'].apply(lambda x: os.path.join(CFG.base_path, 'test_images', str(x) + '.tiff'))

# **Dataset**

In [10]:
class HuBMAP_Dataset(torch.utils.data.Dataset):
    def __init__(self, df, labeled=True, transforms=None):
        self.df = df
        self.labeled = labeled
        self.transforms = transforms
        
    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, index):
        img_path = self.df.loc[index, 'image_path']
        img_height = self.df.loc[index, 'img_height']
        img_width = self.df.loc[index, 'img_width']
        id_ = self.df.loc[index, 'id']
        img = read_tiff(img_path)
        
        if self.labeled:
            rle_mask = self.df.loc[index, 'rle']
            mask = rle_decode(rle_mask, (img_height, img_width))
            
            if self.transforms:
                data = self.transforms(image=img, mask=mask)
                img  = data['image']
                mask  = data['mask']
            
            mask = np.expand_dims(mask, axis=0)
            img = np.transpose(img, (2, 0, 1))
#             mask = np.transpose(mask, (2, 0, 1))
            
            return torch.tensor(img), torch.tensor(mask)
        
        else:
            if self.transforms:
                data = self.transforms(image=img)
                img  = data['image']
                
            img = np.transpose(img, (2, 0, 1))
            
            return torch.tensor(img), img_height, img_width, id_

# **Augmentations**

In [11]:
data_transforms = {
    "inference": A.Compose([
        A.Resize(*CFG.img_size, interpolation=cv2.INTER_NEAREST),
        ], p=1.0)
}

# **Models**

In [12]:
def build_model():
    model = smp.Unet(
        encoder_name=CFG.backbone,      
        encoder_weights=None,     
        in_channels=3,                  
        classes=CFG.num_classes,
        activation=None,
    )
    model.to(CFG.device)
    return model

def load_model(path):
    model = build_model()
    model.load_state_dict(torch.load(path))
    model.eval()
    return model

# **Dataloader**

In [13]:
def prepare_loaders():

    infer_dataset = HuBMAP_Dataset(df, labeled=False, transforms=data_transforms['inference'])

    infer_loader = torch.utils.data.DataLoader(infer_dataset, batch_size=CFG.batch_size,
                              num_workers=CFG.num_workers, shuffle=False, pin_memory=True, drop_last=False)
    
    return infer_loader

# **Inference**

In [14]:
infer_loader = prepare_loaders()
model = load_model(CFG.ckpt_path)

pred_ids = []
pred_rles = []
with torch.no_grad():
    for (images, heights, widths, ids) in infer_loader:
        images = images.to(CFG.device)
        output = model(images)
        output = nn.Sigmoid()(output)
        msks = (output.permute((0,2,3,1))>CFG.threshold).to(torch.uint8).cpu().detach().numpy()

        for idx in range(msks.shape[0]):
            height = heights[idx].item()
            width = widths[idx].item()
            id_ = ids[idx].item()
            msk = cv2.resize(msks[idx].squeeze(), 
                             dsize=(width, height), 
                             interpolation=cv2.INTER_NEAREST)
            rle = rle_encode_less_memory(msk)
            pred_rles.append(rle)
            pred_ids.append(id_)

        gc.collect()
        torch.cuda.empty_cache()

In [15]:
len(pred_rles)

1

In [16]:
pred_df = pd.DataFrame({
    "id":pred_ids,
    "rle":pred_rles
})
pred_df.to_csv('submission.csv',index=False)
display(pred_df.head(5))

Unnamed: 0,id,rle
0,10078,168950 8 168973 28 169005 8 170973 8 170996 28...
