# Use image generator to feed tfrecord
[tensorflow Images](https://tensorflow.org/api_guides/python/image) <br>
[planspace blog: tfrecords for humans](https://planspace.org/20170323-tfrecords_for_humans/) <br>
```python
im_gen_obj = PatchImageGenerator()
patch_image = im_gen_obj.next_patch()
```
****
### Fix and test:
* digipath_mltk.py & or digipath_toolkit.py
    * err-parameters checking

In [None]:
import time
nb_start_time = time.time()

import os
import tempfile
import sys

from collections import OrderedDict
import argparse

import tensorflow as tf
from tensorflow import io as tf_io

import numpy as np
import pandas as pd
import yaml

from skimage.filters import threshold_otsu
from skimage.color import rgb2lab

from PIL import ImageDraw
from PIL import TiffImagePlugin as tip

import IPython.display as ip_display

import openslide

sys.path.insert(0, '../src/python')
from digipath_toolkit import *

In [None]:
COMMON_THUMBNAIL_DIVISOR = 20

data_dir = '../../DigiPath_MLTK_data/Aperio'
file_type_list=['.svs', '.tif', '.tiff']
fs_od = get_file_size_ordered_dict(data_dir, file_type_list)
list_number = 0
for k, v in fs_od.items():
    print('%3i %30s: %i'%(list_number, k,v))
    list_number += 1

In [None]:
data_dir = '../../DigiPath_MLTK_data/Aperio'
image_file_name = 'CMU-1.svs'

run_parameters = dict()

run_parameters['wsi_filename'] = os.path.join(data_dir, image_file_name)
print('Image File:\n', run_parameters['wsi_filename'])

run_parameters['thumbnail_divisor'] = COMMON_THUMBNAIL_DIVISOR
run_parameters['patch_select_method'] = 'threshold_otsu' # 'threshold_rgb2lab'
run_parameters['patch_height'] = 224
run_parameters['patch_width'] = 224
run_parameters['threshold'] = 0
run_parameters['image_level'] = 2

run_parameters['class_label'] = 'class_label_test_str'
run_parameters['output_dir'] = '../../run_dir/tfrecord_result'
            
for k, v in run_parameters.items():
    print('%25s: %s'%(k,v))

    
print('\nCalling wsi_file_to_patches_tfrecord\n')
call_start_time = time.time()
wsi_file_to_patches_tfrecord(run_parameters)
print('\nTotal run time for wsi_file_to_patches_tfrecord: %0.3f'%(time.time() - call_start_time))

In [None]:
tfrecord_file_name = '../../run_dir/tfrecord_result/CMU-1.tfrecords'
iterable_tfrecord = get_iterable_tfrecord(tfrecord_file_name)

n_to_show = 10
for dakine in iterable_tfrecord.take(n_to_show):
    print(dakine['label'], dakine['image_name'])
    image_raw = dakine['image_raw'].numpy()
    print(type(image_raw))
    ip_display.display(ip_display.Image(data=image_raw))

## view tfrecord:

In [None]:
run_parameters['thumbnail_divisor'] = 5
run_parameters['border_color'] = 'green'
run_parameters['tfrecord_file_name'] = '../../run_dir/tfrecord_result/CMU-1.tfrecords'
thumb_preview = tf_record_to_marked_thumbnail_image(run_parameters)
display(thumb_preview)

In [None]:
run_parameters['output_file_name'] = 'test_jpg_write.jpg'
write_tfrecord_marked_thumbnail_image(run_parameters)

```python
"""
            copy - import from src/python/openslide_2_tfrecord.py
"""

def _bytes_feature(value):
    """Returns a bytes_list from a string / byte."""
    return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))


def _float_feature(value):
    """Returns a float_list from a float / double."""
    return tf.train.Feature(float_list=tf.train.FloatList(value=[value]))


def _int64_feature(value):
    """Returns an int64_list from a bool / enum / int / uint."""
    return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))


def tf_imp_dict(image_string, label, image_name, class_label='class_label'):
    """ tf_image_patch_dict = tf_imp_dict(image_string, label, image_name='patch')
        Create a dictionary of jpg image features
    
    Args:
        image_string:  bytes(PIL_image)
        label:         sequence number     (this is not the label you are looking for)
        image_name:    bytes(image_name)   (this is the label)
        
    Returns:
        one_tf_train_example: tf.train.Example 

    """
    image_shape = tf.image.decode_jpeg(image_string).shape
    feature = {'height': _int64_feature(image_shape[0]),
               'width': _int64_feature(image_shape[1]),
               'depth': _int64_feature(image_shape[2]),
               'label': _int64_feature(label), 
               'class_label': _bytes_feature(class_label),
               'image_name': _bytes_feature(image_name),
               'image_raw': _bytes_feature(image_string) }

    return tf.train.Example(features=tf.train.Features(feature=feature))

def _parse_tf_imp_dict(example_proto):
    """ tf_image_patch_dict = _parse_tf_imp_dict(example_proto)
        readback dict for tf_imp_dict() (.tfrecords file decoder)
    
    Args: 
        example_proto:
        
    Returns:
        iterable_tfrecord:   try iterable_tfrecord.__iter__()
    """
    image_feature_description = {
        'height': tf.io.FixedLenFeature([], tf.int64),
        'width': tf.io.FixedLenFeature([], tf.int64),
        'depth': tf.io.FixedLenFeature([], tf.int64),
        'label': tf.io.FixedLenFeature([], tf.int64),
        'class_label': tf.io.FixedLenFeature([], tf.string),
        'image_name': tf.io.FixedLenFeature([], tf.string),
        'image_raw': tf.io.FixedLenFeature([], tf.string) }

    return tf.io.parse_single_example(example_proto, image_feature_description)


def get_iterable_tfrecord(tfr_name):
    """ usage:
    iterable_tfrecord = get_iterable_tfrecord(tfr_name)
    
    Args:
        tfr_name:   tensorflow data TFRecord file
        
    Returns:
        iterable_tfrecord:  an iterable TFRecordDataset mapped to _parse_tf_imp_dict
    """
    return tf.data.TFRecordDataset(tfr_name).map(_parse_tf_imp_dict)

def wsi_file_to_patches_tfrecord(run_parameters):
    """ Usage: wsi_file_to_patches_tfrecord(run_parameters)
    Args:
        run_parameters:         with keys:
                                    output_dir
                                    wsi_filename
                                    class_label
                                    patch_height
                                    patch_width
                                    thumbnail_divisor
                                    patch_select_method
                                    threshold
                                    image_level
                                    
                                (optional)
                                    file_ext
    Returns:
        None:                    prints number of images and output file name if successful
        
    """
    _, file_name_base = os.path.split(run_parameters['wsi_filename'])
    file_name_base, _ = os.path.splitext(file_name_base)
    class_label = run_parameters['class_label']
    h = run_parameters['patch_height']
    w = run_parameters['patch_width']
    class_label = run_parameters['class_label']
    output_dir = run_parameters['output_dir']
    if 'file_ext' in run_parameters:
        file_ext = run_parameters['file_ext']
    else:
        file_ext = ''

    if os.path.isdir(output_dir) == False:
        os.makedirs(output_dir)
        print('created new dir:',output_dir)

    tfrecord_file_name = file_name_base + '.tfrecords'
    tfrecord_file_name = os.path.join(output_dir, tfrecord_file_name)

    patch_image_name_dict = {'case_id': file_name_base, 'class_label': class_label, 'file_ext': file_ext}

    patch_generator = PatchImageGenerator(run_parameters)

    with tf_io.TFRecordWriter(tfrecord_file_name) as writer:
        seq_number = 0
        while True:
            try:
                patch_dict = patch_generator.next_patch()
                x = patch_dict['image_level_x']
                y = patch_dict['image_level_y']
                patch_image_name_dict['location_x'] = x
                patch_image_name_dict['location_y'] = y
                patch_name = dict_to_patch_name(patch_image_name_dict)

                image_string = patch_dict['patch_image'].convert('RGB')

                tmp = tempfile.NamedTemporaryFile(delete=False, suffix='.jpg')
                try:
                    image_string.save(tmp.name)
                    image_string = open(tmp.name, 'rb').read()
                    
                except:
                    print('Image write-read exception with patch # %i, named:\n%s'%(seq_number, patch_name))
                    pass
                
                finally:
                    os.unlink(tmp.name)
                    tmp.close()

                tf_example_obj = tf_imp_dict(image_string,
                                             label=seq_number,
                                             image_name=bytes(patch_name,'utf8'), 
                                             class_label=bytes(class_label,'utf8') )

                writer.write(tf_example_obj.SerializeToString())
                seq_number += 1
                
            except StopIteration:
                print('%5i images written to %s'%(seq_number, tfrecord_file_name))
                break
                
def tf_record_to_marked_thumbnail_image(run_parameters):
    """ Usage: thumb_preview = tf_record_to_marked_thumbnail_image(run_parameters)
    """
    #                   unpack - name the variables
    tfrecord_file_name = run_parameters['tfrecord_file_name']
    wsi_filename = run_parameters['wsi_filename']
    patch_select_method = run_parameters['patch_select_method']
    thumbnail_divisor = run_parameters['thumbnail_divisor']
    border_color = run_parameters['border_color']

    #                   scale the patch size to the thumbnail image
    scaled_patch_height = run_parameters['patch_height'] // thumbnail_divisor - 1
    scaled_patch_width = run_parameters['patch_width'] // thumbnail_divisor - 1

    if 'image_level' in run_parameters:
        image_level = run_parameters['image_level']
    else:
        image_level = 0

    #                     OpenSlide open                      #
    os_im_obj = openslide.OpenSlide(wsi_filename)

    #                   get the size of the image at this image level
    obj_level_diminsions = os_im_obj.level_dimensions
    pixels_width = obj_level_diminsions[image_level][0]
    pixels_height = obj_level_diminsions[image_level][1]

    #                   get the thumbnail image scaled to the thumbnail divisor
    thumbnail_size = (pixels_width // thumbnail_divisor, pixels_height // thumbnail_divisor)
    thumb_preview = os_im_obj.get_thumbnail(thumbnail_size)
    os_im_obj.close()
    #                     OpenSlide close                      #

    #                   rectangle-drawing object for the thumbnail preview image
    thumb_draw = ImageDraw.Draw(thumb_preview)

    iterable_tfrecord = get_iterable_tfrecord(tfrecord_file_name)
    for patch_dict in iterable_tfrecord:
        im_name = patch_dict['image_name'].numpy().decode('utf-8')
        patch_name_dict = patch_name_to_dict(im_name)
        c = patch_name_dict['location_x']
        r = patch_name_dict['location_y']
        
        # define the patch location by upper left corner = (column, row)
        ulc = (c // thumbnail_divisor, r // thumbnail_divisor)
        
        #               lower right corner = upper left corner + scaled patch sizes
        lrc = (ulc[0] + scaled_patch_width, ulc[1] + scaled_patch_height)

        #               draw the rectangle from the upper left corner to the lower right corner
        thumb_draw.rectangle((ulc, lrc), outline=border_color, fill=None)
        
    return thumb_preview

```

In [None]:
"""
                optimistic first hak
"""
_, file_name_base = os.path.split(run_parameters['wsi_filename'])
file_name_base, _ = os.path.splitext(file_name_base)
class_label = run_parameters['class_label']
h = run_parameters['patch_height']
w = run_parameters['patch_width']
class_label = run_parameters['class_label']
output_dir = run_parameters['output_dir']

file_ext = ''

if os.path.isdir(output_dir) == False:
    os.makedirs(output_dir)
    print('created new dir:',output_dir)

tfrecord_file_name = file_name_base + '.tfrecords'
tfrecord_file_name = os.path.join(output_dir, tfrecord_file_name)

patch_image_name_dict = {'case_id': file_name_base, 'class_label': class_label, 'file_ext': file_ext}

patch_generator = PatchImageGenerator(run_parameters)
with tf_io.TFRecordWriter(tfrecord_file_name) as writer:
    seq_number = 0
    while True:
        try:
            patch_dict = patch_generator.next_patch()
            x = patch_dict['image_level_x']
            y = patch_dict['image_level_y']
            patch_image_name_dict['location_x'] = x
            patch_image_name_dict['location_y'] = y
            patch_name = dict_to_patch_name(patch_image_name_dict)
            
            image_string = patch_dict['patch_image'].convert('RGB')
            
            tmp = tempfile.NamedTemporaryFile(delete=False, suffix='.jpg')
            try:
                image_string.save(tmp.name)
                image_string = open(tmp.name, 'rb').read()
            finally:
                os.unlink(tmp.name)
                tmp.close()

            tf_example_obj = tf_imp_dict(image_string,
                                         label=seq_number,
                                         image_name=bytes(patch_name,'utf8'), 
                                         class_label=bytes(class_label,'utf8') )

            writer.write(tf_example_obj.SerializeToString())
            seq_number += 1
            
        except StopIteration:
            print('Iteration Stopped by Image Generator Signal: StopIteration')
            break

print(seq_number)

In [None]:
iterable_tfrecord = get_iterable_tfrecord(tfrecord_file_name)

n_to_show = 10
for dakine in iterable_tfrecord.take(n_to_show):
    print(dakine['label'], dakine['image_name'])
    image_raw = dakine['image_raw'].numpy()
    print(type(image_raw))
    ip_display.display(ip_display.Image(data=image_raw))

In [None]:
nxt_dict = patch_generator.next_patch()
for k in nxt_dict.keys():
    print(k)
    
print('\n# %04i, %i row, %i col'%(nxt_dict['patch_number'], nxt_dict['image_level_y'], nxt_dict['image_level_x']))
display(nxt_dict['patch_image'])

In [None]:
patch_generator = PatchImageGenerator(run_parameters)
print()
patch_count = 0
while True:
    try:
        #         print(G.next_patch())
        patch_dict = patch_generator.next_patch()
        x = patch_dict['image_level_x']
        y = patch_dict['image_level_y']
        print('%8i x,\t%8i y\tsize = '%(x, y), patch_dict['patch_image'].size)
        #         display the image:
        #display(patch_generator.next_patch()['patch_image'])
        patch_count += 1
        
    except StopIteration:
        print('StopIteration Exception thrown & caught')
        break
        
print('patch_number:', patch_dict['patch_number'], '\npatch_count:', patch_count)

In [None]:
help(patch_generator)

In [None]:
# import digipath_toolkit
# help(digipath_toolkit)

In [None]:
lost_data_dir = '../../DigiPath_MLTK_data/lost_data/'
os.listdir(lost_data_dir)