In [None]:
import os
import numpy as np
import pandas as pd
import cv2
import tifffile
import matplotlib.pyplot as plt
from tqdm import notebook as tqdm
import tensorflow as tf

In [None]:
df = pd.read_csv('../input/hubmap-kidney-segmentation/train.csv')
image_list_train = list(df['id'])

In [None]:
def rle2mask(mask_rle, shape):
    
    s = mask_rle.split()
    starts, lengths = [np.asarray(x, dtype=int) for x in (s[0:][::2], s[1:][::2])]
    starts -= 1
    ends = starts + lengths
    img = np.zeros(shape[0]*shape[1], dtype=np.uint8)
    for lo, hi in zip(starts, ends):
        img[lo:hi] = 1
    return img.reshape(shape).T


In [None]:
import gc
def _bytes_feature(value):
    """Returns a bytes_list from a string / byte."""
    if isinstance(value, type(tf.constant(0))):
        value = value.numpy() # BytesList won't unpack a string from an EagerTensor.
    return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))

def _float_feature(value):
    """Returns a float_list from a float / double."""
    return tf.train.Feature(float_list=tf.train.FloatList(value=[value]))

def _int64_feature(value):
    """Returns an int64_list from a bool / enum / int / uint."""
    return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))

def image_example(image,mask):
    #image,mask = image[:,:,:-1],image[:,:,-1:]
    image_shape = image.shape
    
    img_bytes = image.tobytes()

    mask_bytes = mask.tobytes()
    
    feature = {
        'img_bytes': _bytes_feature(img_bytes),
        'mask' : _bytes_feature(mask_bytes),
    }
    return tf.train.Example(features=tf.train.Features(feature=feature))

input_dir_train = '/kaggle/input/hubmap-kidney-segmentation/train'
pbar = tqdm.tqdm(image_list_train[-2:])
part_num = 0
for tiff_name in pbar:
    pbar.set_description('Writing File '+tiff_name +' ...')
    image_path = os.path.join(input_dir_train,tiff_name +".tiff")
    print("Start -okay:",tiff_name)
    im = np.squeeze(tifffile.imread(image_path))
    print("Read -okay:",tiff_name)
    if( im.shape[0] == 3):
        im = im.swapaxes(0,1)
        im = im.swapaxes(1,2)
    im0 = rle2mask(df[df["id"] == tiff_name]["encoding"].values[0], (im.shape[1], im.shape[0]))*255
    #im = tf.convert_to_tensor(im
    im_list = []
    mask_list = []
    print("Mask -okay:",tiff_name)
    for i in range(0,im.shape[0],520):
        for j in range(0,im.shape[1],520):
            patch = im[i:i+520,j:j+520,:]
            mask = im0[i:i+520,j:j+520]
            
            if patch.shape!=(520,520,4):
                patch = np.pad(patch,((520-patch.shape[0],0),(520-patch.shape[1],0),(0,0)),'constant')
                mask = np.pad(mask,((520-mask.shape[0],0),(520-mask.shape[1],0)),'constant')
            if (0<np.sum(patch,axis = (0,1,2))<179000000) and (np.sum(mask)==0 or np.sum(mask)>25000):
                im_list.append(patch)
                mask_list.append(mask)
    print("Patching -okay:",tiff_name)
    #im_0,mask_0 = create_tiles(df)
    length = len(im_list) 
    for i in tqdm.tqdm(range(0,length,128),desc="Writing Shards ..."):
        start,end = i,i+128
        image_shard = im_list[start:end]
        mask_shard = mask_list[start:end]
        part_num+=1
        with tf.io.TFRecordWriter('Bad_kidney_val%i'%(part_num), tf.io.TFRecordOptions(compression_type="GZIP"))  as writer:
            for image,mask in zip(image_shard,mask_shard):
                tf_example = image_example(image,mask)
                writer.write(tf_example.SerializeToString())
    del im
    del patch
    del image_shard
    del im_list
    del mask_list
    del tf_example
    del length
    gc.collect()