This notebook makes TFRecord files to train the models running on TPU.
The contents of the TFRecord files are as follows:

| Name | Type | Description |
| --- | --- | --- |
| id | bytes | sample ID taken from the 'id' column in 'train.csv', utf-8 encoded. |
| case number | int64 | case number taken from 'id' at caseNNN |
| day number | int64 | day number taken from 'id' at dayNN |
| slice number | int64 | slice number taken from 'id' at slice_NNNN |
| image | bytes | numpy save image bytes read from the associated file. | 
| mask | bytes | PNG format mask bytes generated from the 'segmentation' column in 'train.csv' |
| fold | int64 | fold number that this sample belongs to |
| height | int64 | slice height taken from the file name |
| width | int64 | slice width taken from the file name |
| space height | float32 | pixel spacing height taken from the file name |
| space width | float32 | pixel spacing width taken from the file name |
| large bowel dice coef | float32 | how well the model predicted for large bowel |
| small bowel dice coef | float32 | how well the model predicted for small bowel |
| stomach dice coef | float32 | how well the model predicted for stomach |
| slice count | int64 | number of slices for case/day |

I would like to:
* put as much information as possible.
* put the original size of images and masks, so it's possible to try various processing (e.g. resizing or padding) for the training. 

# Reference

* [AW-Madison: EDA & In Depth Mask Exploration](https://www.kaggle.com/code/andradaolteanu/aw-madison-eda-in-depth-mask-exploration)

# Preparation

In [None]:
DEBUG = False

In [None]:
import numpy as np
import pandas as pd
import tensorflow as tf
import os
import glob
from sklearn.model_selection import StratifiedGroupKFold
import matplotlib.pyplot as plt
import cv2
from io import BytesIO

print(tf.__version__)

In [None]:
NUM_TFREC_FILES = 15
N_FOLDS = 5

In [None]:
DATA_SRC = 'uw-madison-gi-tract-image-segmentation'
DATA_DIR = os.path.join('..', 'input', DATA_SRC)
TRAIN_DIR = os.path.join(DATA_DIR, 'train')
ANAL_DIR = 'uwmgi-image-segmentation-pred-analysis'
ANAL_FILE = 'analysis_V32.csv'
ANAL_PATH = os.path.join('..', 'input', ANAL_DIR)

# DataFrame

In [None]:
train_csv_path = os.path.join(DATA_DIR, 'train.csv')
train_df = pd.read_csv(train_csv_path)

train_df

In [None]:
train_unstack_df = \
    train_df \
        .set_index(['id', 'class']) \
        .unstack() \
        .reset_index()
train_unstack_df.columns = \
    ['id', 'large_bowel', 'small_bowel', 'stomach']

train_unstack_df

In [None]:
train_unstack_df[['case_no', 'day_no', 'slice_no']] = \
    train_unstack_df['id'] \
        .str \
        .extract(r'case(\d\d*)_day(\d\d*)_slice_(\d\d*)')

train_unstack_df

In [None]:
# transform() gets a Series of 'slice_no' for each group.
# Series.iloc[-1] returns the last element.
#
# https://stackoverflow.com/questions/56288949/
# how-to-access-the-last-element-in-a-pandas-series
train_unstack_df['slice_count'] = \
    train_unstack_df \
        .groupby(['case_no', 'day_no']) \
        ['slice_no'] \
        .transform(lambda x: x.iloc[-1])

train_unstack_df

In [None]:
file_path_pattern = os.path.join(TRAIN_DIR, '**', '*.png')
file_paths = glob.glob(file_path_pattern, recursive=True)
file_info_df = pd.DataFrame({"file_path": file_paths})

file_info_df

In [None]:
file_info_df['id'] = file_info_df['file_path'] \
    .str \
    .replace(
        pat=r'^.*/(case\d\d*)_(day\d\d*)/scans/(slice_\d\d*)_.*$',
        repl=r'\1_\2_\3', regex=True)

file_name_info_df = file_info_df['file_path'] \
    .str \
    .extract(
        r'slice_\d\d*_(\d\d*)_(\d\d*)_(\d\d*\.\d\d*)_(\d\d*\.\d\d*)') \
    .rename(columns={
        0: 'height', 1: 'width', 2: 'space_h', 3: 'space_w'})

file_info_df = pd.concat([file_info_df, file_name_info_df], axis=1)

file_info_df

In [None]:
anal_csv_path = os.path.join(ANAL_PATH, ANAL_FILE)
anal_df = pd.read_csv(anal_csv_path) \
    [['id', 'large_bowel_dice_coef',
      'small_bowel_dice_coef', 'stomach_dice_coef']]

anal_df

In [None]:
train_data_df = pd.merge(
    train_unstack_df,
    pd.merge(file_info_df, anal_df, how='inner', on='id'),
    how='inner', on='id')
train_data_df = train_data_df.fillna('')

train_data_df

# Folds

In [None]:
def has_mask(row):
    return \
        len(row['large_bowel']) > 0 or \
        len(row['small_bowel']) > 0 or \
        len(row['stomach']) > 0 

train_data_df['has_mask'] = train_data_df.apply(has_mask, axis=1)

train_data_df['has_mask']

In [None]:
sgkf = StratifiedGroupKFold(
    n_splits=N_FOLDS, shuffle=True, random_state=53)
train_data_len = len(train_data_df)
fold_X = np.arange(train_data_len)
fold_y = train_data_df['has_mask']
fold_groups = train_data_df['case_no']
fold_data = np.empty(train_data_len)

for fold_idx, (_, val_idx) in \
        enumerate(sgkf.split(fold_X, fold_y, fold_groups)):
    fold_data[val_idx] = fold_idx
    
train_data_df['fold'] = fold_data.astype(np.int32)

train_data_df

In [None]:
train_data_df.groupby(['fold', 'has_mask'])['id'].count()

In [None]:
train_data_df.to_csv('train_data.csv', index=False)

!head train_data.csv

# Image/Mask

In [None]:
# https://www.kaggle.com/code/andradaolteanu/aw-madison-eda-in-depth-mask-exploration
def read_cv2_image(path):
    '''Reads and converts the image.
    path: the full complete path to the .png file'''

    # Read image in a corresponding manner
    # convert int16 -> float32
    image = cv2.imread(path, cv2.IMREAD_UNCHANGED).astype('float32')
    # Scale to [0, 255]
    image = cv2.normalize(image, None, alpha = 0, beta = 255, 
                        norm_type = cv2.NORM_MINMAX, dtype = cv2.CV_32F)
    image = image.astype(np.uint8)
    
    return image

In [None]:
def read_image(file_path):
    image_cv2 = read_cv2_image(file_path)
    image_cv2_flatten = image_cv2.flatten()
    with BytesIO() as f:
        np.save(f, image_cv2_flatten)
        image_bytes = f.getvalue()
    tf_image_bytes = tf.constant(image_bytes)
    return tf_image_bytes

In [None]:
def make_zero_mask(height, width):
    return tf.zeros([height, width, 1], dtype=tf.uint8)

def make_non_zero_mask(rle, height, width):
    rle_splits = tf.strings.split(rle, sep=' ')
    rle_nums = tf.strings.to_number(rle_splits, out_type=tf.int32)
    # For example, '1 3 10 5'
    # starts = [1 10]
    # lengths = [3 5]
    starts = rle_nums[0::2]
    lengths = rle_nums[1::2]
    # sl_stack = [[1 10]
    #             [3 5]]
    sl_stack = tf.stack([starts, lengths])
    # starts_lengths = [[1 3]   <-- 1st pair of start and length
    #                   [10 5]] <-- 2nd pair of start and length
    starts_lengths = tf.transpose(sl_stack)
    range_length = height * width
    
    def _make_bool_mask(start_length):
        start = start_length[0]
        length = start_length[1]
        end = start + length
        # range_values = [0 1 2 ... range_length-1]
        range_values = tf.range(range_length, dtype=tf.int32)
        # range_mask = [False ... False True ... True False ... False]
        #                 0             start    end-1
        range_mask = tf.math.logical_and(
            start <= range_values, range_values < end)
        return range_mask
    
    bool_masks = tf.map_fn(
        fn=_make_bool_mask, elems=starts_lengths, dtype=tf.bool)
    bool_mask = tf.reduce_any(bool_masks, axis=0)
    ui8_one = tf.ones([], dtype=tf.uint8)
    ui8_zero = tf.zeros([], dtype=tf.uint8)
    mask = tf.where(bool_mask, ui8_one, ui8_zero)
    mask = tf.reshape(mask, [height, width, 1])
    return mask

def make_one_mask(rle, height, width):
    return tf.cond(
        tf.strings.length(rle) == 0,
        lambda: make_zero_mask(height, width),
        lambda: make_non_zero_mask(rle, height, width))

def make_mask(large_bowel, small_bowel, stomach, height, width):
    large_bowel_mask = make_one_mask(large_bowel, height, width)
    small_bowel_mask = make_one_mask(small_bowel, height, width)
    stomach_mask = make_one_mask(stomach, height, width)
    mask = tf.concat(
        [large_bowel_mask, small_bowel_mask, stomach_mask],
        axis=2)
    mask_bytes = tf.io.encode_png(mask)
    return mask_bytes

# TFRecords

In [None]:
def _bytes_feature(value):
    """Returns a bytes_list from a string / byte."""
    if isinstance(value, type(tf.constant(0))):
        # BytesList won't unpack a string from an EagerTensor.
        value = value.numpy() 
    elif isinstance(value, str):
        # string needs to be encoded to bytes.
        value = value.encode('utf-8')
    return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))

