# Installs

In [None]:
!cd ../input/gi-seg-downloads && \
pip install -q efficientnet_pytorch-0.6.3.tar.gz pretrainedmodels-0.7.4.tar.gz timm-0.4.12-py3-none-any.whl  segmentation_models_pytorch-0.2.1-py3-none-any.whl

# Imports

In [None]:
from pathlib import Path
from typing import Callable
from typing import List
from typing import Optional
from typing import Tuple
import os
import albumentations as A
import cupy as cp
import cv2
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import pytorch_lightning as pl
import segmentation_models_pytorch as smp
import torch
from albumentations.pytorch import ToTensorV2
from torch.utils.data import DataLoader, Dataset
from tqdm.notebook import tqdm

# Paths & Settings

In [None]:
KAGGLE_DIR = Path("/") / "kaggle"
INPUT_DIR = KAGGLE_DIR / "input"
OUTPUT_DIR = KAGGLE_DIR / "working"

INPUT_DATA_DIR = INPUT_DIR / "uw-madison-gi-tract-image-segmentation"
INPUT_DATA_NPY_DIR = INPUT_DIR / "uw-madison-gi-tract-image-segmentation-masks"

IMG_SIZE = 356
CROP_SIZE = 320
USE_AUGS = True
BATCH_SIZE = 32
NUM_WORKERS = 2
ENCODER_NAME = "efficientnet-b3"
GPUS = 1
CHANNELS = 5
DEVICE = "cuda"
THR = 0.45

DEBUG = False # Debug complete pipeline

# Dataset

In [None]:
transforms_val = A.Compose([
    A.CenterCrop(CROP_SIZE, CROP_SIZE, p=1),
    ToTensorV2(transpose_mask=True)
])

In [None]:
class UWDataset(Dataset):
    def __init__(self, df, transforms=None):
        self.df = df
        self.transforms = transforms

    def __len__(self):
        return len(self.df)

    def resize(self, img, interp):
        return cv2.resize(
            img, (IMG_SIZE, IMG_SIZE), interpolation=interp)

    def load_slice(self, img_file, diff):
        slice_num = os.path.basename(img_file).split('_')[1]
        filename = (
            img_file.replace(
                'slice_' + slice_num,
                'slice_' + str(int(slice_num) + diff).zfill(4)))
        if os.path.exists(filename):
            return cv2.imread(filename, cv2.IMREAD_UNCHANGED)
        return None

    def __getitem__(self, idx: int):
        row = self.df.iloc[idx]

        # read 5 slices into one image
        imgs = [self.load_slice(row["image_path"], i) for i in range(-2, 3)]
        if imgs[3] is None:
            imgs[3] = imgs[2]
        if imgs[4] is None:
            imgs[4] = imgs[3]
        if imgs[1] is None:
            imgs[1] = imgs[2]
        if imgs[0] is None:
            imgs[0] = imgs[1]
        image = np.stack(imgs, axis=2)
        image = image.astype(np.float32)
        h, w = image.shape[:2]
        max_val = image.max()
        if max_val != 0:
            image /= max_val
        image = self.resize(image, cv2.INTER_AREA)
        id_ = row["id"]

        if self.transforms:
            data = self.transforms(image=image)
            image = data["image"]
        r = {
            'image': image,
            'id': id_,
            'h': h,
            'w': w
        }
        return r

### Load Test Data

In [None]:
def extract_metadata_from_id(df):
    df[["case", "day", "slice"]] = df["id"].str.split("_", n=2, expand=True)

    df["case"] = df["case"].str.replace("case", "").astype(int)
    df["day"] = df["day"].str.replace("day", "").astype(int)
    df["slice"] = df["slice"].str.replace("slice_", "").astype(int)

    return df


def extract_metadata_from_path(path_df):
    path_df[["parent", "case_day", "scans", "file_name"]] = path_df["image_path"].str.rsplit("/", n=3, expand=True)

    path_df[["case", "day"]] = path_df["case_day"].str.split("_", expand=True)
    path_df["case"] = path_df["case"].str.replace("case", "")
    path_df["day"] = path_df["day"].str.replace("day", "")

    path_df[["slice", "width", "height", "spacing", "spacing_"]] = (
        path_df["file_name"].str.replace("slice_", "").str.replace(".png", "").str.split("_", expand=True)
    )
    path_df = path_df.drop(columns=["parent", "case_day", "scans", "file_name", "spacing_"])

    numeric_cols = ["case", "day", "slice", "width", "height", "spacing"]
    path_df[numeric_cols] = path_df[numeric_cols].apply(pd.to_numeric)

    return path_df

In [None]:
sub_df = pd.read_csv(INPUT_DATA_DIR / "sample_submission.csv")
test_set_hidden = not bool(len(sub_df))

if test_set_hidden:
    test_df = pd.read_csv(INPUT_DATA_DIR / "train.csv")[: 1000 * 3]
    test_df = test_df.drop(columns=["class", "segmentation"]).drop_duplicates()
    image_paths = [str(path) for path in (INPUT_DATA_DIR / "train").rglob("*.png")]
else:
    test_df = sub_df.drop(columns=["class", "predicted"]).drop_duplicates()
    image_paths = [str(path) for path in (INPUT_DATA_DIR / "test").rglob("*.png")]

test_df = extract_metadata_from_id(test_df)

path_df = pd.DataFrame(image_paths, columns=["image_path"])
path_df = extract_metadata_from_path(path_df)

