## Load Libs

In [1]:
import sys

sys.path.append('../input/monai-v081/')

In [2]:
import gc
from glob import glob
import os
import numpy as np
import pandas as pd
import torch
from torch import nn
from monai.inferers import sliding_window_inference
from monai.data import decollate_batch
from monai.handlers.utils import from_engine
from monai.networks.nets import UNet
from torch.cuda.amp import GradScaler, autocast
from tqdm import tqdm
import json

In [3]:
from monai.data import CacheDataset, DataLoader
from monai.transforms import (
    Compose,
    Activations,
    AsDiscrete,
    Activationsd,
    AsDiscreted,
    KeepLargestConnectedComponentd,
    Invertd,
    LoadImage,
    Transposed,
    LoadImaged,
    AddChanneld,
    CastToTyped,
    Lambdad,
    Resized,
    EnsureTyped,
    SpatialPadd,
    EnsureChannelFirstd,
)

## Prepare meta info.

### Thanks awsaf49, this section refers to:
https://www.kaggle.com/code/awsaf49/uwmgi-2-5d-infer-pytorch

In [4]:
def get_metadata(row):
    data = row['id'].split('_')
    case = int(data[0].replace('case',''))
    day = int(data[1].replace('day',''))
    slice_ = int(data[-1])
    row['case'] = case
    row['day'] = day
    row['slice'] = slice_
    return row

def path2info(row):
    path = row['image_path']
    data = path.split('/')
    slice_ = int(data[-1].split('_')[1])
    case = int(data[-3].split('_')[0].replace('case',''))
    day = int(data[-3].split('_')[1].replace('day',''))
    width = int(data[-1].split('_')[2])
    height = int(data[-1].split('_')[3])
    row['height'] = height
    row['width'] = width
    row['case'] = case
    row['day'] = day
    row['slice'] = slice_
    return row

In [5]:
sub_df = pd.read_csv('../input/uw-madison-gi-tract-image-segmentation/sample_submission.csv')
if not len(sub_df):
    debug = True
    sub_df = pd.read_csv('../input/uw-madison-gi-tract-image-segmentation/train.csv')[:1000*3]
    sub_df = sub_df.drop(columns=['class','segmentation']).drop_duplicates()
else:
    debug = False
    sub_df = sub_df.drop(columns=['class','predicted']).drop_duplicates()
sub_df = sub_df.apply(lambda x: get_metadata(x),axis=1)

In [6]:
if debug:
    paths = glob(f'/kaggle/input/uw-madison-gi-tract-image-segmentation/train/**/*png',recursive=True)
#     paths = sorted(paths)
else:
    paths = glob(f'/kaggle/input/uw-madison-gi-tract-image-segmentation/test/**/*png',recursive=True)
#     paths = sorted(paths)
path_df = pd.DataFrame(paths, columns=['image_path'])
path_df = path_df.apply(lambda x: path2info(x),axis=1)
path_df.head()

Unnamed: 0,image_path,height,width,case,day,slice
0,/kaggle/input/uw-madison-gi-tract-image-segmen...,266,266,36,14,6
1,/kaggle/input/uw-madison-gi-tract-image-segmen...,266,266,36,14,82
2,/kaggle/input/uw-madison-gi-tract-image-segmen...,266,266,36,14,113
3,/kaggle/input/uw-madison-gi-tract-image-segmen...,266,266,36,14,76
4,/kaggle/input/uw-madison-gi-tract-image-segmen...,266,266,36,14,125


## Produce 3d data list for MONAI DataSet

In [7]:
test_df = sub_df.merge(path_df, on=['case','day','slice'], how='left')
test_df["case_id_str"] = test_df["id"].apply(lambda x: x.split("_", 2)[0])
test_df["day_num_str"] = test_df["id"].apply(lambda x: x.split("_", 2)[1])
test_df["slice_id"] = test_df["id"].apply(lambda x: x.split("_", 2)[2])

In [8]:
test_data = []

for group in test_df.groupby(["case_id_str", "day_num_str"]):

    case_id_str, day_num_str = group[0]
    group_id = case_id_str + "_" + day_num_str
    group_df = group[1].sort_values("slice_id", ascending=True)
    n_slices = group_df.shape[0]
    group_slices, group_ids = [], []
    for idx in range(n_slices):
        slc = group_df.iloc[idx]
        group_slices.append(slc.image_path)
        group_ids.append(slc.id)
    test_data.append({"image": group_slices, "id": group_ids})