def _int64_feature(value):
    """Returns an int64_list from a bool / enum / int / uint."""
    return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))

def _float32_feature(value):
    """Returns an float32_list from a bool / enum / int / uint."""
    return tf.train.Feature(float_list=tf.train.FloatList(value=[value]))

In [None]:
def serialize_example(row):
    sample_id = row['id']
    case_no = int(row['case_no'])
    day_no = int(row['day_no'])
    slice_no = int(row['slice_no'])
    height = int(row['height'])
    width = int(row['width'])
    space_h = float(row['space_h'])
    space_w = float(row['space_w'])
    large_bowel_dice_coef = float(row['large_bowel_dice_coef'])
    small_bowel_dice_coef = float(row['small_bowel_dice_coef'])
    stomach_dice_coef = float(row['stomach_dice_coef'])
    slice_count = int(row['slice_count'])
    
    image_bytes = read_image(row['file_path'])
    mask_bytes = make_mask(
        row['large_bowel'], row['small_bowel'], row['stomach'],
        height, width)
    
    feature = {
        'id': _bytes_feature(sample_id.encode('utf-8')),
        'case_no': _int64_feature(case_no),
        'day_no': _int64_feature(day_no),
        'slice_no': _int64_feature(slice_no),
        'image': _bytes_feature(image_bytes),
        'mask': _bytes_feature(mask_bytes),
        'fold' : _int64_feature(row['fold']),
        'height': _int64_feature(height),
        'width': _int64_feature(width),
        'space_h': _float32_feature(space_h),
        'space_w': _float32_feature(space_w),
        'large_bowel_dice_coef': _float32_feature(large_bowel_dice_coef),
        'small_bowel_dice_coef': _float32_feature(small_bowel_dice_coef),
        'stomach_dice_coef': _float32_feature(stomach_dice_coef),
        'slice_count': _int64_feature(slice_count),
    }
    
    example_proto = tf.train.Example(
        features=tf.train.Features(feature=feature))
    return example_proto.SerializeToString()

