In [None]:
import numpy as np
import pandas as pd
import cv2
import matplotlib.pyplot as plt
import tensorflow as tf
from matplotlib.patches import Rectangle
from sklearn.model_selection import StratifiedGroupKFold
from tqdm import tqdm
import imgaug.augmenters as iaa

In [None]:
def rle_decode(mask_rle, shape):
    s = np.asarray(mask_rle.split(), dtype=int)
    starts = s[0::2] - 1
    lengths = s[1::2]
    ends = starts + lengths
    img = np.zeros(shape[0]*shape[1], dtype=np.uint8)
    for lo, hi in zip(starts, ends):
        img[lo:hi] = 1
    return img.reshape(shape) 

In [None]:
def get_metadata(row):
    data = row['id'].split('_')
    case = int(data[0].replace('case',''))
    day = int(data[1].replace('day',''))
    slice_ = int(data[-1])
    row['case'] = case
    row['day'] = day
    row['slice'] = slice_
    return row


In [None]:
def path2info(row):
    path = row['image_path']
    data = path.split('/')
    slice_ = int(data[-1].split('_')[1])
    case = int(data[-3].split('_')[0].replace('case',''))
    day = int(data[-3].split('_')[1].replace('day',''))
    width = int(data[-1].split('_')[2])
    height = int(data[-1].split('_')[3])
    row['height'] = height
    row['width'] = width
    row['case'] = case
    row['day'] = day
    row['slice'] = slice_
    return row

In [None]:
def id2mask(id_):
    idf = df[df['id']==id_]
    shape = (idf.height.item(), idf.width.item(), 3)
    mask = np.zeros(shape, dtype=np.uint8)
    rles = idf.segmentation.squeeze()
    for i, rle in enumerate(rles):
        if not pd.isna(rle):
            mask[..., i] = rle_decode(rle, shape[:2])
    return mask

In [None]:
def rgb2gray(mask):
    pad_mask = np.pad(mask, pad_width=[(0,0),(0,0),(1,0)])
    gray_mask = pad_mask.argmax(-1)
    return gray_mask

def gray2rgb(mask):
    rgb_mask = tf.keras.utils.to_categorical(mask, num_classes=4)
    return rgb_mask[..., 1:].astype(mask.dtype)

In [None]:
def load_img(path, size=IMAGE_SIZE):
    img = cv2.imread(path, cv2.IMREAD_UNCHANGED)
    if size is not None:
        img = cv2.resize(img, dsize=IMAGE_SIZE, interpolation=cv2.INTER_NEAREST)
    return img

def load_imgs(img_paths):
    imgs = [None]*3
    for i, img_path in enumerate(img_paths):
        img = load_img(img_path)
        imgs[i] = img
    return np.stack(imgs,axis=-1)

In [None]:
def show_img(img, mask=None):
    plt.imshow(img, cmap='bone')
    if mask is not None:
        plt.imshow(mask, alpha=0.5)
        handles = [Rectangle((0,0),1,1, color=_c) for _c in [(0.667,0.0,0.0), (0.0,0.667,0.0), (0.0,0.0,0.667)]]
        labels = [ "Large Bowel", "Small Bowel", "Stomach"]
        plt.legend(handles,labels)
    plt.axis('off')

In [None]:
# Load CSV file
df = pd.read_csv('../input/uwmgi-mask-dataset/train.csv')
df['segmentation'] = df.segmentation.fillna('')
df['rle_len'] = df.segmentation.map(len)

In [None]:
# Grouping and preprocessing
df2 = df.groupby(['id'])['segmentation'].agg(list).to_frame().reset_index()
df2 = df2.merge(df.groupby(['id'])['rle_len'].agg(sum).to_frame().reset_index())
df = df.drop(columns=['segmentation', 'class', 'rle_len'])
df = df.groupby(['id']).head(1).reset_index(drop=True)
df = df.merge(df2, on=['id'])
df['empty'] = (df.rle_len==0)


In [None]:
# Creating image paths and processing folds
for i in range(CHANNELS):
    df[f'image_path_{i:02}'] = df.groupby(['case','day'])['image_path'].shift(-i*STRIDE).fillna(method="ffill")
df['image_paths'] = df[[f'image_path_{i:02d}' for i in range(CHANNELS)]].values.tolist()

# Displaying folds
skf = StratifiedGroupKFold(n_splits=FOLDS, shuffle=True, random_state=SEED)
for fold, (train_idx, val_idx) in enumerate(skf.split(df, df['empty'], groups = df["case"])):
    df.loc[val_idx, 'fold'] = fold
df.groupby(['fold','empty'])['id'].count()

In [None]:
# augmentations
augmentations = iaa.Sequential([
    iaa.ShiftScaleRotate(shift_limit=0.1, scale_limit=0.1, rotate_limit=45, p=0.5),
    iaa.Fliplr(p=0.5),
    iaa.Jitter(p=0.5)
])

for fold in tqdm(folds):
    fold_df = df.query("fold==@fold")
    if show:
        print(); print('Processing data for fold %i :' % fold)  
   
    samples = fold_df.shape[0]
    it = tqdm(range(samples)) if show else range(samples)
    for k in it:
        row = fold_df.iloc[k,:]
        image = load_imgs(row['image_paths'])
        image_id = row['id']
        mask = id2mask(image_id) * 255
        
        image_with_mask = np.concatenate([image, mask], axis=-1)
       
        augmented_image_with_mask = augmentations(images=image_with_mask)