## Prepare Transforms, Dataset, DataLoader

In [9]:
class cfg:
    img_size = (224, 224, 80)
    in_channels = 1
    out_channels = 3
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    weights = glob("../input/uw3dweights/large*")
    batch_size = 1
    sw_batch_size = 4

In [10]:
test_transforms = Compose(
    [
        LoadImaged(keys="image"), # d, h, w
        AddChanneld(keys="image"), # c, d, h, w
        Transposed(keys="image", indices=[0, 2, 3, 1]), # c, w, h, d
        Lambdad(keys="image", func=lambda x: x / x.max()),
#         SpatialPadd(keys="image", spatial_size=cfg.img_size),  # in case less than 80 slices
        EnsureTyped(keys="image", dtype=torch.float32),
    ]
)

test_ds = CacheDataset(
        data=test_data,
        transform=test_transforms,
        cache_rate=0.0,
        num_workers=2,
    )

test_dataloader = DataLoader(
    test_ds,
    batch_size=cfg.batch_size,
    num_workers=2,
    pin_memory=True,
)

## Prepare Network

In [11]:
model = UNet(
    spatial_dims=3,
    in_channels=cfg.in_channels,
    out_channels=cfg.out_channels,
    channels=(32, 64, 128, 256, 512),
    strides=(2, 2, 2, 2),
    kernel_size=3,
    up_kernel_size=3,
    num_res_units=2,
    act="PRELU",
    norm="BATCH",
    dropout=0.2,
    bias=True,
    dimensions=None,
).to(cfg.device)

## Infer

In [12]:
# ref.: https://www.kaggle.com/stainsby/fast-tested-rle
def rle_encode(img):
    """ TBD
    
    Args:
        img (np.array): 
            - 1 indicating mask
            - 0 indicating background
    
    Returns: 
        run length as string formated
    """
    
    pixels = img.flatten()
    pixels = np.concatenate([[0], pixels, [0]])
    runs = np.where(pixels[1:] != pixels[:-1])[0] + 1
    runs[1::2] -= runs[::2]
    return ' '.join(str(x) for x in runs)

In [13]:
outputs = []

post_pred = Compose([
    Activations(sigmoid=True),
    AsDiscrete(threshold=0.5),
])

model.eval()
torch.set_grad_enabled(False)
progress_bar = tqdm(range(len(test_dataloader)))
val_it = iter(test_dataloader)
for itr in progress_bar:
    batch = next(val_it)
    test_inputs = batch["image"].to(cfg.device)

    pred_all = []
    for weights in cfg.weights:
        model.load_state_dict(torch.load(weights))
        pred = sliding_window_inference(test_inputs, cfg.img_size, cfg.sw_batch_size, model)
        pred_all.append(pred)
        # do 4 tta
        for dims in [[2], [3], [2, 3]]:
            flip_pred = sliding_window_inference(torch.flip(test_inputs, dims=dims), cfg.img_size, cfg.sw_batch_size, model)
            flip_pred = torch.flip(flip_pred, dims=dims)
            pred_all.append(flip_pred)
    
    pred_all = torch.mean(torch.stack(pred_all), dim=0)[0]
    pred_all = post_pred(pred_all)
    # c, w, h, d to d, c, h, w
    pred_all = torch.permute(pred_all, [3, 0, 2, 1]).cpu().numpy().astype(np.uint8)
    id_outputs = from_engine(["id"])(batch)[0]

    for test_output, id_output in zip(pred_all, id_outputs):
        id_name = id_output[0]
        lb, sb, st = test_output
        outputs.append([id_name, "large_bowel", rle_encode(lb)])
        outputs.append([id_name, "small_bowel", rle_encode(sb)])
        outputs.append([id_name, "stomach", rle_encode(st)])


100%|██████████| 7/7 [00:29<00:00,  4.20s/it]


In [14]:
submit = pd.DataFrame(data=np.array(outputs), columns=["id", "class", "predicted"])

In [15]:
# Fix sub error, refers to: https://www.kaggle.com/competitions/uw-madison-gi-tract-image-segmentation/discussion/320541
if not debug:
    sub_df = pd.read_csv('../input/uw-madison-gi-tract-image-segmentation/sample_submission.csv')
    del sub_df['predicted']
    sub_df = sub_df.merge(submit, on=['id','class'])
    sub_df.to_csv('submission.csv',index=False)
else:
    submit.to_csv('submission.csv', index=False)