Most of the codes are modified from https://www.kaggle.com/code/yiheng/3d-solution-with-monai-infer, thanks YIHENG WANG.

**Load libraries**

In [1]:
#all codes are modified from https://www.kaggle.com/code/yiheng/3d-solution-with-monai-infer
#used dynunet pipeline instead of unet
import sys

sys.path.append('../input/monai-v081/')

In [2]:
import gc
from glob import glob
import os
import numpy as np
import pandas as pd
import torch
from torch import nn
from monai.inferers import sliding_window_inference
from monai.data import decollate_batch
from monai.handlers.utils import from_engine
from monai.networks.nets import DynUNet, UNet
from torch.cuda.amp import GradScaler, autocast
from tqdm import tqdm
import json
import copy

In [3]:
from monai.data import CacheDataset, DataLoader
from monai.transforms import (
    Compose,
    Activations,
    AsDiscrete,
    Activationsd,
    AsDiscreted,
    KeepLargestConnectedComponentd,
    Invertd,
    LoadImage,
    Transposed,
    LoadImaged,
    AddChanneld,
    CastToTyped,
    Lambdad,
    Resized,
    EnsureTyped,
    SpatialPadd,
    EnsureChannelFirstd,
)

## Prepare meta info.

### Thanks awsaf49, this section refers to:
https://www.kaggle.com/code/awsaf49/uwmgi-2-5d-infer-pytorch

In [4]:
def get_metadata(row):
    data = row['id'].split('_')
    case = int(data[0].replace('case',''))
    day = int(data[1].replace('day',''))
    slice_ = int(data[-1])
    row['case'] = case
    row['day'] = day
    row['slice'] = slice_
    return row

def path2info(row):
    path = row['image_path']
    data = path.split('/')
    slice_ = int(data[-1].split('_')[1])
    case = int(data[-3].split('_')[0].replace('case',''))
    day = int(data[-3].split('_')[1].replace('day',''))
    width = int(data[-1].split('_')[2])
    height = int(data[-1].split('_')[3])
    row['height'] = height
    row['width'] = width
    row['case'] = case
    row['day'] = day
    row['slice'] = slice_
    return row

In [5]:
sub_df = pd.read_csv('../input/uw-madison-gi-tract-image-segmentation/sample_submission.csv')
if not len(sub_df):
    debug = True
    sub_df = pd.read_csv('../input/uw-madison-gi-tract-image-segmentation/train.csv')[:1000*12]
    sub_df = sub_df.drop(columns=['class','segmentation']).drop_duplicates()
else:
    debug = False
    sub_df = sub_df.drop(columns=['class','predicted']).drop_duplicates()
sub_df = sub_df.apply(lambda x: get_metadata(x),axis=1)

In [6]:
if debug:
    paths = glob(f'/kaggle/input/uw-madison-gi-tract-image-segmentation/train/**/*png',recursive=True)
#     paths = sorted(paths)
else:
    paths = glob(f'/kaggle/input/uw-madison-gi-tract-image-segmentation/test/**/*png',recursive=True)
#     paths = sorted(paths)
path_df = pd.DataFrame(paths, columns=['image_path'])
path_df = path_df.apply(lambda x: path2info(x),axis=1)
path_df.head()

Unnamed: 0,image_path,height,width,case,day,slice
0,/kaggle/input/uw-madison-gi-tract-image-segmen...,266,266,36,14,6
1,/kaggle/input/uw-madison-gi-tract-image-segmen...,266,266,36,14,82
2,/kaggle/input/uw-madison-gi-tract-image-segmen...,266,266,36,14,113
3,/kaggle/input/uw-madison-gi-tract-image-segmen...,266,266,36,14,76
4,/kaggle/input/uw-madison-gi-tract-image-segmen...,266,266,36,14,125


## Produce 3d data list for MONAI DataSet

In [7]:
test_df = sub_df.merge(path_df, on=['case','day','slice'], how='left')
test_df["case_id_str"] = test_df["id"].apply(lambda x: x.split("_", 2)[0])
test_df["day_num_str"] = test_df["id"].apply(lambda x: x.split("_", 2)[1])
test_df["slice_id"] = test_df["id"].apply(lambda x: x.split("_", 2)[2])

In [8]:
test_data = []

