## Configuration

In [None]:
import os
from itertools import chain

GRADIENT = os.path.exists('train')
KAGGLE = os.path.exists('../input')
model_name = 'resnet34-v1-20-epoch.pkl'
SEED = 42
VAL_PCT = 0.2
INFER = True

## Libraries and Data

In [None]:
if KAGGLE:
    !cp -r ../input/pytorch-segmentation-models-lib/ ./
    !pip install -q ./pytorch-segmentation-models-lib/pretrainedmodels-0.7.4/pretrainedmodels-0.7.4
    !pip install -q ./pytorch-segmentation-models-lib/efficientnet_pytorch-0.6.3/efficientnet_pytorch-0.6.3
    !pip install -q ./pytorch-segmentation-models-lib/timm-0.4.12-py3-none-any.whl
    !pip install -q ./pytorch-segmentation-models-lib/segmentation_models_pytorch-0.2.0-py3-none-any.whl

In [None]:
from fastai.vision.all import *

import gc

In [None]:
if KAGGLE:
    data_path = '../input/uw-madison-gi-tract-image-segmentation/'
    model_name = '../input/uw-madison-models/' + model_name
elif GRADIENT:
    data_path = ''
    model_name = 'models/' + model_name 

In [None]:
path = Path(data_path+'train')
test_path = Path(data_path+'test')
train = pd.read_csv(data_path+'train.csv', low_memory=False)
sample_submission = pd.read_csv(data_path+'sample_submission.csv', low_memory=False)
fnames = get_image_files(path)

## Helper Functions

In [None]:
# Extract case id from fname
def get_case_id(fname):
    if KAGGLE: i = 5
    elif GRADIENT: i = 2
    return fname.parts[i] + '_' + fname.parts[i+2][:10]

def check_file(file_id, fname):
    case_id, day, _, slice_no = file_id.split('_')
    if case_id == fname.parts[1] and day == fname.parts[2].split('_')[1] and slice_no in fname.parts[-1]:
        return True
    return False

def get_file(file_id):
    return fnames.filter(lambda f: check_file(not_null_train.id[0], f))[0]

# https://www.kaggle.com/code/dschettler8845/uwm-gi-tract-image-segmentation-eda
def get_custom_df(df, fnames, root):
    
    df = df.copy()
    
    # 1. Get Case-ID as a column (str and int)
    df["case_id_str"] = df["id"].apply(lambda x: x.split("_", 2)[0])
    df["case_id"] = df["id"].apply(lambda x: int(x.split("_", 2)[0].replace("case", "")))

    # 2. Get Day as a column
    df["day_num_str"] = df["id"].apply(lambda x: x.split("_", 2)[1])
    df["day_num"] = df["id"].apply(lambda x: int(x.split("_", 2)[1].replace("day", "")))

    # 3. Get Slice Identifier as a column
    df["slice_id"] = df["id"].apply(lambda x: x.split("_", 2)[2])

    # 4. Get full file paths for the representative scans
    df["_partial_fname"] = (root+'/'+ # /kaggle/input/uw-madison-gi-tract-image-segmentation/train/
                          df["case_id_str"]+"/"+ # .../case###/
                          df["case_id_str"]+"_"+df["day_num_str"]+ # .../case###_day##/
                          "/scans/"+df["slice_id"]) # .../slice_####
    
    _tmp_merge_df = pd.DataFrame({"_partial_fname":[str(x).rsplit("_",4)[0] for x in fnames], "fname": fnames})
    df = df.merge(_tmp_merge_df, on="_partial_fname").drop(columns=["_partial_fname"])
    
    # Minor cleanup of our temporary workaround
    del _tmp_merge_df; gc.collect(); gc.collect()
    
    # 5. Get slice dimensions from filepath (int in pixels)
    df["slice_h"] = df["fname"].apply(lambda x: int(str(x)[:-4].rsplit("_",4)[1]))
    df["slice_w"] = df["fname"].apply(lambda x: int(str(x)[:-4].rsplit("_",4)[2]))

    # 6. Pixel spacing from filepath (float in mm)
    df["px_spacing_h"] = df["fname"].apply(lambda x: float(str(x)[:-4].rsplit("_",4)[3]))
    df["px_spacing_w"] = df["fname"].apply(lambda x: float(str(x)[:-4].rsplit("_",4)[4]))

    # 7. Merge 3 Rows Into A Single Row (As This/Segmentation-RLE Is The Only Unique Information Across Those Rows)
    l_bowel_train_df = df[df["class"]=="large_bowel"][["id", "segmentation"]].rename(columns={"segmentation":"lb_seg_rle"})
    s_bowel_train_df = df[df["class"]=="small_bowel"][["id", "segmentation"]].rename(columns={"segmentation":"sb_seg_rle"})
    stomach_train_df = df[df["class"]=="stomach"][["id", "segmentation"]].rename(columns={"segmentation":"st_seg_rle"})
    df = df.merge(l_bowel_train_df, on="id", how="left")
    df = df.merge(s_bowel_train_df, on="id", how="left")
    df = df.merge(stomach_train_df, on="id", how="left")
    df = df.drop_duplicates(subset=["id",]).reset_index(drop=True)
    df["lb_seg_flag"] = df["lb_seg_rle"].apply(lambda x: not pd.isna(x))
    df["sb_seg_flag"] = df["sb_seg_rle"].apply(lambda x: not pd.isna(x))
    df["st_seg_flag"] = df["st_seg_rle"].apply(lambda x: not pd.isna(x))
    df["n_segs"] = df["lb_seg_flag"].astype(int)+df["sb_seg_flag"].astype(int)+df["st_seg_flag"].astype(int)

    # 8. Reorder columns to the a new ordering (drops class and segmentation as no longer necessary)
    df = df[["id", "fname", "n_segs",
             "lb_seg_rle", "lb_seg_flag",
             "sb_seg_rle", "sb_seg_flag", 
             "st_seg_rle", "st_seg_flag",
             "slice_h", "slice_w", "px_spacing_h", 
             "px_spacing_w", "case_id_str", "case_id", 
             "day_num_str", "day_num", "slice_id",]]

    return df

