In [None]:
!mkdir -p /tmp/pip/cache/
!cp -r ../input/pytorch-segmentation-models-lib/efficientnet_pytorch-0.6.3 /tmp/pip/cache/efficientnet_pytorch-0.6.3
!cp -r ../input/pytorch-segmentation-models-lib/pretrainedmodels-0.7.4 /tmp/pip/cache/pretrainedmodels-0.7.4
!cp ../input/pytorch-segmentation-models-lib/segmentation_models_pytorch-0.2.0-py3-none-any.whl /tmp/pip/cache/
!cp ../input/pytorch-segmentation-models-lib/timm-0.4.12-py3-none-any.whl /tmp/pip/cache/
# !cp ../input/segmentationmodelspytorch/segmentation_models/timm-0.2.1-py3-none-any.whl /tmp/pip/cache/
# !pip install --no-index --find-links /tmp/pip/cache/ efficientnet-pytorch
# !pip install --no-index --find-links /tmp/pip/cache/ segmentation-models-pytorch

In [None]:
!pip install -q /tmp/pip/cache/pretrainedmodels-0.7.4/pretrainedmodels-0.7.4
!pip install -q /tmp/pip/cache/efficientnet_pytorch-0.6.3/efficientnet_pytorch-0.6.3
!pip install -q /tmp/pip/cache/timm-0.4.12-py3-none-any.whl
!pip install -q /tmp/pip/cache/segmentation_models_pytorch-0.2.0-py3-none-any.whl

In [None]:
import pandas as pd
import numpy as np

import torch, torchvision
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
from torch.utils.data import Dataset, DataLoader
import torchvision.models as models
import timm
import segmentation_models_pytorch as smp
import gc

# Albumentations for augmentations
import albumentations as A
from albumentations.pytorch import ToTensorV2

from PIL import Image
import cv2
import matplotlib.pyplot as plt
from matplotlib.patches import Rectangle
%matplotlib inline

# Sklearn
# from sklearn.model_selection import StratifiedKFold, KFold, StratifiedGroupKFold

import os
import time
import random
from glob import glob
import copy
from pathlib import Path
from typing import Callable
from typing import List
from typing import Optional
from typing import Tuple
from tqdm import tqdm
tqdm.pandas()

# For colored terminal text
from colorama import Fore, Back, Style

In [None]:
class CFG:
    IMG_SIZE = [320, 384]
    KAGGLE_DIR = Path("/") / "kaggle"
    INPUT_DIR = KAGGLE_DIR / "input"
    OUTPUT_DIR = KAGGLE_DIR / "working"

    DATA_DIR = INPUT_DIR / "uw-madison-gi-tract-image-segmentation"
#     INPUT_TRAIN_DIR = INPUT_DIR / "uwmgisegmentationpreprocessed"
    DEBUG = False # Debug complete pipeline
    N_SPLITS = 5
    SEED = 42
    USE_DEPTH = True # True for 2.5D data
    USE_AUGS = True
    BATCH_SIZE = 64
    NUM_WORKERS = 2
    ARCH = "Unet"
    ENCODER_NAME = "efficientnet-b1"
    ENCODER_WEIGHTS = "imagenet"
    LOSS = ["bce", "tversky"]
    NUM_CLASSES = 3
    TRAIN_BS      = 64
    VALID_BS      = TRAIN_BS*2
    LEARNING_RATE = 2e-3
    WEIGHT_DECAY = 1e-6
    NUM_EPOCHS = 2 if DEBUG else 15
    THR = 0.5
    load_saved_weights = True
    saved_weights_dir = INPUT_DIR / "uwmgi-unet"

    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    FOLDS=[0,1] if DEBUG else [0]
    
    imp_color  = Fore.GREEN
    reset_style = Style.RESET_ALL

In [None]:
def set_seed(seed = 42):
    '''Sets the seed of the entire notebook so results are the same every time we run.
    This is for REPRODUCIBILITY.'''
    np.random.seed(seed)
    random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    # When running on the CuDNN backend, two further options must be set
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    # Set a fixed value for the hash seed
    os.environ['PYTHONHASHSEED'] = str(seed)
    print('> SEEDING DONE')
    
set_seed(CFG.SEED)

In [None]:
def get_metadata(row):
    data = row['id'].split('_')
    case = int(data[0].replace('case',''))
    day = int(data[1].replace('day',''))
    slice_ = int(data[-1])
    row['case'] = case
    row['day'] = day
    row['slice'] = slice_
    return row

