This notebook creates a *train.csv* and a number of `TFRecord`s. Each record corresponds to a grouped fold. The fold respects the cases available in the training PNGs (i.e. `GroupKFold`). The *train.csv* has the information about the training examples, as well as, the fold each belongs to. The records are available in the following dataset:

[UWMGTIS Training Dataset](https://www.kaggle.com/datasets/jasonprasad/uwmgtis-training-dataset)

In [None]:
import glob
import math
import os
import random

import matplotlib.pyplot as plt
import matplotlib.animation as animation
import numpy as np
import pandas as pd
import tensorflow as tf

from IPython.display import HTML
from sklearn.model_selection import GroupKFold

In [None]:
DATA_DIR = "/kaggle/input/uw-madison-gi-tract-image-segmentation"
N_SPLITS = 16  # Works out to be > ~100MB per TFRecord
OUTPUT_SHAPE = (224, 224)
SEED = 42
TRAIN_DIR = f"{DATA_DIR}/train"
TRAIN_CSV = f"{DATA_DIR}/train.csv"

In [None]:
# The shuffle for the GroupKFold adds randomness
os.environ['PYTHONHASHSEED'] = str(SEED)
random.seed(SEED)
np.random.seed(SEED)
tf.random.set_seed(SEED)

# Creating the image DataFrame

The image `DataFrame` is constructed by reading all of the *.png* paths available in the train directory. The `DataFrame` will contain the case, day, slice, filepath, and pixel dimensions. The filepath is all that is necessary to create the image `tf.data.Dataset`, whereas the other attributes are used to decode the segmentation masks.

In [None]:
def parse_int(s):
    i = j = 0
    while i < len(s):
        while j < len(s) and s[j].isdigit():
            j += 1
        if i < j:
            return int(s[i:j])
        i = j = j + 1
    return math.nan

def extract_image_info(path):
    case_day, fname = path.rsplit("/", maxsplit=3)[1::2]
    case, day = case_day.split("_")
    slice, height, width = fname.split("_", maxsplit=4)[1:4]
    info = [f"{case}_{day}_slice_{slice}"]
    numeric_info = [case, day, slice, height, width]
    info.extend([parse_int(info) for info in numeric_info])
    return info
    
def create_image_df(path):
    paths = glob.glob(f"{TRAIN_DIR}/**/*.png", recursive=True)
    df = pd.DataFrame({"path": paths})
    info_cols = ["id", "case", "day", "slice", "height", "width"]
    df[info_cols] = df.apply(
        lambda row: extract_image_info(row.path), 
        axis=1,
        result_type="expand"
    )
    return df

In [None]:
image_df = create_image_df(TRAIN_DIR)
image_df.head()

# Creating the mask DataFrame

The mask`DataFrame` is constructed by manipulating the *train.csv*. The *class/segmentation* structure is flattened for easier access later on.

In [None]:
df = pd.read_csv(TRAIN_CSV)
df.head()

I took the `unstack` logic from [UWMGI Image Segmentation Make TFRecords](https://www.kaggle.com/code/tt195361/uwmgi-image-segmentation-make-tfrecords)

In [None]:
def create_mask_df(df):
    df = (
        df.set_index(["id", "class"])
        .unstack()
        .reset_index()
    )
    df.columns = "id", "lb_seg", "sb_seg", "s_seg"
    df.fillna("", inplace=True)
    return df

In [None]:
mask_df = create_mask_df(df)
mask_df.head()

# Join the image and mask DataFrames and create K-folds

The image and mask `DataFrames` are joined so that the may be split into K-folds, as well as, to relate the pixel dimensions to the segmentation masks.

In [None]:
df = image_df.merge(mask_df, on="id", how="left")
df.head()

In [None]:
X = df.sample(frac=1).reset_index(drop=True)
groups = X["case"]

In [None]:
folds = GroupKFold(n_splits=N_SPLITS).split(X=X, groups=groups)

In [None]:
X["fold"] = -1
for fold, (_, test_indices) in enumerate(folds):
    X.loc[test_indices, "fold"] = fold

In [None]:
X.to_csv("train.csv", index=False)

# Creating TF dataset for images and masks

The following helpers were taken from [UWMGIT - DeepLabV3+ - End-to-End Pipeline [TF]](https://www.kaggle.com/code/dschettler8845/uwmgit-deeplabv3-end-to-end-pipeline-tf#helper_functions) with slight modifications. Assumptions that were made: to use resizing with padding to minimize distortion, as well as, to use nearest neighbor in resizing to preserve the `dtype`.

In [None]:
def tf_rle_decode(mask_rle, orig_shape, output_shape):
    shape = tf.convert_to_tensor(orig_shape, tf.int64)
    size = tf.math.reduce_prod(shape)
    
    # Split string
    s = tf.strings.split(mask_rle)
    s = tf.strings.to_number(s, tf.int64)
    
    # Get starts and lengths
    starts = s[::2] - 1
    lens = s[1::2]
    
    # Make ones to be scattered
    total_ones = tf.reduce_sum(lens)
    ones = tf.ones([total_ones], tf.uint8)
    
    # Make scattering indices
    r = tf.range(total_ones)
    lens_cum = tf.math.cumsum(lens)
    s = tf.searchsorted(lens_cum, r, "right")
    idx = r + tf.gather(starts - tf.pad(lens_cum[:-1], [(1, 0)]), s)
    
    # Scatter ones into flattened mask
    mask_flat = tf.scatter_nd(tf.expand_dims(idx, 1), ones, [size])
    
    # Reshape and resize into mask
    mask = tf.reshape(mask_flat, orig_shape)
    mask = tf.expand_dims(mask, axis=-1)
    mask = tf.image.resize_with_pad(
        mask,
        *output_shape,
        method=tf.image.ResizeMethod.NEAREST_NEIGHBOR
    )
    return tf.cast(mask, tf.uint8)

def tf_load_image(path, output_shape):
    """ Load an image with the resized output shape using only TF
    
    Args:
        path (string): Path to the image to be loaded
        output_shape (tuple, optional): Shape to resize image
    
    Returns:
        3 channel tf.Constant image ready for training/inference
    
    """
    raw = tf.io.read_file(path)
    img = tf.image.decode_png(raw, channels=3, dtype=tf.uint16)
    img = tf.image.resize_with_pad(
        img, 
        *output_shape,
        tf.image.ResizeMethod.NEAREST_NEIGHBOR
    )
    return img

In [None]:
def assemble_masks(lb_seg, sb_seg, s_seg, height, width, output_shape):
    orig_shape = (width, height)  # this ordering is important to decode the rle correctly!
    masks = [
        tf_rle_decode(lb_seg, orig_shape, output_shape),
        tf_rle_decode(sb_seg, orig_shape, output_shape),
        tf_rle_decode(s_seg, orig_shape, output_shape),
    ]
    return tf.concat(masks, axis=-1)
    
def make_dataset(df, output_shape=OUTPUT_SHAPE):
    id_ds = tf.data.Dataset.from_tensor_slices(df.id)
    image_ds = tf.data.Dataset.from_tensor_slices(df.path)
    image_ds = image_ds.map(
        lambda path: tf_load_image(path, output_shape),
        num_parallel_calls=tf.data.AUTOTUNE
    )
    mask_ds = tf.data.Dataset.from_tensor_slices((
        df.lb_seg, 
        df.sb_seg, 
        df.s_seg, 
        df.height,
        df.width
    ))
    mask_ds = mask_ds.map(
        lambda *args: assemble_masks(*args, output_shape),
        num_parallel_calls=tf.data.AUTOTUNE
    )
    
    return tf.data.Dataset.zip((id_ds, image_ds, mask_ds))

def show_example(ds):
    for example in ds:
        id, img, mask = example
        if tf.reduce_sum(mask) > 0:
            break
    _, ax = plt.subplots(figsize=(6, 6))
    ax.set_title(id.numpy().decode())
    ax.imshow(tf.keras.utils.array_to_img(img), cmap="gray")
    ax.imshow(tf.keras.utils.array_to_img(mask), cmap="hot", alpha=0.5)
    plt.show()

In [None]:
example_fold = 0
example_ds = make_dataset(X[X.fold == example_fold])
show_example(example_ds)

# Persisting dataset with TFRecord

In [None]:
def bytes_feature(value):
    return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))

def encode_example(id, image, mask):
    return id, tf.io.encode_png(image), tf.io.encode_png(mask)

def serialize_example(id, image, mask):
    features = tf.train.Features(feature={
        "id": bytes_feature(id),
        "image": bytes_feature(image),
        "mask": bytes_feature(mask),
    })
    
    example_proto = tf.train.Example(features=features)
    return example_proto.SerializeToString()

def write_records(ds, fold, n_splits=N_SPLITS, output_shape=OUTPUT_SHAPE):
    h, w = output_shape
    path = f"uwmgtis-{h}-{w}.tfrecord-{fold:04d}-of-{n_splits:04d}"
    with tf.io.TFRecordWriter(path) as writer:
        for id, image, mask in ds.as_numpy_iterator():
            example = serialize_example(id, image, mask)
            writer.write(example)

In [None]:
for fold in range(N_SPLITS):
    fold_ds = make_dataset(X[X.fold == fold])
    encoded_fold_ds = fold_ds.map(encode_example, num_parallel_calls=tf.data.AUTOTUNE)
    write_records(encoded_fold_ds, fold)

# Reading an example TFRecord

In [None]:
def parse_example(example):
    features = {
        "id": tf.io.FixedLenFeature([], tf.string),
        "image": tf.io.FixedLenFeature([], tf.string),
        "mask": tf.io.FixedLenFeature([], tf.string),
    }
    
    example = tf.io.parse_single_example(example, features)
    image = tf.image.decode_png(example["image"], channels=3, dtype=tf.uint16)
    mask = tf.image.decode_png(example["mask"], channels=3, dtype=tf.uint8)
    return example["id"], image, mask

In [None]:
example_record = "uwmgtis-224-224.tfrecord-0000-of-0016"
record_ds = tf.data.TFRecordDataset(example_record)
record_ds = record_ds.map(parse_example, num_parallel_calls=tf.data.AUTOTUNE)
show_example(record_ds)

With the *id* persisted in the records we can retrieve specific slices. In the following case we are using it to animate the scans and masks.

In [None]:
case36_day8 = filter(
    lambda x: x[0].decode().startswith("case36_day8"), 
    record_ds.as_numpy_iterator()
)
case36_day8 = list(case36_day8)
case36_day8.sort(key=lambda x: x[0])

In [None]:
fig, ax = plt.subplots(figsize=(6, 6))

ims = []
for _, img, mask in case36_day8:
    im = ax.imshow(tf.keras.utils.array_to_img(img), cmap="gray")
    im2 = ax.imshow(tf.keras.utils.array_to_img(mask), cmap="hot", alpha=0.5)
    ims.append([im, im2])

ani = animation.ArtistAnimation(fig, ims)
plt.close()
HTML(ani.to_jshtml())