In [None]:
train_data_df = train_data_df.head(300) if DEBUG else train_data_df

len(train_data_df)

In [None]:
n_data_per_tfrec = \
    (len(train_data_df) + NUM_TFREC_FILES - 1) // NUM_TFREC_FILES

n_data_per_tfrec

In [None]:
remaining_item_count = len(train_data_df)
tfrec_i = 0
tfrec_data_iter = train_data_df.iterrows()
while 0 < remaining_item_count:
    tfrec_item_count = min(n_data_per_tfrec, remaining_item_count)
    tfrec_file_name = "{0:02d}-{1:03d}.tfrec".format(tfrec_i, tfrec_item_count)
    print("Writing {0}...".format(tfrec_file_name))
    with tf.io.TFRecordWriter(tfrec_file_name) as writer:
        for tfrec_item_i in range(tfrec_item_count):
            if tfrec_item_i % 100 == 0:
                print(tfrec_item_i, ", ", end='')
            _, row = next(tfrec_data_iter)
            example = serialize_example(row)
            writer.write(example)
    print()
    remaining_item_count -= tfrec_item_count
    tfrec_i += 1

In [None]:
!ls -l

# Verify TFRecords

In [None]:
def decode_image(image_bytes, height, width):
    image_raw_bytes = tf.io.decode_raw(image_bytes, out_type=tf.uint8)
    image_len = height * width
    image_bytes = image_raw_bytes[-image_len: ]
    image = tf.reshape(image_bytes, [width, height, 1])
    return image

