In [None]:
import os, math, glob, re
import numpy as np
import pandas as pd
import cv2

import matplotlib.pyplot as plt

import tensorflow as tf

from tqdm import tqdm

# Read Data

In [None]:
IMAGE_WIDTH  = 704
IMAGE_HEIGHT = 520

df = pd.read_csv("../input/sartorius-cell-instance-segmentation/train.csv")
uuids = df["id"].unique()

# Load Images

In [None]:
def rle_decode(mask_rle, shape):
    s = mask_rle.split()
    starts, lengths = [np.asarray(x, dtype=int) for x in (s[0:][::2], s[1:][::2])]
    starts -= 1
    ends = starts + lengths
    img = np.zeros((shape[0] * shape[1], shape[2]), dtype=np.uint8)
    for start, end in zip(starts, ends):
        img[start : end] = 1
    return img.reshape(shape)


def build_masks(image_id, shape):
    labels = df[df["id"] == image_id]["annotation"].tolist()
    masks = []
    
    for label in labels:
        masks.append(rle_decode(label, shape=shape))
            
    return np.squeeze(np.stack(masks, axis=-1))

In [None]:
sample_filename = '0030fd0e6378'
sample_path = os.path.join("../input/sartorius-cell-instance-segmentation/train", f"{sample_filename}.png")
sample_img = cv2.imread(sample_path)
sample_img = (sample_img /255.).astype('float32')
sample_masks = build_masks(sample_filename, shape=(IMAGE_HEIGHT, IMAGE_WIDTH, 1))
print("Image Shape: ", sample_img.shape, sample_img.dtype)
print("Label Shape: ", sample_masks.shape, sample_masks.dtype)
fig, axs = plt.subplots(1, 2,figsize=(20, 20))
axs[0].imshow(sample_img)
axs[0].axis("off")
axs[1].imshow(np.sum(sample_masks, axis=-1))
axs[1].axis("off")
plt.show()

# Convert to TFRecord

In [None]:
def _bytes_feature(value):
    if isinstance(value, type(tf.constant(0))):
        value = value.numpy() 
    return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))

def _float_feature(value):
    return tf.train.Feature(float_list=tf.train.FloatList(value=[value]))

def _int64_feature(value):
    return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))

In [None]:
def serialize_example(image, label):
    feature = {
        'image': _bytes_feature(image.tobytes()),
        'label': _bytes_feature(label.tobytes())
    }
    example_proto = tf.train.Example(features=tf.train.Features(feature=feature))
    return example_proto.SerializeToString()

In [None]:
%%time
! mkdir -p ./tfrecords/
outpath = "./tfrecords"
with tf.io.TFRecordWriter(os.path.join(outpath,'sartorius.tfrec'), options=tf.io.TFRecordOptions(compression_type="GZIP")) as writer:
    for i in tqdm(uuids, colour="#73d315", ncols=100):
        img_path = os.path.join("../input/sartorius-cell-instance-segmentation/train", f"{i}.png")
        img = cv2.imread(img_path)
        img = (img/255.).astype('float32')
        mask = build_masks(i, shape=(IMAGE_HEIGHT, IMAGE_WIDTH, 1))
        example = serialize_example(img, mask)
        writer.write(example)

# Test Written Data 

## Deserialize TFRecord

In [None]:
def deserialize_example(serialized_string):
    image_feature_description = {
        'image': tf.io.FixedLenFeature([], tf.string),
        'label': tf.io.FixedLenFeature([], tf.string)
    }
    parsed_record = tf.io.parse_single_example(serialized_string, image_feature_description)
    image = tf.reshape(tf.io.decode_raw(parsed_record['image'], tf.float32),(IMAGE_HEIGHT, IMAGE_WIDTH, 3))
    label = tf.reshape(tf.io.decode_raw(parsed_record['label'], tf.uint8),(IMAGE_HEIGHT, IMAGE_WIDTH, -1))
    return image, label

In [None]:
train_set = tf.data.TFRecordDataset(os.path.join(outpath,"sartorius.tfrec"), compression_type="GZIP").map(deserialize_example)

## Plot Data

In [None]:
ds = train_set.take(1)
for image, label in ds:
    print(image.shape)
    print(label.shape)

fig, axs = plt.subplots(1, 2,figsize=(20, 20))
axs[0].imshow(image)
axs[0].axis('off')
axs[1].imshow(np.sum(label, axis=-1))
axs[1].axis('off')
plt.show()