# ref: https://www.kaggle.com/paulorzp/run-length-encode-and-decode
# modified from: https://www.kaggle.com/inversion/run-length-decoding-quick-start
def rle_decode(mask_rle, shape, color=1):
    """ TBD
    
    Args:
        mask_rle (str): run-length as string formated (start length)
        shape (tuple of ints): (height,width) of array to return 
    
    Returns: 
        Mask (np.array)
            - 1 indicating mask
            - 0 indicating background

    """
    # Split the string by space, then convert it into a integer array
    s = np.array(mask_rle.split(), dtype=int)

    # Every even value is the start, every odd value is the "run" length
    starts = s[0::2] - 1
    lengths = s[1::2]
    ends = starts + lengths

    # The image image is actually flattened since RLE is a 1D "run"
    if len(shape)==3:
        h, w, d = shape
        img = np.zeros((h * w, d), dtype=np.float32)
    else:
        h, w = shape
        img = np.zeros((h * w,), dtype=np.float32)

    # The color here is actually just any integer you want!
    for lo, hi in zip(starts, ends):
        img[lo : hi] = color
        
    # Don't forget to change the image back to the original shape
    return img.reshape(shape)

def get_image(row):
    img = np.array(Image.open(row['fname']))
    img = np.interp(img, [np.min(img), np.max(img)], [0,255])
    return img
    # return row['fname']
                   

def get_mask(row):
    mask = np.zeros((row['slice_w'], row['slice_h'], 3))
    if row['lb_seg_flag']:
        mask[..., 0] += rle_decode(row['lb_seg_rle'], shape=(row['slice_w'], row['slice_h']), color=255)
    if row['sb_seg_flag']:
        mask[..., 1] += rle_decode(row['sb_seg_rle'], shape=(row['slice_w'], row['slice_h']), color=255)
    if row['st_seg_flag']:
        mask[..., 2] += rle_decode(row['st_seg_rle'], shape=(row['slice_w'], row['slice_h']), color=255)
        
    return mask.astype(np.uint8)

## Prepare Data

In [None]:
root = data_path+'test'
test_fnames = get_image_files(test_path)

if not test_fnames:
    test_fnames = fnames
    root = data_path+'train'

