## Load Libs

In [None]:
import sys

sys.path.append('/kaggle/input/monai-v081/')

In [None]:
import gc
from glob import glob
import os
import numpy as np
import pandas as pd
import torch
from torch import nn
from monai.inferers import sliding_window_inference
from monai.data import decollate_batch
from monai.handlers.utils import from_engine
from monai.networks.nets import UNet
from monai.data import CacheDataset, DataLoader
from torch.cuda.amp import GradScaler, autocast
from tqdm import tqdm
import json


## Prepare meta info.

### Thanks awsaf49, this section refers to:
https://www.kaggle.com/code/awsaf49/uwmgi-2-5d-infer-pytorch

In [None]:
def get_metadata(row):
    data = row['id'].split('_')
    case = int(data[0].replace('case',''))
    day = int(data[1].replace('day',''))
    slice_ = int(data[-1])
    row['case'] = case
    row['day'] = day
    row['slice'] = slice_
    return row

def path2info(row):
    path = row['image_path']
    data = path.split('/')
    slice_ = int(data[-1].split('_')[1])
    case = int(data[-3].split('_')[0].replace('case',''))
    day = int(data[-3].split('_')[1].replace('day',''))
    width = int(data[-1].split('_')[2])
    height = int(data[-1].split('_')[3])
    row['height'] = height
    row['width'] = width
    row['case'] = case
    row['day'] = day
    row['slice'] = slice_
    return row

In [None]:
sub_df = pd.read_csv('../input/uw-madison-gi-tract-image-segmentation/sample_submission.csv')
if not len(sub_df):
    debug = True
    sub_df = pd.read_csv('../input/uw-madison-gi-tract-image-segmentation/train.csv')[-1000*3:]
    sub_df = sub_df.drop(columns=['class','segmentation']).drop_duplicates()
else:
    debug = False
    sub_df = sub_df.drop(columns=['class','predicted']).drop_duplicates()
sub_df = sub_df.apply(lambda x: get_metadata(x),axis=1)
print(f'debug:{debug}')
sub_df.head()


In [None]:
if debug:
    paths = glob(f'/kaggle/input/uw-madison-gi-tract-image-segmentation/train/**/*.png',recursive=True)
#     paths = sorted(paths)
else:
    paths = glob(f'/kaggle/input/uw-madison-gi-tract-image-segmentation/test/**/*.png',recursive=True)
#     paths = sorted(paths)
path_df = pd.DataFrame(paths, columns=['image_path'])
path_df = path_df.apply(lambda x: path2info(x),axis=1)
path_df.head()


## Produce 2d data list for MONAI DataSet

In [None]:
test_df = sub_df.merge(path_df, on=['case','day','slice'], how='left')
test_df["case_id_str"] = test_df["id"].apply(lambda x: x.split("_", 2)[0])
test_df["day_num_str"] = test_df["id"].apply(lambda x: x.split("_", 2)[1])
test_df["slice_id"] = test_df["id"].apply(lambda x: x.split("_", 2)[2])
test_df.head()

## Prepare Transforms, Dataset, DataLoader

In [None]:
class cfg:
    img_size = (160, 160)
    in_channels = 1
    out_channels = 4
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    weights = '/kaggle/input/uwmgi-models-2d/best_metric_model_segmentation2d_dict.pth'
    batch_size = 1
    sw_batch_size = 4
print(cfg.device)

In [None]:
test_data=test_df['image_path'].to_list()

In [None]:
import monai.transforms as T
val_transforms = T.Compose(
    [
        T.LoadImage(image_only=True),
        T.AddChannel(),
#         T.Spacingd(keys=["img", "seg"],pixdim=(1.5,1.5),mode=("bilinear", "nearest"),allow_missing_keys=False),
        T.CropForeground(),
        T.Resize(spatial_size=cfg.img_size,mode="bilinear"),
        T.EnsureType(),
#         T.ToDevice(cfg.device),
    ]
)

test_ds = CacheDataset(
        data=test_data,
        transform=val_transforms,
        cache_rate=0.0,
        num_workers=2,
    )

test_dataloader = DataLoader(
    test_ds,
    batch_size=cfg.batch_size,
    num_workers=2,
    pin_memory=True,
)

## Prepare Network

In [None]:
model = UNet(
    spatial_dims=2,
    in_channels=cfg.in_channels,
    out_channels=cfg.out_channels,
    channels=(16, 32, 64, 128, 256),
    strides=(2, 2, 2, 2),
    num_res_units=2,
#     norm="batch",
).to(cfg.device)

