In [None]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

import os
for dirname, _, filenames in os.walk('/kaggle/input'):
    for filename in filenames:
        print(os.path.join(dirname, filename))

# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from PIL import Image
import tifffile as tiff
import cv2
import os
from tqdm.notebook import tqdm
import tensorflow as tf

In [None]:
orig = 1024
sz = 512 #128 #256 #the size of tiles
reduce = orig//sz  #reduce the original images by 'reduce' times 
MASKS = '../input/hubmap-kidney-segmentation/train.csv'
DATA = '../input/hubmap-kidney-segmentation/train/'
s_th = 40  #saturation blancking threshold
p_th = 200*sz//256 #threshold for the minimum number of pixels

In [None]:
#functions to convert encoding to mask and mask to encoding
def enc2mask(encs, shape):
    img = np.zeros(shape[0]*shape[1], dtype=np.uint8)
    for m,enc in enumerate(encs):
        if isinstance(enc,np.float) and np.isnan(enc): continue
        s = enc.split()
        for i in range(len(s)//2):
            start = int(s[2*i]) - 1
            length = int(s[2*i+1])
            img[start:start+length] = 1 + m
    return img.reshape(shape).T

def mask2enc(mask, n=1):
    pixels = mask.T.flatten()
    encs = []
    for i in range(1,n+1):
        p = (pixels == i).astype(np.int8)
        if p.sum() == 0: encs.append(np.nan)
        else:
            p = np.concatenate([[0], p, [0]])
            runs = np.where(p[1:] != p[:-1])[0] + 1
            runs[1::2] -= runs[::2]
            encs.append(' '.join(str(x) for x in runs))
    return encs

df_masks = pd.read_csv(MASKS).set_index('id')
df_masks.head()

In [None]:
# The following function can be used to convert a value to a type compatible
# with tf.train.Example.

def _bytes_feature(value):
  """Returns a bytes_list from a string / byte."""
  if isinstance(value, type(tf.constant(0))):
    value = value.numpy() # BytesList won't unpack a string from an EagerTensor.
  return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))

def serialize_example(image, mask):
  """
  Creates a tf.train.Example message ready to be written to a file.
  """
  # Create a dictionary mapping the feature name to the tf.train.Example-compatible
  # data type.
  feature = {
      'image': _bytes_feature(image),
      'mask': _bytes_feature(mask),
  }

  # Create a Features message using tf.train.Example.

  example_proto = tf.train.Example(features=tf.train.Features(feature=feature))
  return example_proto.SerializeToString()

In [None]:
WINDOW = orig #1024
MIN_OVERLAP = 300
NEW_SIZE = sz #512

import numpy as np
import pandas as pd
import os
import glob
import gc

import rasterio
from rasterio.windows import Window

import pathlib
from tqdm.notebook import tqdm
import cv2

import tensorflow as tf

def make_grid(shape, window=256, min_overlap=32):
    """
        Return Array of size (N,4), where N - number of tiles,
        2nd axis represente slices: x1,x2,y1,y2 
    """
    x, y = shape
    nx = x // (window - min_overlap) + 1
    x1 = np.linspace(0, x, num=nx, endpoint=False, dtype=np.int64)
    x1[-1] = x - window
    x2 = (x1 + window).clip(0, x)
    ny = y // (window - min_overlap) + 1
    y1 = np.linspace(0, y, num=ny, endpoint=False, dtype=np.int64)
    y1[-1] = y - window
    y2 = (y1 + window).clip(0, y)
    slices = np.zeros((nx,ny, 4), dtype=np.int64)
    
    for i in range(nx):
        for j in range(ny):
            slices[i,j] = x1[i], x2[i], y1[j], y2[j]    
    return slices.reshape(nx*ny,4)

def _bytes_feature(value):
  """Returns a bytes_list from a string / byte."""
  if isinstance(value, type(tf.constant(0))):
    value = value.numpy() # BytesList won't unpack a string from an EagerTensor.
  return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))

def _int64_feature(value):
  """Returns an int64_list from a bool / enum / int / uint."""
  return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))