def path2info(row):
    path = row['image_path']
    data = path.split('/')
    slice_ = int(data[-1].split('_')[1])
    case = int(data[-3].split('_')[0].replace('case',''))
    day = int(data[-3].split('_')[1].replace('day',''))
    width = int(data[-1].split('_')[2])
    height = int(data[-1].split('_')[3])
    row['height'] = height
    row['width'] = width
    row['case'] = case
    row['day'] = day
    row['slice'] = slice_
#     row['id'] = f'case{case}_day{day}_slice_{slice_}'
    return row

In [None]:
#ref: https://www.kaggle.com/paulorzp/run-length-encode-and-decode
def rle_decode(mask_rle, shape):
    '''
    mask_rle: run-length as string formated (start length)
    shape: (height,width) of array to return 
    Returns numpy array, 1 - mask, 0 - background

    '''
    s = mask_rle.split()
    starts, lengths = [np.asarray(x, dtype=int) for x in (s[0:][::2], s[1:][::2])]
    starts -= 1
    ends = starts + lengths
    img = np.zeros(shape[0]*shape[1], dtype=np.uint8)
    for lo, hi in zip(starts, ends):
        img[lo:hi] = 1
    return img.reshape(shape)  # Needed to align to RLE direction


# ref.: https://www.kaggle.com/stainsby/fast-tested-rle
def rle_encode(img):
    '''
    img: numpy array, 1 - mask, 0 - background
    Returns run length as string formated
    '''
    pixels = img.flatten()
    pixels = np.concatenate([[0], pixels, [0]])
    runs = np.where(pixels[1:] != pixels[:-1])[0] + 1
    runs[1::2] -= runs[::2]
    return ' '.join(str(x) for x in runs)

In [None]:
def show_img(img, mask=None):
    clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8,8))
    img = clahe.apply(img)
    plt.imshow(img, cmap='bone')
    
    if mask is not None:
        # plt.imshow(np.ma.masked_where(mask!=1, mask), alpha=0.5, cmap='autumn')
        plt.imshow(mask, alpha=0.5)
        handles = [Rectangle((0,0),1,1, color=_c) for _c in [(0.667,0.0,0.0), (0.0,0.667,0.0), (0.0,0.0,0.667)]]
        labels = ["Large Bowel", "Small Bowel", "Stomach"]
        plt.legend(handles,labels)
    plt.axis('off')

In [None]:
def get_transforms(split='train', img_size = CFG.IMG_SIZE):
    if split == 'valid':
        return A.Compose([
    #         A.Resize(img_size, interpolation=cv2.INTER_NEAREST),
            ToTensorV2(transpose_mask=True)
            ], p=1.0)
    elif split == 'train':
        return A.Compose([
            A.HorizontalFlip(p=0.5),
    #         A.VerticalFlip(p=0.5),
            A.ShiftScaleRotate(shift_limit=0.0625, scale_limit=0.05, rotate_limit=10, p=0.5),
#             A.RandomResizedCrop(*img_size, scale=(0.3, 1.0), ratio=(0.75, 1.3333333333333333), interpolation=1, always_apply=False, p=0.5),
            A.OneOf([
                A.GridDistortion(num_steps=5, distort_limit=0.05, p=1.0),
    # #             A.OpticalDistortion(distort_limit=0.05, shift_limit=0.05, p=1.0),
                A.ElasticTransform(alpha=1, sigma=50, alpha_affine=50, p=1.0)
            ], p=0.25),
            A.CoarseDropout(max_holes=8, max_height=img_size[0]//20, max_width=img_size[1]//20,
                             min_holes=5, fill_value=0, mask_fill_value=0, p=0.5),
            ToTensorV2(transpose_mask=True)
            ], p=1.0)

In [None]:
class Dataset(torch.utils.data.Dataset):
    def __init__(self, df, label=False, transforms=None):
        self.df         = df
        self.label      = label
        self.img_paths  = df['image_paths'].tolist()
        self.ids        = df['id'].tolist()
        if 'mask_path' in df.columns:
            self.mask_paths  = df['mask_path'].tolist()
        else:
            self.mask_paths = None
        self.transforms = transforms
        
    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, index):
        img_path  = self.img_paths[index]
        id_       = self.ids[index]
        img = []
        img, shape0 = self.load_imgs(img_path)
        h, w = shape0
        if self.label:
            mask_path = self.mask_paths[index]
            mask = load_mask(mask_path)
            if self.transforms:
                data = self.transforms(image=img, mask=mask)
                img  = data['image']
                mask  = data['mask']
                
            return img, mask
#             img = np.transpose(img, (2, 0, 1))
#             msk = np.transpose(mask, (2, 0, 1))
#             return torch.tensor(img), torch.tensor(mask)
        else:
            if self.transforms:
                data = self.transforms(image=img)
                img  = data['image']