test_df = test_df.merge(path_df, on=["case", "day", "slice"], how="left")

print(len(test_df))
test_df.head()

### Save Test DataFrame

In [None]:
test_df.to_csv("test_preprocessed.csv", index=False)

In [None]:
td = UWDataset(test_df, transforms=transforms_val)

In [None]:
td[0]['image'].shape

In [None]:
td[0]['id'],td[0]['h'],td[0]['w']

In [None]:
ls ../input/exp01017/

In [None]:
model_pths = [
    '../input/exp01017/expexp010-bestloss-fold0-7.ckpt'
]

In [None]:
def build_model():
    model = smp.Unet(
        encoder_name=ENCODER_NAME,   
        encoder_weights=None,
        in_channels=CHANNELS,                  
        classes=3,        
        activation=None,
        decoder_use_batchnorm=True,
        decoder_attention_type='scse'
    )
    return model

In [None]:
def load_model(path):
    model = build_model()
    state = torch.load(path)['state_dict']
    nstate = {}
    for k,v in state.items():
        nstate[k[4:]] = v
    model.load_state_dict(nstate)
    model.to(DEVICE)
    model.eval()
    return model

## Run inference

In [None]:
def mask2rle(mask):
    """
    img: numpy array, 1 - mask, 0 - background
    Returns run length as string formated
    """
    mask = cp.array(mask)
    pixels = mask.flatten()
    pad = cp.array([0])
    pixels = cp.concatenate([pad, pixels, pad])
    runs = cp.where(pixels[1:] != pixels[:-1])[0] + 1
    runs[1::2] -= runs[::2]

    return " ".join(str(x) for x in runs)

def pad_mask(mask):
    # pad image to conf.image_size
    padded = np.zeros((IMG_SIZE, IMG_SIZE, 3), dtype=mask.dtype)
    dh = IMG_SIZE - mask.shape[0]
    dw = IMG_SIZE - mask.shape[1]

    top = dh//2
    left = dw//2
    padded[top:top + mask.shape[0], left:left + mask.shape[1], :] = mask
    return padded

def resize_mask(mask, height, width):
    # print(mask.shape)  # (356, 356, 3)
    msk = np.zeros((height, width, 3), dtype=mask.dtype)
    msk[:,:,0] = cv2.resize(mask[:,:,0], (width, height), interpolation=cv2.INTER_NEAREST)
    msk[:,:,1] = cv2.resize(mask[:,:,1], (width, height), interpolation=cv2.INTER_NEAREST)
    msk[:,:,2] = cv2.resize(mask[:,:,2], (width, height), interpolation=cv2.INTER_NEAREST)
    return msk

def masks2rles(masks, ids, heights, widths):
    pred_strings = []
    pred_ids = []
    pred_classes = []

    for idx in range(masks.shape[0]):
        mask = pad_mask(masks[idx])  # crop_size to img_size
        mask = resize_mask(mask, heights[idx].item(), widths[idx].item()) # img_size to ori_size
        rle = [None] * 3
        for midx in [0, 1, 2]:
            rle[midx] = mask2rle(mask[..., midx])

        pred_strings.extend(rle)
        pred_ids.extend([ids[idx]] * len(rle))
        pred_classes.extend(["large_bowel", "small_bowel", "stomach"])

    return pred_strings, pred_ids, pred_classes


@torch.no_grad()
def infer(model_paths, thr):
    
    test_set = UWDataset(test_df, transforms=transforms_val)
    test_dataloader = DataLoader(test_set,
                              batch_size=BATCH_SIZE,
                              num_workers=NUM_WORKERS, pin_memory=False, drop_last=False)
    
    pred_strings = []
    pred_ids = []
    pred_classes = []

    # for imgs, ids, heights, widths in tqdm(test_dataloader):
    for r in tqdm(test_dataloader):
        imgs, ids, heights, widths = r['image'], r['id'], r['h'], r['w']
        imgs = imgs.to(DEVICE, dtype=torch.float)
        size = imgs.size()

        masks = []
        masks = torch.zeros((size[0], 3, size[2], size[3]), device=DEVICE, dtype=torch.float32)

        for path in model_paths:
            model = load_model(path)
            out = model(imgs)
            out = torch.nn.Sigmoid()(out)
            masks += out / len(model_paths)

        masks = (masks.permute((0, 2, 3, 1)) > thr).to(torch.uint8).cpu().detach().numpy()  # shape: (n, h, w, c)

        result = masks2rles(masks, ids, heights, widths)
        pred_strings.extend(result[0])
        pred_ids.extend(result[1])
        pred_classes.extend(result[2])

    pred_df = pd.DataFrame({"id": pred_ids, "class": pred_classes, "predicted": pred_strings})

    return pred_df

In [None]:
pred_df = infer(model_pths, THR)

In [None]:
a = np.zeros((100, 200, 2));a.shape

In [None]:
b = np.zeros((200, 300));b.shape

In [None]:
c = cv2.resize(b, (100, 200));c.shape

## Submit

In [None]:
if not test_set_hidden:
    sub_df = pd.read_csv("../input/uw-madison-gi-tract-image-segmentation/sample_submission.csv")
    del sub_df["predicted"]
else:
    sub_df = pd.read_csv("../input/uw-madison-gi-tract-image-segmentation/train.csv")[: 1000 * 3]
    del sub_df["segmentation"]

sub_df = sub_df.merge(pred_df, on=["id", "class"])
sub_df.to_csv("submission.csv", index=False)
display(sub_df.head(5))

## 