# [UW-Madison GI Tract Image Segmentation](https://www.kaggle.com/competitions/uw-madison-gi-tract-image-segmentation/)
> Track healthy organs in medical scans to improve cancer treatment

<img src="https://storage.googleapis.com/kaggle-competitions/kaggle/27923/logos/header.png?t=2021-06-02-20-30-25">

In [None]:
!pip install -q scikit-learn==1.0.0

# Reference
Check this amazing notebook, [How To Create TFRecords](https://www.kaggle.com/cdeotte/how-to-create-tfrecords) by [Chris Deotte](https://www.kaggle.com/cdeotte)

# How to Create TFRecord

In [None]:
SEED  = 101
FOLDS = 40
IMAGE_SIZE = None
CHANNELS = 3
STRIDE = 2

# Importing Packages

In [None]:
import numpy as np 
import pandas as pd 
import os, shutil
from glob import glob
from sklearn.cluster import KMeans
from tqdm.notebook import tqdm
from sklearn.preprocessing import LabelEncoder
from tqdm.notebook import tqdm
tqdm.pandas()

# visualization
import cv2
import matplotlib.pyplot as plt
from matplotlib.patches import Rectangle

# Utility

## Mask

In [None]:
# ref: https://www.kaggle.com/paulorzp/run-length-encode-and-decode
def rle_decode(mask_rle, shape):
    '''
    mask_rle: run-length as string formated (start length)
    shape: (height,width) of array to return 
    Returns numpy array, 1 - mask, 0 - background

    '''
    s = np.asarray(mask_rle.split(), dtype=int)
    starts = s[0::2] - 1
    lengths = s[1::2]
    ends = starts + lengths
    img = np.zeros(shape[0]*shape[1], dtype=np.uint8)
    for lo, hi in zip(starts, ends):
        img[lo:hi] = 1
    return img.reshape(shape)  # Needed to align to RLE direction


# ref.: https://www.kaggle.com/stainsby/fast-tested-rle
def rle_encode(img):
    '''
    img: numpy array, 1 - mask, 0 - background
    Returns run length as string formated
    '''
    pixels = img.flatten()
    pixels = np.concatenate([[0], pixels, [0]])
    runs = np.where(pixels[1:] != pixels[:-1])[0] + 1
    runs[1::2] -= runs[::2]
    return ' '.join(str(x) for x in runs)


## Metadata

In [None]:
def get_metadata(row):
    data = row['id'].split('_')
    case = int(data[0].replace('case',''))
    day = int(data[1].replace('day',''))
    slice_ = int(data[-1])
    row['case'] = case
    row['day'] = day
    row['slice'] = slice_
    return row

def path2info(row):
    path = row['image_path']
    data = path.split('/')
    slice_ = int(data[-1].split('_')[1])
    case = int(data[-3].split('_')[0].replace('case',''))
    day = int(data[-3].split('_')[1].replace('day',''))
    width = int(data[-1].split('_')[2])
    height = int(data[-1].split('_')[3])
    row['height'] = height
    row['width'] = width
    row['case'] = case
    row['day'] = day
    row['slice'] = slice_
    return row

## Visualization

In [None]:
def id2mask(id_):
    idf = df[df['id']==id_]
    shape = (idf.height.item(), idf.width.item(), 3)
    mask = np.zeros(shape, dtype=np.uint8)
    rles = idf.segmentation.squeeze()
    for i, rle in enumerate(rles):
        if not pd.isna(rle):
            mask[..., i] = rle_decode(rle, shape[:2])
    return mask

def rgb2gray(mask):
    pad_mask = np.pad(mask, pad_width=[(0,0),(0,0),(1,0)])
    gray_mask = pad_mask.argmax(-1)
    return gray_mask

def gray2rgb(mask):
    rgb_mask = tf.keras.utils.to_categorical(mask, num_classes=4)
    return rgb_mask[..., 1:].astype(mask.dtype)

def load_img(path, size=IMAGE_SIZE):
    img = cv2.imread(path, cv2.IMREAD_UNCHANGED)
    if size is not None:
        img = cv2.resize(img, dsize=IMAGE_SIZE, interpolation=cv2.INTER_NEAREST)
#     img = img.astype('float32') # original is uint16
#     img = (img - img.min())/(img.max() - img.min())*255.0 # scale image to [0, 255]
#     img = img.astype('uint8')
    return img

def load_imgs(img_paths):
    imgs = [None]*3
    for i, img_path in enumerate(img_paths):
        img = load_img(img_path)
        imgs[i] = img
    return np.stack(imgs,axis=-1)

def show_img(img, mask=None):
#     plt.figure(figsize=(10,10))
    plt.imshow(img, cmap='bone')
    
    if mask is not None:
        # plt.imshow(np.ma.masked_where(mask!=1, mask), alpha=0.5, cmap='autumn')
        plt.imshow(mask, alpha=0.5)
        handles = [Rectangle((0,0),1,1, color=_c) for _c in [(0.667,0.0,0.0), (0.0,0.667,0.0), (0.0,0.0,0.667)]]
        labels = [ "Large Bowel", "Small Bowel", "Stomach"]
        plt.legend(handles,labels)
    plt.axis('off')

# Meta Data

## Train

In [None]:
df = pd.read_csv('../input/uwmgi-mask-dataset/train.csv')
df['segmentation'] = df.segmentation.fillna('')
df['rle_len'] = df.segmentation.map(len) # length of each rle mask

df2 = df.groupby(['id'])['segmentation'].agg(list).to_frame().reset_index() # rle list of each id
df2 = df2.merge(df.groupby(['id'])['rle_len'].agg(sum).to_frame().reset_index()) # total length of all rles of each id

df = df.drop(columns=['segmentation', 'class', 'rle_len'])
df = df.groupby(['id']).head(1).reset_index(drop=True)
df = df.merge(df2, on=['id'])
df['empty'] = (df.rle_len==0) # empty masks

## Remove Faulty Cases

In [None]:
fault1 = 'case7_day0'
fault2 = 'case81_day30'
df = df[~df['id'].str.contains(fault1) & ~df['id'].str.contains(fault2)].reset_index(drop=True)
df.head()

## 2.5D

In [None]:
for i in range(CHANNELS):
    df[f'image_path_{i:02}'] = df.groupby(['case','day'])['image_path'].shift(-i*STRIDE).fillna(method="ffill")
df['image_paths'] = df[[f'image_path_{i:02d}' for i in range(CHANNELS)]].values.tolist()
df.image_paths[0]

# Check Data

In [None]:
row=1; col=4
plt.figure(figsize=(5*col,5*row))
for i, id_ in enumerate(df[df['empty']==0].sample(frac=1.0)['id'].unique()[:row*col]):
    img = load_imgs(df[df['id']==id_].squeeze().image_paths).astype('float32')
    img/=img.max(axis=(0,1))
    mask = id2mask(id_)*255
    plt.subplot(row, col, i+1)
    i+=1
    show_img(img, mask=mask)
    plt.tight_layout()

# Data Split

In [None]:
from sklearn.model_selection import StratifiedGroupKFold
skf = StratifiedGroupKFold(n_splits=FOLDS, shuffle=True, random_state=SEED)
for fold, (train_idx, val_idx) in enumerate(skf.split(df, df['empty'], groups = df["case"])):
    df.loc[val_idx, 'fold'] = fold
display(df.groupby(['fold','empty'])['id'].count())

# TFRecord Data

In [None]:
import tensorflow as tf

def _bytes_feature(value):
    """Returns a bytes_list from a string / byte."""
    if isinstance(value, type(tf.constant(0))):
        value = value.numpy() # BytesList won't unpack a string from an EagerTensor.
    return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))