#             img = np.transpose(img, (2, 0, 1))
#             return torch.tensor(img), id_, h, w
            return img, id_, h, w
        
    @staticmethod
    def load_img(path, size=CFG.IMG_SIZE):
        img = cv2.imread(path, cv2.IMREAD_UNCHANGED)
        shape0 = np.array(img.shape[:2])
        resize = np.array(size)
        if np.any(shape0!=resize):
            diff = resize - shape0
            pad0 = diff[0]
            pad1 = diff[1]
            pady = [pad0//2, pad0//2 + pad0%2]
            padx = [pad1//2, pad1//2 + pad1%2]
            img = np.pad(img, [pady, padx])
            img = img.reshape((*resize))
        return img, shape0

    @staticmethod
    def load_imgs(img_paths, size=CFG.IMG_SIZE):
        imgs = np.zeros((*size, len(img_paths)), dtype=np.float32)
        for i, img_path in enumerate(img_paths):
            if i==0:
                img, shape0 = Dataset.load_img(img_path, size=size)
            else:
                img, _ = Dataset.load_img(img_path, size=size)
            img = img.astype('float32') # original is uint16
            mx = np.max(img)
            if mx:
                img/=mx # scale image to [0, 1]
            imgs[..., i]+=img
        return imgs, shape0

    @staticmethod
    def load_mask(path, size=CFG.IMG_SIZE):
        msk = np.load(path)
        shape0 = np.array(msk.shape[:2])
        resize = np.array(size)
        if np.any(shape0!=resize):
            diff = resize - shape0
            pad0 = diff[0]
            pad1 = diff[1]
            pady = [pad0//2, pad0//2 + pad0%2]
            padx = [pad1//2, pad1//2 + pad1%2]
            msk = np.pad(msk, [pady, padx, [0,0]])
            msk = msk.reshape((*resize, 3))
        msk = msk.astype('float32')
        msk/=255.0
        return msk

In [None]:
class SegmentationModel(nn.Module):
    def __init__(self, arch:str ,encoder_name:str, encoder_weights:str, num_classes:int):
        super(SegmentationModel, self).__init__()
        self.arch = arch
        self.encoder_name = encoder_name
        self.encoder_weights = encoder_weights

        self.model = smp.create_model(
            self.arch,
            encoder_name=self.encoder_name,
            encoder_weights=self.encoder_weights,
            in_channels=3,
            classes=num_classes,
            activation=None,
        )
            
    def forward(self, images):
        return self.model(images)
    
def load_model(model, path):
    model.load_state_dict(torch.load(path))
    return model

In [None]:
import cupy as cp

def mask2rle(msk, thr):
    '''
    img: numpy array, 1 - mask, 0 - background
    Returns run length as string formated
    '''
    msk    = cp.array(msk)
    pixels = msk.flatten()
    pad    = cp.array([0])
    pixels = cp.concatenate([pad, pixels, pad])
    runs   = cp.where(pixels[1:] != pixels[:-1])[0] + 1
    runs[1::2] -= runs[::2]
    return ' '.join(str(x) for x in runs)

def masks2rles(msks, ids, heights, widths):
    pred_strings = []; pred_ids = []; pred_classes = [];
    for idx in range(msks.shape[0]):
        msk = msks[idx]
        height = heights[idx].item()
        width = widths[idx].item()
        shape0 = np.array([height, width])
        resize = np.array([320, 384])
        if np.any(shape0!=resize):
            diff = resize - shape0
            pad0 = diff[0]
            pad1 = diff[1]
            pady = [pad0//2, pad0//2 + pad0%2]
            padx = [pad1//2, pad1//2 + pad1%2]
            msk = msk[pady[0]:-pady[1], padx[0]:-padx[1], :]
            msk = msk.reshape((*shape0, 3))
        rle = [None]*3
        for midx in [0, 1, 2]:
            rle[midx] = mask2rle(msk[...,midx], CFG.THR)
        pred_strings.extend(rle)
        pred_ids.extend([ids[idx]]*len(rle))
        pred_classes.extend(['large_bowel', 'small_bowel', 'stomach'])
    return pred_strings, pred_ids, pred_classes

In [None]:
@torch.no_grad()
def infer(test_loader, num_log=1, thr=CFG.THR):
    msks = []; imgs = [];
    pred_strings = []; pred_ids = []; pred_classes = [];
    models = []
    for fold in CFG.FOLDS:
        model = SegmentationModel(CFG.ARCH, CFG.ENCODER_NAME, None, CFG.NUM_CLASSES)
        model = load_model(model, CFG.saved_weights_dir / f'best_epoch-{fold:02d}.pt')
        model.to(CFG.device)
        model.eval()
        models.append(model)
    for idx, (img, ids, heights, widths) in enumerate(tqdm(test_loader, total=len(test_loader), desc='Infer ')):
        img = img.to(CFG.device, dtype=torch.float) # .squeeze(0)
        size = img.size()
        msk = []
        msk = torch.zeros((size[0], 3, size[2], size[3]), device=CFG.device, dtype=torch.float32)
        for model in models:
            out   = model(img) # .squeeze(0) # removing batch axis
            out   = nn.Sigmoid()(out) # removing channel axis
            msk+=out/len(models)
        msk = (msk.permute((0,2,3,1))>thr).to(torch.uint8).cpu().detach().numpy() # shape: (n, h, w, c)
        result = masks2rles(msk, ids, heights, widths)
        pred_strings.extend(result[0])
        pred_ids.extend(result[1])
        pred_classes.extend(result[2])
        if idx<num_log and CFG.DEBUG:
            img = img.permute((0,2,3,1)).cpu().detach().numpy()
            imgs.append(img[::5])
            msks.append(msk[::5])
        del img, msk, out, model, result
        gc.collect()
        torch.cuda.empty_cache()
    return pred_strings, pred_ids, pred_classes, imgs, msks

In [None]:
sub_df = pd.read_csv(CFG.DATA_DIR / 'sample_submission.csv')
if not len(sub_df):
    CFG.DEBUG = True
    sub_df = pd.read_csv(CFG.DATA_DIR / 'train.csv')
    sub_df = sub_df[~sub_df.segmentation.isna()][:1000*3]
    sub_df = sub_df.drop(columns=['class','segmentation']).drop_duplicates()
else:
    CFG.DEBUG = False
    sub_df = sub_df.drop(columns=['class','predicted']).drop_duplicates()
sub_df = sub_df.progress_apply(get_metadata,axis=1)

In [None]:
if CFG.DEBUG:
    paths = glob(f'/kaggle/input/uw-madison-gi-tract-image-segmentation/train/**/*png',recursive=True)
#     paths = sorted(paths)
else:
    paths = glob(f'/kaggle/input/uw-madison-gi-tract-image-segmentation/test/**/*png',recursive=True)
#     paths = sorted(paths)
path_df = pd.DataFrame(paths, columns=['image_path'])
path_df = path_df.progress_apply(path2info, axis=1)
path_df.head()

In [None]:
test_df = sub_df.merge(path_df, on=['case','day','slice'], how='left')
test_df.head()

In [None]:
channels=3
stride=2
for i in range(channels):
    test_df[f'image_path_{i:02}'] = test_df.groupby(['case','day'])['image_path'].shift(-i*stride).fillna(method="ffill")
test_df['image_paths'] = test_df[[f'image_path_{i:02d}' for i in range(channels)]].values.tolist()
if CFG.DEBUG:
    test_df = test_df.sample(frac=1.0)
test_df.image_paths[0]

In [None]:
test_dataset = Dataset(test_df, transforms=get_transforms('valid'))
test_loader  = DataLoader(test_dataset, batch_size=CFG.VALID_BS, 
                          num_workers=2, shuffle=False, pin_memory=False)
pred_strings, pred_ids, pred_classes, imgs, msks = infer(test_loader)

In [None]:
if CFG.DEBUG:
    for img, msk in zip(imgs[0][:5], msks[0][:5]):
        plt.figure(figsize=(12, 7))
        plt.subplot(1, 3, 1); plt.imshow(img, cmap='bone');
        plt.axis('OFF'); plt.title('image')
        plt.subplot(1, 3, 2); plt.imshow(msk*255); plt.axis('OFF'); plt.title('mask')
        plt.subplot(1, 3, 3); plt.imshow(img, cmap='bone'); plt.imshow(msk*255, alpha=0.4);
        plt.axis('OFF'); plt.title('overlay')
        plt.tight_layout()
        plt.show()

In [None]:
del imgs, msks
gc.collect()

In [None]:
pred_df = pd.DataFrame({
    "id":pred_ids,
    "class":pred_classes,
    "predicted":pred_strings
})
if not CFG.DEBUG:
    sub_df = pd.read_csv(CFG.DATA_DIR / 'sample_submission.csv')
    del sub_df['predicted']
else:
    sub_df = pd.read_csv(CFG.DATA_DIR / 'train.csv')[:1000*3]
    del sub_df['segmentation']
    
sub_df = sub_df.merge(pred_df, on=['id','class'])
sub_df.to_csv('submission.csv',index=False)
display(sub_df.head(5))