def decode_mask(mask_bytes):
    mask_png = tf.image.decode_png(mask_bytes)
    mask_float = tf.cast(mask_png, dtype=tf.float32)
    return mask_float

def read_tfrecord(example):
    TFREC_FORMAT = {
        'id': tf.io.FixedLenFeature([], tf.string),
        'case_no': tf.io.FixedLenFeature([], tf.int64),
        'day_no': tf.io.FixedLenFeature([], tf.int64),
        'slice_no': tf.io.FixedLenFeature([], tf.int64),
        'image': tf.io.FixedLenFeature([], tf.string),
        'mask': tf.io.FixedLenFeature([], tf.string),
        'fold': tf.io.FixedLenFeature([], tf.int64),
        'height': tf.io.FixedLenFeature([], tf.int64),
        'width': tf.io.FixedLenFeature([], tf.int64),
        'space_h': tf.io.FixedLenFeature([], tf.float32),
        'space_w': tf.io.FixedLenFeature([], tf.float32),
        'large_bowel_dice_coef': tf.io.FixedLenFeature([], tf.float32),
        'small_bowel_dice_coef': tf.io.FixedLenFeature([], tf.float32),
        'stomach_dice_coef': tf.io.FixedLenFeature([], tf.float32),
        'slice_count': tf.io.FixedLenFeature([], tf.int64),
    }
    
    example = tf.io.parse_single_example(example, TFREC_FORMAT)
    sample_id = example['id']
    case_no = example['case_no']
    day_no = example['day_no']
    slice_no = example['slice_no']
    mask = decode_mask(example['mask'])
    fold = example['fold']
    height = example['height']
    width = example['width']
    space_h = example['space_h']
    space_w = example['space_w']
    image = decode_image(example['image'], height, width)
    large_bowel_dice_coef = example['large_bowel_dice_coef']
    small_bowel_dice_coef = example['small_bowel_dice_coef']
    stomach_dice_coef = example['stomach_dice_coef']
    slice_count = example['slice_count']
    return \
        sample_id, case_no, day_no, slice_no, \
        image, mask, fold, height, width, space_h, space_w, \
        large_bowel_dice_coef, small_bowel_dice_coef, \
        stomach_dice_coef, slice_count

def load_dataset(filenames):
    dataset = tf.data.TFRecordDataset(filenames, num_parallel_reads=None)
    dataset = dataset.map(read_tfrecord, num_parallel_calls=None)
    return dataset

In [None]:
tfrec_file_names = sorted(tf.io.gfile.glob('*.tfrec'))
tfrec_ds = load_dataset(tfrec_file_names)

tfrec_ds

In [None]:
def draw_image_mask(ax, image, mask, title):
    ax.imshow(image, cmap='gray')
    ax.imshow(mask, alpha=0.5)
    ax.set_title(title)

def draw_tfrec_images(nrows, ncols, figsize):
    fig, axes = plt.subplots(nrows=nrows, ncols=ncols, figsize=figsize)
    tfrec_ds_iter = iter(tfrec_ds.skip(100))
    for row in range(nrows):
        for col in range(ncols):
            tfrec_data = next(tfrec_ds_iter)
            sample_id = tfrec_data[0].numpy().decode('utf-8')
            image = tfrec_data[4]
            mask = tfrec_data[5]
            title = sample_id
            draw_image_mask(axes[row, col], image, mask, title)
    plt.tight_layout()
    plt.show()

In [None]:
draw_tfrec_images(3, 4, (12, 10))