test = pd.DataFrame({
    'id': chain.from_iterable([[get_case_id(fname)]*3 for fname in test_fnames]),
    'class': chain.from_iterable([['large_bowel', 'small_bowel', 'stomach'] for _ in test_fnames]),
    'segmentation': chain.from_iterable([[np.nan]*3 for _ in test_fnames]),
})

test = get_custom_df(test, test_fnames, root)
train = get_custom_df(train, fnames, data_path+'train')

## Validation Set

In [None]:
valid_pct = 0.2
set_seed(SEED, True)

### Custom validation

In [None]:
np.random.seed(SEED)

cases = train.case_id.unique()
n_cases = len(cases)
random_cases = np.random.choice(cases, int(n_cases*valid_pct), replace=False)

train['is_valid'] = False
train.loc[train.case_id.isin(random_cases), 'is_valid'] = True

days = train.loc[~train['is_valid'], 'day_num'].unique()
n_days = len(days)
random_days = np.random.choice(days, int(n_days*valid_pct), replace=False)

train.loc[train.case_id.isin(random_days), 'is_valid'] = True

train['is_valid'].mean()

### GroupValidation by Cases

In [None]:
from sklearn.model_selection import GroupShuffleSplit

In [None]:
gss = GroupShuffleSplit(n_splits=1, test_size=valid_pct, random_state=SEED)
train_idx, val_idx = [(train_idx, val_idx) for (train_idx, val_idx) in gss.split(train, train, train['case_id'])][0]

train['is_valid'] = False
train.loc[val_idx, 'is_valid'] = True

train['is_valid'].mean()

## Use Datablock API

In [None]:
@ToTensor
def encodes(self, o:PILMask): return o._tensor_cls(image2tensor(o))

In [None]:
@Normalize
def encodes(self, o:TensorMask): return o / 255

@Normalize
def decodes(self, o:TensorMask): 
    f = to_cpu if o.device.type=='cpu' else noop
    return f((o * 255).long())

In [None]:
import matplotlib.patches as mpatches