def _float_feature(value):
    """Returns a float_list from a float / double."""
    return tf.train.Feature(float_list=tf.train.FloatList(value=[value]))

def _int64_feature(value):
    """Returns an int64_list from a bool / enum / int / uint."""
    return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))

# Writng TFRecord (Train)

In [None]:
def train_serialize_example(feature0, feature1, feature2, feature3, feature4,
                           feature5, feature6, feature7, feature8):
    feature = {
      'image':_bytes_feature(feature0),
      'id':_bytes_feature(feature1),
      'case':_int64_feature(feature2),
      'day':_int64_feature(feature3), 
      'slice':_int64_feature(feature4),
      'height':_int64_feature(feature5),
      'width':_int64_feature(feature6),
      'empty':_int64_feature(feature7),
      'mask':_bytes_feature(feature8),
  }
    example_proto = tf.train.Example(features=tf.train.Features(feature=feature))
    return example_proto.SerializeToString()

In [None]:
show=True
os.makedirs('/tmp/uwmgi', exist_ok=True)
folds = df.fold.unique().tolist()
for fold in tqdm(folds): # create tfrecord for each fold
    fold_df = df.query("fold==@fold")
    if show:
        print(); print('Writing TFRecord of fold %i :'%(fold))  
    with tf.io.TFRecordWriter('/tmp/uwmgi/train%.2i-%i.tfrec'%(fold,fold_df.shape[0])) as writer:
        samples = fold_df.shape[0] # samples = 200
        it = tqdm(range(samples)) if show else range(samples)
        for k in it: # images in fold
            row = fold_df.iloc[k,:]
            image = load_imgs(row['image_paths'])
            image_id = row['id']
            case = row['case']
            day = row['day']
            slice_ = row['slice']
            height = row['height']
            width = row['width']
            empty = row['empty']
            mask = id2mask(image_id)*255 # [0, 1] => [0, 255]
            example  = train_serialize_example(
                image.tobytes(),
                str.encode(image_id),
                case,
                day,
                slice_,
                height,
                width,
                empty,
                mask.tobytes(),
                )
            writer.write(example)
        if show:
            filepath = '/tmp/uwmgi/train%.2i-%i.tfrec'%(fold,fold_df.shape[0])
            filename = filepath.split('/')[-1]
            filesize = os.path.getsize(filepath)/10**6
            print(filename,':',np.around(filesize, 2),'MB')