def serialize_example(image, x1, y1):
  feature = {
      'image': _bytes_feature(image),
      'x1': _int64_feature(x1),
      'y1': _int64_feature(y1)
  }
  example_proto = tf.train.Example(features=tf.train.Features(feature=feature))
  return example_proto.SerializeToString()

In [None]:
p = pathlib.Path('../input/hubmap-kidney-segmentation')
identity = rasterio.Affine(1, 0, 0, 0, 1, 0)
os.makedirs('test', exist_ok = True)

for i, filename in tqdm(enumerate(p.glob('test/*.tiff')), 
                        total = len(list(p.glob('test/*.tiff')))):
    
    print(f'{i+1} Creating tfrecords for image: {filename.stem}')
    dataset = rasterio.open(filename.as_posix(), transform = identity)
    slices = make_grid(dataset.shape, window=WINDOW, min_overlap=MIN_OVERLAP)
    
    print(slices.shape[0])
    cnt = 0
    part = 0 
    fname = f'test/{filename.stem}-part{part}.tfrec'
    writer = tf.io.TFRecordWriter(fname) 
    for (x1,x2,y1,y2) in slices:
        if cnt>999:
            writer.close()
            os.rename(fname, f'test/{filename.stem}-part{part}-{cnt}.tfrec')
            part += 1
            fname = f'test/{filename.stem}-part{part}.tfrec'
            writer = tf.io.TFRecordWriter(fname)
            cnt = 0
        
        image = dataset.read([1,2,3],
                    window=Window.from_slices((x1,x2),(y1,y2)))
        image = np.moveaxis(image, 0, -1)
        image = cv2.resize(image, (NEW_SIZE, NEW_SIZE),interpolation = cv2.INTER_AREA)
        image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
        example = serialize_example(image.tobytes(),x1,y1)
        writer.write(example)
        cnt+=1
    writer.close()
    del writer
    os.rename(fname, f'test/{filename.stem}-part{part}-{cnt}.tfrec')
    gc.collect();

In [None]:
import re
import glob
def count_data_items(filenames):
    n = [int(re.compile(r"-([0-9]*)\.").search(filename).group(1)) for filename in filenames]
    return np.sum(n)

In [None]:
test_images = glob.glob('test/*.tfrec')
ctesti = count_data_items(test_images)
print(f'Num test images: {ctesti}')

In [None]:
DIM = sz
mini_size = 64
def _parse_image_function(example_proto):
    image_feature_description = {
        'image': tf.io.FixedLenFeature([], tf.string),
        'x1': tf.io.FixedLenFeature([], tf.int64),
        'y1': tf.io.FixedLenFeature([], tf.int64)
    }
    single_example = tf.io.parse_single_example(example_proto, image_feature_description)
    image = tf.reshape( tf.io.decode_raw(single_example['image'],out_type=np.dtype('uint8')), (DIM,DIM, 3))
    x1 = single_example['x1']
    y1 = single_example['y1']
    image = tf.image.resize(image,(mini_size,mini_size))/255.0
    return image, x1, y1


def load_dataset(filenames):
    dataset = tf.data.TFRecordDataset(filenames)
    dataset = dataset.map(lambda ex: _parse_image_function(ex))
    return dataset

N = 8
def get_dataset(FILENAME):
    dataset = load_dataset(FILENAME)
    dataset = dataset.batch(N*N)
    return dataset

In [None]:
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
for imgs, x1, y1 in get_dataset(test_images[0]).take(1):
    pass

plt.figure(figsize = (N,N))
gs1 = gridspec.GridSpec(N,N)

for i in range(N*N):
   # i = i + 1 # grid spec indexes from 0
    ax1 = plt.subplot(gs1[i])
    plt.axis('on')
    ax1.set_xticklabels([])
    ax1.set_yticklabels([])
    ax1.set_aspect('equal')
    ax1.set_title(f'{x1[i]}; {y1[i]}', fontsize=6)
    ax1.imshow(imgs[i])

plt.show()