@typedispatch
def show_batch(x:TensorImage, y:TensorMask, samples, ctxs=None, max_n=6, nrows=None, ncols=2, figsize=None, **kwargs):
    if figsize is None: figsize = (ncols*3, max_n//ncols * 3)
    if ctxs is None: ctxs = get_grid(max_n, nrows=nrows, ncols=ncols, figsize=figsize)
    for i,ctx in enumerate(ctxs): 
        x_i = x[i] / x[i].max()
        show_image(x_i, ctx=ctx, cmap='gray', **kwargs)
        show_image(y[i], ctx=ctx, cmap='Spectral_r', alpha=0.35, **kwargs)
        red_patch = mpatches.Patch(color='red', label='lb')
        green_patch = mpatches.Patch(color='green', label='sb')
        blue_patch = mpatches.Patch(color='blue', label='st')
        ctx.legend(handles=[red_patch, green_patch, blue_patch], fontsize=figsize[0]/2)

In [None]:
def get_aug_dls(aug=[], method='squish', bs=16, sample=False, show=True):
    batch_tfms = [Normalize.from_stats(*imagenet_stats)]
    if aug: batch_tfms = [*aug] + batch_tfms
    
    db = DataBlock((ImageBlock(cls=PILImageBW), MaskBlock),
                    get_x=get_image,
                   get_y=get_mask,
                   splitter = ColSplitter(),
                   item_tfms=[Resize(224, method=method)],
                   batch_tfms=batch_tfms)
    
    if sample:
        dev = train.sample(frac=0.2, random_state=SEED)
    else:
        dev = train
        
    dls = db.dataloaders(dev, bs=bs, shuffle=True)
    dls.rng.seed(SEED)
    
    if show:
        dls.show_batch(nrows=bs//4, ncols=4, max_n=bs, figsize=(12, 12))
        
    return dls, dev

## Metrics

In [None]:
from scipy.spatial.distance import directed_hausdorff

def mod_acc(inp, targ):
    targ = targ.squeeze(1)
    mask = targ != 0
    if mask.sum() == 0:
        mask = targ == 0
    return (torch.where(sigmoid(inp) > 0.5, 1, 0)[mask]==targ[mask]).float().mean().item()

def dice_coeff(inp, targ):
    inp = np.where(inp.cpu().detach().numpy() > 0.5, 1, 0)
    targ = targ.cpu().detach().numpy()
    eps = 1e-5
    I = (targ * inp).sum((2, 3))
    U =  targ.sum((2,3)) + inp.sum((2, 3))
    return ((2.*I+eps)/(U+eps)).mean((1, 0))

# def dice_coeff(inp, targ):
#     if torch.is_tensor(inp):
#         inp = torch.where(sigmoid(inp) > 0.5, 1, 0).cpu().detach().numpy().astype(np.uint8)
#     if torch.is_tensor(targ):
#         targ = targ.cpu().detach().numpy().astype(np.uint8)
#     # mask = targ == 1
#     # I = (inp[mask] == targ[mask]).sum((2, 3))
#     eps = 1e-5
#     I = (targ & inp).sum((2, 3))
#     # U = inp.sum((2, 3)) + targ.sum((2, 3))
#     U = (targ | inp).sum((2, 3))
#     return ((2*I)/(U+I+1) + (U==0)).mean((1, 0))

# def dice_coeff2(inp, targ, thr=0.5, dim=(2,3), epsilon=0.001):
#     targ = targ.to(torch.float32)
#     inp = (inp>thr).to(torch.float32)
#     inter = (targ*inp).sum(dim=dim)
#     den = targ.sum(dim=dim) + inp.sum(dim=dim)
#     dice = ((2*inter+epsilon)/(den+epsilon)).mean(dim=(1,0))
#     return dice

def hd_dist_per_slice(inp, targ):    
    inp = np.argwhere(inp) / np.array(inp.shape)
    targ = np.argwhere(targ) / np.array(targ.shape)
    # if len(targ) == 0:
    #     inp = 1 - inp
    #     targ = 1 - targ
    haussdorf_dist = 1 - directed_hausdorff(inp, targ, SEED)[0]
    return haussdorf_dist if haussdorf_dist > 0 else 0

def hd_dist(inp, targ):
    inp = np.where(inp.cpu().detach().numpy() > 0.5, 1, 0)
    targ = targ.cpu().detach().numpy()
    
    return np.mean([np.mean([hd_dist_per_slice(inp[i, j], targ[i, j]) for j in range(3)]) for i in range(len(inp))])

def custom_metric(inp, targ):
    hd_score_per_batch = hd_dist(inp, targ)
    dice_score_per_batch = dice_coeff(inp, targ)
        
    return 0.4*dice_score_per_batch + 0.6*hd_score_per_batch


def custom_loss(inp, targ):
    return nn.BCEWithLogitsLoss(inp, targ.float())

In [None]:
# custom_metric(torch.zeros_like(y), torch.zeros_like(y))

## Loss

In [None]:
class DiceBCEModule(Module):
    def __init__(self, eps:float=1e-5, from_logits=True):
        store_attr()
        
    def forward(self, inp:Tensor, targ:Tensor) -> Tensor:
        if self.from_logits: 
            bce_loss = nn.BCEWithLogitsLoss()(inp, targ)
            inp = torch.sigmoid(inp)
        inp = inp.view(-1)
        targ = targ.view(-1)
        
        intersection = (inp * targ).sum()                            
        dice = (2.*intersection + self.eps)/(inp.sum() + targ.sum() + self.eps)  
        
        return 0.5*(1 - dice) + 0.5*bce_loss


class DiceBCELoss(BaseLoss):
    def __init__(self, *args, eps:float=1e-5, from_logits=True, **kwargs):
        super().__init__(DiceBCEModule, *args, eps=eps, from_logits=from_logits, flatten=False, is_2d=True, floatify=True, **kwargs)
    
    def decodes(self, x:Tensor) -> Tensor:
        "Converts model output to target format"
        return (x>self.thresh).long()

    def activation(self, x:Tensor) -> Tensor:
        "`nn.BCEWithLogitsLoss`'s fused activation function applied to model output"
        return torch.sigmoid(x)

## Learner

In [None]:
def splitter(model):
    return [params(model.encoder), params(model.decoder)]

In [None]:
def build_model(encoder_name):
    model = smp.Unet(
        encoder_name=encoder_name,      # choose encoder, e.g. mobilenet_v2 or efficientnet-b7
        encoder_weights="imagenet",     # use `imagenet` pre-trained weights for encoder initialization
        in_channels=3,                  # model input channels (1 for gray-scale images, 3 for RGB, etc.)
        classes=3,        # model output channels (number of classes in your dataset)
        activation=None,
    )
    model.to('cuda')
    return model

In [None]:
if not INFER:
    dls, dev = get_aug_dls(aug_transforms(), sample=False, bs=16, show=False)
    unet = build_model('resnet34')
    learn = Learner(dls, unet, metrics=[dice_coeff, hd_dist, custom_metric], loss_func=DiceBCELoss()).to_fp16()

## Training

In [None]:
if not INFER:
    learn.freeze()
    learn.lr_find()

In [None]:
lr = 1e-3

In [None]:
if not INFER:
    learn.fit_one_cycle(1, slice(lr))

In [None]:
if not INFER:
    learn.unfreeze()
    learn.fit_one_cycle(10, slice(lr/400, lr/10))

## Show Results

In [None]:
if not INFER:
    learn.show_results(max_n=16)

## Save Model

In [None]:
if not INFER:
    learn.export(model_name)

## Test inference 

In [None]:
if INFER:
    learn = load_learner(model_name)

In [None]:
from tqdm import tqdm

# @torch.no_grad()
# def get_preds(learn, dl, thresh=0.5):
#     learn.model.eval()
#     preds = []
#     for b in tqdm(dl):
#         b[0].to('cuda')
#         b_preds = (sigmoid(learn.model(b[0])) > 0.5).permute(0, 2, 3, 1).cpu().detach().numpy().astype(np.uint8)
#         preds.append(b_preds)
#         torch.cuda.empty_cache()
#         gc.collect()
#     preds_arr = np.concatenate(preds)
#     del preds; gc.collect()
#     return preds_arr


In [None]:
if INFER:
    test_dl = learn.dls.test_dl(test, shuffle=False).to('cuda')
    b = test_dl.one_batch()

In [None]:
if INFER:
    learn.model = learn.model.cuda()
    learn.model.eval()
    test_preds = np.zeros((test.shape[0], b[0].shape[2], b[0].shape[3], b[0].shape[1]), dtype=np.uint8)
    with torch.no_grad():
        for i, b in enumerate(tqdm(test_dl)):
            b[0].to('cuda')
            b_preds = (sigmoid(learn.model(b[0])) > 0.5).permute(0, 2, 3, 1).cpu().detach().numpy().astype(np.uint8)
            test_preds[i*16:i*16+16] = b_preds
            torch.cuda.empty_cache()
            gc.collect()

In [None]:
# if INFER:
#     learn.model = learn.model.cuda()
#     test_dl = learn.dls.test_dl(test, shuffle=False).to('cuda')
#     test_preds = get_preds(learn, test_dl)

In [None]:
# Source: https://www.kaggle.com/code/clemchris/gi-seg-pytorch-train-infer

def mask2rle(mask):
    """
    img: numpy array, 1 - mask, 0 - background
    Returns run length as string formated
    """
    mask = np.array(mask)
    pixels = mask.flatten()
    pad = np.array([0])
    pixels = np.concatenate([pad, pixels, pad])
    runs = np.where(pixels[1:] != pixels[:-1])[0] + 1
    runs[1::2] -= runs[::2]

    return " ".join(str(x) for x in runs)

In [None]:
import cv2

def get_rle_masks(preds, df):
    rle_masks = []
    for pred, width, height in tqdm(zip(preds, df['slice_w'], df['slice_h'])):
        upsized_mask = cv2.resize(pred, dsize=(height, width), interpolation=cv2.INTER_NEAREST)
        for i in range(3):
            rle_mask = mask2rle(upsized_mask[:, :, i])
            rle_masks.append(rle_mask)
    return rle_masks

In [None]:
if INFER:
    masks = get_rle_masks(test_preds, test)   

In [None]:
if INFER:
    submission = pd.DataFrame({
        'id': chain.from_iterable([[get_case_id(fname)]*3 for fname in test_fnames]),
        'class': chain.from_iterable([['large_bowel', 'small_bowel', 'stomach'] for _ in test_fnames]),
        'predicted': masks,
    })
    
    if sample_submission.shape[0] > 0:
        del sample_submission['predicted']
        submission = sample_submission.merge(submission, on=['id', 'class'])
    
    submission.to_csv('submission.csv', index=False)