# Reading TFRecord

In [None]:
import re, math
def decode_image(data, height, width, target_size=(224, 224)):
    image = tf.io.decode_raw(data, out_type=tf.uint16)
    image = tf.reshape(image, [height, width, 3]) # explicit size needed for TPU
    image = tf.image.resize(image, target_size, method='nearest')
    image = tf.cast(image,tf.float32) 
    image = image / tf.reduce_max(image)
    return image
def decode_mask(data, height, width, target_size=(224, 224)):    
    mask = tf.io.decode_raw(data, out_type=tf.uint8)
    mask = tf.reshape(mask, [height, width, 3]) # explicit size needed for TPU  
    mask = tf.image.resize(mask, target_size, method='nearest')
    return mask

def read_labeled_tfrecord(example):
    LABELED_TFREC_FORMAT = {
        "image" : tf.io.FixedLenFeature([], tf.string), # tf.string means bytestring
        "height" : tf.io.FixedLenFeature([], tf.int64),
        "width" : tf.io.FixedLenFeature([], tf.int64),
        "mask" : tf.io.FixedLenFeature([], tf.string),  # shape [] means single element
    }
    example = tf.io.parse_single_example(example, LABELED_TFREC_FORMAT)
    height = example['height']
    width = example['width']
    image = decode_image(example['image'], height, width)
    mask = decode_mask(example['mask'], height, width)
    return image, mask # returns a dataset of (image, label) pairs

def load_dataset(fileids, labeled=True, ordered=False):
    # Read from TFRecords. For optimal performance, reading from multiple files at once and
    # disregarding data order. Order does not matter since we will be shuffling the data anyway.

    ignore_order = tf.data.Options()
    if not ordered:
        ignore_order.experimental_deterministic = False # disable order, increase speed

    dataset = tf.data.TFRecordDataset(fileids, num_parallel_reads=AUTO) # automatically interleaves reads from multiple files
    dataset = dataset.with_options(ignore_order) # uses data as soon as it streams in, rather than in its original order
    dataset = dataset.map(read_labeled_tfrecord)
    # returns a dataset of (image, label) pairs if labeled=True or (image, id) pairs if labeled=False
    return dataset

def get_training_dataset():
    dataset = load_dataset(TRAINING_FILENAMES, labeled=True)
    dataset = dataset.repeat() # the training dataset must repeat for several epochs
    dataset = dataset.shuffle(20, seed=SEED)
    dataset = dataset.batch(BATCH_SIZE)
    dataset = dataset.prefetch(AUTO) # prefetch next batch while training (autotune prefetch buffer size)
    return dataset

def count_data_items(fileids):
    # the number of data items is written in the id of the .tfrec files, i.e. flowers00-230.tfrec = 230 data items
    n = [int(re.compile(r"-([0-9]*)\.").search(fileid).group(1)) for fileid in fileids]
    return np.sum(n)

# Helper

In [None]:
def display_batch(batch, size=5):
    imgs, tars = batch
    plt.figure(figsize=(size*5, 5))
    for img_idx in range(size):
        plt.subplot(1, size, img_idx+1)
        plt.imshow(imgs[img_idx,], cmap='bone')
        plt.imshow(tars[img_idx,], alpha=0.5)
        plt.xticks([])
        plt.yticks([])
    plt.tight_layout()
    plt.show() 

# Total Images

In [None]:
BATCH_SIZE = 32
AUTO = tf.data.experimental.AUTOTUNE
TRAINING_FILENAMES = tf.io.gfile.glob('/tmp/uwmgi/train*.tfrec')
TEST_FILENAMES     = tf.io.gfile.glob('/tmp/uwmgi/test*.tfrec')
print('There are %i train & %i test images'%(count_data_items(TRAINING_FILENAMES), count_data_items(TEST_FILENAMES)))

# Display Image from TFRecord

In [None]:
# DISPLAY TRAIN IMAGES
training_dataset = get_training_dataset()
training_dataset = training_dataset.unbatch().batch(20)
train_batch = next(iter(training_dataset))
display_batch(train_batch, 5);

In [None]:
img, label = train_batch
np.unique(label.numpy(), return_counts=True)

# Compress Files

In [None]:
shutil.make_archive('/kaggle/working/tfrecord',
                    'zip',
                    '/tmp',
                    'uwmgi')