for group in test_df.groupby(["case_id_str", "day_num_str"]):

    case_id_str, day_num_str = group[0]
    group_id = case_id_str + "_" + day_num_str
    group_df = group[1].sort_values("slice_id", ascending=True)
    n_slices = group_df.shape[0]
    group_slices, group_ids = [], []
    for idx in range(n_slices):
        slc = group_df.iloc[idx]
        group_slices.append(slc.image_path)
        group_ids.append(slc.id)
    test_data.append({"image": group_slices, "id": group_ids})

## Prepare Transforms, Dataset, DataLoader

In [9]:
class cfg:
    img_size = (192, 192, 80)
    in_channels = 1
    out_channels = 3
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    #weights = glob("../input/unet-ddp-models/*")
    weights = ["../input/bunet-dice-sgd-cosine/bUnet_finetune_bestloss_0.072235_epoch_1124.pt", "../input/bunet-dice-sgd-cosine/bUnet_Dice_bestloss_0.065661_epoch_1088.pt", "../input/bunet-dice-sgd-cosine/bUnet_bestloss_0.064913_epoch_1095.pt"]
    batch_size = 1
    sw_batch_size = 1

In [10]:
test_transforms = Compose(
    [
        LoadImaged(keys="image"), # d, h, w
        AddChanneld(keys="image"), # c, d, h, w
        Transposed(keys="image", indices=[0, 2, 3, 1]), # c, w, h, d wrong
        Lambdad(keys="image", func=lambda x: x / x.max()),
        SpatialPadd(keys="image", spatial_size=cfg.img_size, method="end"),  # in case less than 80 slices
        EnsureTyped(keys="image", dtype=torch.float32),
    ]
)

test_ds = CacheDataset(
        data=test_data,
        transform=test_transforms,
        cache_rate=0.0,
        num_workers=2,
    )

test_dataloader = DataLoader(
    test_ds,
    batch_size=cfg.batch_size,
    num_workers=2,
    pin_memory=True,
)

In [11]:
# #https://github.com/Project-MONAI/tutorials/blob/main/modules/dynunet_pipeline/create_network.py

# def get_kernels_strides(img_size):
#     sizes, spacings = img_size, (1.5, 1.5, 1)   
#     input_size=sizes
#     strides, kernels = [], []

#     while True:
#         spacing_ratio = [sp / min(spacings) for sp in spacings]
#         stride = [
#             2 if ratio <= 2 and size >= 8 else 1
#             for (ratio, size) in zip(spacing_ratio, sizes)
#         ]
#         kernel = [3 if ratio <= 2 else 1 for ratio in spacing_ratio]
#         if all(s == 1 for s in stride):
#             break
#         for idx, (i, j) in enumerate(zip(sizes, stride)):
#             if i % j != 0:
#                 raise ValueError(
#                     f"Patch size is not supported, please try to modify the size {input_size[idx]} in the spatial dimension {idx}."
#                 )
#         sizes = [i / j for i, j in zip(sizes, stride)]
#         spacings = [i * j for i, j in zip(spacings, stride)]
#         kernels.append(kernel)
#         strides.append(stride)

#     strides.insert(0, len(spacings) * [1])
#     kernels.append(len(spacings) * [3])
   
#     return kernels, strides

## Prepare Network

In [12]:
# kernels, strides = get_kernels_strides(cfg.img_size)
# model = DynUNet(
#     spatial_dims=3,
#     in_channels=1,
#     out_channels=3,
#     strides=strides,
#     kernel_size=kernels,
#     upsample_kernel_size=strides[1:],
#     norm_name="instance",
#     deep_supervision=True,
#     deep_supr_num=3,
#     res_block=True,
# ).to(cfg.device)


model = UNet(
    spatial_dims=3,
    in_channels=1,
    out_channels=3,
    channels=(48, 96, 192, 384, 768),
    strides=(2, 2, 2, 2),
    kernel_size=3,
    up_kernel_size=3,
    num_res_units=4,
    act="PRELU",
    norm="instance",
    dropout=0.0,
    bias=True,
    dimensions=None,
).to(cfg.device)

## Infer

