# TIF to TFRecords in multiple resolutions
In this notebook we will transform the huge train images into TFRecords format.
When creating TFRecords of a dataset it is most efficient for the processing pipeline to have TFRecords of a certain size, generally >10MBytes, to benefit from I/O prefetching.  

**Updated with the latest dataset!**

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
import glob
from PIL import Image
import hashlib
from io import BytesIO
from skimage import io
import contextlib2
import json
import cv2
import os, shutil  
from functools import partial
%matplotlib inline  

## References
This notebook uses/modifies some code snippets from these notebooks:
* [Global Wheat to TFRecords](https://www.kaggle.com/mistag/global-wheat-to-tfrecords) (own work)
* [HuBMap: Read data and build TFRecords](https://www.kaggle.com/marcosnovaes/hubmap-read-data-and-build-tfrecords)   

In addition some sample code from Keras documentation.

# TFRecords creation using patches
The super-resolution cell scan images need to be split into patches (or tiles if you like).   

Here we go for 1024x1024 as the primary tile (patch) size. Adjust the IMG_SIZE variable further below to suit needs. We also create downscaled versions of 512x512 and 256x256.   

In [None]:
# A few helper functions for reading in image and masks (from json)

def read_tif_file(fname):
    img = io.imread(fname)
    img = np.squeeze(img)
    if img.shape[0] == 3: # swap axes as required
        img = img.swapaxes(0,1)
        img = img.swapaxes(1,2)
    return img

def read_mask_file(fname, mshape):
    with open(fname) as f:
        mdata = json.load(f)
        polys = []
        for index in range(mdata.__len__()):
            if mdata[index]['properties']['classification']['name'] == 'glomerulus':
                geom = np.array(mdata[index]['geometry']['coordinates'])
                if geom.shape[0] == 1:
                    polys.append(geom[0].astype('int32'))
        mask = np.zeros(mshape, dtype=np.int8)
        cv2.fillPoly(mask, polys, 1)
        mask = mask.astype(bool, copy=False)
    return mask

The function below creates a TFRecord from a single patch. The image is stored as JPEG, while the mask is stored as PNG (lossless). We also put in some extra metadata.

In [None]:
# patches are stored as jpeg, masks as PNG
def create_tf_example(patch, m_patch, fid, x, y, size):
    filename = fid+'.tiff'
    height = size # Image height
    width = size # Image width
    buf= BytesIO()
    im = Image.fromarray(np.uint8(patch))
    im.save(buf, format= 'JPEG') # encode to jpeg in memory
    encoded_image_data= buf.getvalue()
    image_format = b'jpeg'
    source_id = fid+'-'+str(x)+'-'+str(y) # must be unique
    # A hash of the image is used in some frameworks
    key = hashlib.sha256(encoded_image_data).hexdigest()
    # Mask encoding
    buf= BytesIO()
    mim = Image.fromarray(np.uint8(m_patch))
    mim.save(buf, format= 'PNG') # encode to png in memory
    encoded_mask_data= buf.getvalue()
    mask_format = b'png'
    
    tf_record = tf.train.Example(features=tf.train.Features(feature={
        'image/height': tf.train.Feature(int64_list=tf.train.Int64List(value=[height])),
        'image/width': tf.train.Feature(int64_list=tf.train.Int64List(value=[width])),
        'image/filename': tf.train.Feature(bytes_list=tf.train.BytesList(value=[filename.encode()])),
        'image/source_id': tf.train.Feature(bytes_list=tf.train.BytesList(value=[source_id.encode()])),
        'image/encoded': tf.train.Feature(bytes_list=tf.train.BytesList(value=[encoded_image_data])),
        'image/key/sha256': tf.train.Feature(bytes_list=tf.train.BytesList(value=[key.encode()])),
        'image/format': tf.train.Feature(bytes_list=tf.train.BytesList(value=[image_format])),
        'image/patch-x': tf.train.Feature(int64_list=tf.train.Int64List(value=[x])),
        'image/patch-y': tf.train.Feature(int64_list=tf.train.Int64List(value=[y])),
        'mask/patch-x': tf.train.Feature(int64_list=tf.train.Int64List(value=[x])),
        'mask/patch-y': tf.train.Feature(int64_list=tf.train.Int64List(value=[y])),
        'mask/encoded': tf.train.Feature(bytes_list=tf.train.BytesList(value=[encoded_mask_data])),
        'mask/format': tf.train.Feature(bytes_list=tf.train.BytesList(value=[mask_format])),
    }))
    
    return tf_record

## Create TFRecord per image
We create a separate TFRecord file for each image. Well, actually two: One for tiles containing glomeruli ("Tissue") and one for the rest ("Bkgnd").

In [None]:
PATH = '/kaggle/input/hubmap-kidney-segmentation/train/'
filelist = glob.glob(PATH+'*.tiff')
filelist

In [None]:
fnames = []
for f in filelist:
    fnames.append(f.split('train/')[-1].split('.')[0])
fnames

In [None]:
#%%time
PATH = '/kaggle/input/hubmap-kidney-segmentation/train/'
IMG_SIZE = 1024 # adjust according to desired tile size
OVERLAP = IMG_SIZE//2 # overlap between each tile
STEP = IMG_SIZE-OVERLAP
SCALES = 2 # number of times to downscale each patch (generate SCALES+1 sets of images)
SCALE_FACTOR = 2 # scale factor to use for each set, adjust according to needs
BKGND_1_OF_X = 15 # add every n'th background image to the TFRecord

img_sizes = np.zeros(SCALES+1, dtype=int)
img_sizes[0] = IMG_SIZE
for i in range(SCALES):
    img_sizes[i+1] = int(img_sizes[i]/SCALE_FACTOR)

filelist = glob.glob(PATH+'*.tiff')
FCNT = len(filelist)

def open_sharded_tfrecords(exit_stack, names, size):
    tf_record_output_filenames = [
        '{}-{}.tfrecord'.format(names[idx], size)
        for idx in range(len(names))
        ]
    tfrecords = [
        exit_stack.enter_context(tf.io.TFRecordWriter(file_name))
        for file_name in tf_record_output_filenames
    ]
    return tfrecords

gcnt = np.zeros(len(filelist), dtype=int)

# A context2.ExitStack is used to automatically close all the TFRecords created 
with contextlib2.ExitStack() as tf_record_close_stack:
    # create list of TFRecords
    output_tfrecords1 = []
    for scnt in range(SCALES+1):
        output_tfrecords1.append(open_sharded_tfrecords(tf_record_close_stack, fnames, img_sizes[scnt]))
    # process images w/overlapped tiles
    output_shard_index = 0
    for file in filelist:
        print(file)
        fid = file.replace('\\','.').replace('/','.').split('.')[-2]        
        img, mask = np.zeros(10), np.zeros(10) 
        img = read_tif_file(file)
        dims = np.array(img.shape[:2])
        mask = read_mask_file(file.split('.')[0]+'.json', dims)
        bcnt = 0
        for x in range((img.shape[0]-OVERLAP)//STEP):
            for y in range((img.shape[1]-OVERLAP)//STEP):
                # Extract patch
                patch = img[x*STEP:x*STEP+IMG_SIZE, y*STEP:y*STEP+IMG_SIZE]
                m_patch = mask[x*STEP:x*STEP+IMG_SIZE, y*STEP:y*STEP+IMG_SIZE]*255
                # separate tissue from bakground by checking for non-zero pixels in mask
                IsTissue = False
                if np.max(m_patch) == 255:
                    IsTissue = True;
                tf_record = create_tf_example(patch, m_patch, fid, x, y, size=img_sizes[0])
                if IsTissue:
                    output_tfrecords1[0][output_shard_index].write(tf_record.SerializeToString())
                    gcnt[output_shard_index] += 1
                else:
                    if bcnt == BKGND_1_OF_X:
                        output_tfrecords1[0][output_shard_index].write(tf_record.SerializeToString())
                        gcnt[output_shard_index] += 1
                # create downscaled images
                for s in range(SCALES): 
                    spatch = cv2.resize(patch, dsize=(img_sizes[s+1], img_sizes[s+1]), interpolation = cv2.INTER_AREA)
                    sm_patch = cv2.resize(m_patch.astype(int), dsize=(img_sizes[s+1], img_sizes[s+1]), interpolation = cv2.INTER_NEAREST)
                    tf_record = create_tf_example(spatch, sm_patch, fid, x, y, size=img_sizes[s+1])
                    if IsTissue:
                        output_tfrecords1[s+1][output_shard_index].write(tf_record.SerializeToString())
                    else:
                        if bcnt == BKGND_1_OF_X:
                            output_tfrecords1[s+1][output_shard_index].write(tf_record.SerializeToString())
                if bcnt == BKGND_1_OF_X:
                    bcnt = 0
                else:
                    bcnt = bcnt + 1
        output_shard_index += 1


Create a .json file with a few useful parameters we might need during training/inference:

In [None]:
dparams = {
    "IMG_SIZE": IMG_SIZE,
    "SCALE_FACTOR": SCALE_FACTOR}
with open("dparams.json", "w") as json_file:
    json_file.write(json.dumps(dparams, indent = 4))

Export a table with number of images per TFRecord.

In [None]:
import pandas as pd

recsizes = []
num_shards = len(filelist)
for i in range(num_shards):
    for j in range(len(img_sizes)):
        recsizes.append(['{}-{}.tfrecord'.format(fnames[i], img_sizes[j]), gcnt[i]])
df = pd.DataFrame(recsizes, columns=['File', 'ImgCount'])
df.to_pickle('./record_stats.pkl')
df.head(5)

# Check the output files
The final check is to load and plot a couple of TFRecords and verify that everything loads OK.

In [None]:
def plot_imgs(dataset):
    fig = plt.figure(figsize=(18,18))
    idx=1
    for raw_record in dataset.take(36):
        axes = fig.add_subplot(6, 6, idx)
        example = tf.train.Example()
        example.ParseFromString(raw_record.numpy())
        img_encoded=example.features.feature['image/encoded'].bytes_list.value[0]
        img = Image.open(BytesIO(img_encoded))
        mask_encoded=example.features.feature['mask/encoded'].bytes_list.value[0]
        mask = Image.open(BytesIO(mask_encoded))
        plt.setp(axes, xticks=[], yticks=[])
        plt.imshow(img)
        plt.imshow(mask, alpha=0.25)
        idx=idx+1

In [None]:
fname='./095bf7a1f-256.tfrecord'
dataset = tf.data.TFRecordDataset(fname)
dataset = dataset.shuffle(2048, reshuffle_each_iteration=True)
plot_imgs(dataset)