## Infer

In [None]:
# ref.: https://www.kaggle.com/stainsby/fast-tested-rle
def rle_encode(img):
    """ TBD
    
    Args:
        img (np.array): 
            - 1 indicating mask
            - 0 indicating background
    
    Returns: 
        run length as string formated
    """

    pixels = img.flatten()
    pixels = np.concatenate([[0], pixels, [0]])

    runs = np.where(pixels[1:] != pixels[:-1])[0] + 1
    runs[1::2] -= runs[::2]
    L=len(runs)
    if L==0:
        return np.nan
    L=int(L/2)
    area=0
    for i in range(L):
        area=area+runs[i*2+1]
#     print(area)
    if area>10:
        return ' '.join(str(x) for x in runs)
    else:
        return np.nan


In [None]:
a=0
cfg.device

In [None]:
import matplotlib.pyplot as plt
model.load_state_dict(torch.load(cfg.weights,map_location=cfg.device))
post_trans=T.Compose([
    T.SqueezeDim(),
    T.AsDiscrete(argmax=True),
    T.SqueezeDim(),
   
    ])
resize_266= T.Compose([
    T.AddChannel(),
    T.Resize(spatial_size=(266,266),mode='nearest'),
    T.SqueezeDim(),
    ])
resize_276= T.Compose([
    T.AddChannel(),
    T.Resize(spatial_size=(276,276),mode='nearest'),
    T.SqueezeDim(),
    ])
resize_360= T.Compose([
    T.AddChannel(),
    T.Resize(spatial_size=(360,310),mode='nearest'),
    T.SqueezeDim(),
    ])

plt.figure(figsize=(20,20))
submit = pd.DataFrame()
def get_slice_id(path):
    p=path.rsplit('/',3)
    slice_id=p[1]+'_'+p[3][:10]
    width=int(p[3][-21:-18])
    height=int(p[3][-17:-14])
    return slice_id,width,height
for i,batch in enumerate(test_dataloader):
#     if i<a:
#         continue
#     a=a+10
#     print(test_data[i])
    batch=batch.to(cfg.device)
    slice_id,width,height=get_slice_id(test_data[i])
    img=batch.squeeze()
    
#     print(batch.shape)
#     img=resize_266(img)
#     print(img.shape)
    pred160=model(batch)
    
    pred=post_trans(pred160)
    if (width==266 and height==266):
        pred=resize_266(pred)
        img=resize_266(img)
    elif (width==360 and height==310):
        pred=resize_266(pred)
        img=resize_266(img)
    elif (width==276 and height==276):
        pred=resize_276(pred)
        img=resize_276(img)
    else:
        print(width,height)

    pred_stomach=torch.zeros_like(pred)
    pred_large_bowel=torch.zeros_like(pred)
    pred_small_bowel=torch.zeros_like(pred)

    pred_stomach[pred==1]=1
    pred_large_bowel[pred==2]=1
    pred_small_bowel[pred==3]=1
    rle_stomach=rle_encode(pred_stomach.cpu())
    rle_large_bowel=rle_encode(pred_large_bowel.cpu())
    rle_small_bowel=rle_encode(pred_small_bowel.cpu())
#     print(rle_stomach)
#     print(rle_large_bowel)
#     print(rle_small_bowel)
#     plt.subplot(131)
#     plt.imshow(img,cmap='gray')
#     plt.subplot(132)
#     plt.imshow(pred)
#     plt.subplot(133)
#     plt.imshow(img)
#     plt.imshow(pred,alpha=0.4)

    submit=submit.append({"id":slice_id,"class":'stomach','predicted':rle_stomach},ignore_index=True)
    submit=submit.append({"id":slice_id,"class":'large_bowel','predicted':rle_large_bowel},ignore_index=True)
    submit=submit.append({"id":slice_id,"class":'small_bowel','predicted':rle_small_bowel},ignore_index=True)

#     break

In [None]:
submit.dropna().head()

In [None]:
# Fix sub error, refers to: https://www.kaggle.com/competitions/uw-madison-gi-tract-image-segmentation/discussion/320541
if not debug:
    sub_df = pd.read_csv('../input/uw-madison-gi-tract-image-segmentation/sample_submission.csv')
    del sub_df['predicted']
    sub_df = sub_df.merge(submit, on=['id','class'])
    sub_df.to_csv('submission.csv',index=False)
else:
    submit.to_csv('submission.csv', index=False)

In [None]:
submit.dropna().head()