In [13]:
# ref.: https://www.kaggle.com/stainsby/fast-tested-rle
def rle_encode(img):
    """ TBD
    
    Args:
        img (np.array): 
            - 1 indicating mask
            - 0 indicating background
    
    Returns: 
        run length as string formated
    """
    
    pixels = img.flatten()
    pixels = np.concatenate([[0], pixels, [0]])
    runs = np.where(pixels[1:] != pixels[:-1])[0] + 1
    runs[1::2] -= runs[::2]
    return ' '.join(str(x) for x in runs)

In [14]:
def choose_longest(label_array, length_threshold = 13):

    candidates = {}
    slice_id = -1
    tolerance = 0 #variable for disconnection tolerance
    toler_length = 3 #toler_length 2 means tolerate disconnection of slice length (toler_length-1). 

    for idx, slice in enumerate(label_array):

        if tolerance == 0:
            if np.any(slice):
                if slice_id in candidates:
                    candidates[slice_id] += 1
                    tolerance = toler_length 
                else:
                    slice_id = idx
                    candidates[slice_id] = 1
                    tolerance = toler_length
            else:
                slice_id = -1
        
        if tolerance > 0:
            if np.any(slice):
                candidates[slice_id] += 1
                tolerance = toler_length 
            else:
                candidates[slice_id] += 1
                tolerance -= 1
            
    false_alarms = {idx: length for idx, length in candidates.items() if length < length_threshold}
    #print(f"candidates: {candidates}, false alarms: {false_alarms}")
   
    for indices in false_alarms:
        zero_array = np.zeros_like(label_array[indices:indices + false_alarms[indices]])
        label_array[indices:indices + false_alarms[indices]] = zero_array

    return label_array


In [15]:
outputs = []

post_pred = Compose([
    Activations(sigmoid=True),
    AsDiscrete(threshold=0.5),
])

model.eval()
torch.set_grad_enabled(False)
progress_bar = tqdm(range(len(test_dataloader)))
val_it = iter(test_dataloader)
for itr in progress_bar:
    batch = next(val_it)
    test_inputs = batch["image"].to(cfg.device)
    pred_all = []
    for weights in cfg.weights:
        #model.load_state_dict(torch.load(weights)["model"])
        model.load_state_dict(torch.load(weights))
        pred = sliding_window_inference(test_inputs, cfg.img_size, cfg.sw_batch_size, model, overlap=0.6)
        pred_all.append(pred)
        # do 4 tta
        for dims in [[2], [3], [2, 3]]:
            flip_pred = sliding_window_inference(torch.flip(test_inputs, dims=dims), cfg.img_size, cfg.sw_batch_size, model, overlap=0.6)
            flip_pred = torch.flip(flip_pred, dims=dims)
            pred_all.append(flip_pred)

    
    pred_all = torch.mean(torch.stack(pred_all), dim=0)[0]
    pred_all = post_pred(pred_all)
    # c, w, h, d to d, c, h, w
    
    pred_all = torch.permute(pred_all, [3, 0, 2, 1]).cpu().numpy().astype(np.uint8)

    #post-process
    a = np.transpose(pred_all,(1,0,3,2)) #to c d w h 
    for idx, i in enumerate(a):
        a[idx] = choose_longest(i)
    pred_all = np.transpose(a,(1,0,3,2)) #back to d c h w

    id_outputs = from_engine(["id"])(batch)[0]

    
    for test_output, id_output in zip(pred_all, id_outputs):
        id_name = id_output[0]
        lb, sb, st = test_output
        outputs.append([id_name, "large_bowel", rle_encode(lb)])
        outputs.append([id_name, "small_bowel", rle_encode(sb)])
        outputs.append([id_name, "stomach", rle_encode(st)])


100%|██████████| 28/28 [20:43<00:00, 44.39s/it]


In [16]:
submit = pd.DataFrame(data=np.array(outputs), columns=["id", "class", "predicted"])

In [17]:
# Fix sub error, refers to: https://www.kaggle.com/competitions/uw-madison-gi-tract-image-segmentation/discussion/320541
if not debug:
    sub_df = pd.read_csv('../input/uw-madison-gi-tract-image-segmentation/sample_submission.csv')
    del sub_df['predicted']
    sub_df = sub_df.merge(submit, on=['id','class'])
    sub_df.to_csv('submission.csv',index=False)
else:
    submit.to_csv('submission.csv', index=False)