In [1]:
import glob
import math
import matplotlib.pyplot as plt
import multiprocessing
import numpy as np
import openslide
import os
import PIL
from PIL import Image, ImageDraw, ImageFont
import re
import sys
import numpy as np
import scipy.ndimage.morphology as sc_morph
import skimage.color as sk_color
import skimage.exposure as sk_exposure
import skimage.feature as sk_feature
import skimage.filters as sk_filters
import skimage.future as sk_future
import skimage.morphology as sk_morphology
import skimage.segmentation as sk_segmentation
from enum import Enum
import colorsys
import skimage.io
Image.MAX_IMAGE_PIXELS = None

In [2]:
np.random.seed(0)
BASE_DIR = os.path.join(".","test_data")
SLIDE_IMG= os.path.join(".","tcga_images","TCGA-3P-A9WA-01A-01-TS1.EC84245E-56A9-46C1-9496-CCF2587D318E.svs")
SRC_TRAIN_DIR = os.path.join(".","training_png")
SCALE_FACTOR = 64
DEST_TRAIN_EXT = "png"
SRC_TRAIN_EXT = "svs"
DEST_TRAIN_DIR = os.path.join("training_" + DEST_TRAIN_EXT)

FILTER_RESULT_TEXT = "filtered"
FILTER_DIR = os.path.join("filter_" + DEST_TRAIN_EXT)
FILTER_PAGINATION_SIZE = 50
FILTER_PAGINATE = True

TILE_SUMMARY_DIR = os.path.join(".","tile_summary_" + DEST_TRAIN_EXT)
TILE_SUMMARY_PAGINATION_SIZE = 50
TILE_SUMMARY_PAGINATE = True
TILE_SUMMARY_HTML_DIR = BASE_DIR

TILE_DATA_DIR = os.path.join(".", "tile_data")
TILE_DATA_SUFFIX = "tile_data"
TILE_DIR = os.path.join(".", "tile_images")

In [3]:
TISSUE_HIGH_THRESH = 80
TISSUE_LOW_THRESH = 10

ROW_TILE_SIZE = 299
COL_TILE_SIZE = 299
NUM_TOP_TILES = 50

DISPLAY_TILE_SUMMARY_LABELS = False
TILE_LABEL_TEXT_SIZE = 10
LABEL_ALL_TILES_IN_TOP_TILE_SUMMARY = False
BORDER_ALL_TILES_IN_TOP_TILE_SUMMARY = False

TILE_BORDER_SIZE = 2  

HIGH_COLOR = (0, 255, 0)
MEDIUM_COLOR = (255, 255, 0)
LOW_COLOR = (255, 165, 0)
NONE_COLOR = (255, 0, 0)

FADED_THRESH_COLOR = (128, 255, 128)
FADED_MEDIUM_COLOR = (255, 255, 128)
FADED_LOW_COLOR = (255, 210, 128)
FADED_NONE_COLOR = (255, 128, 128)

FONT_PATH = "/Library/Fonts/Arial Bold.ttf"
SUMMARY_TITLE_FONT_PATH = "/Library/Fonts/Courier New Bold.ttf"
SUMMARY_TITLE_TEXT_COLOR = (0, 0, 0)
SUMMARY_TITLE_TEXT_SIZE = 24
SUMMARY_TILE_TEXT_COLOR = (255, 255, 255)
TILE_TEXT_COLOR = (0, 0, 0)
TILE_TEXT_SIZE = 36
TILE_TEXT_BACKGROUND_COLOR = (255, 255, 255)
TILE_TEXT_W_BORDER = 5
TILE_TEXT_H_BORDER = 4

HSV_PURPLE = 270
HSV_PINK = 330

In [4]:
files = os.listdir('/home/roshan/tcga_images')
files[0]

'TCGA-61-1903-11A-01-TS1.2ca2cf14-9a01-408d-be23-73dffb093ead.svs'

In [5]:
def slide_to_scaled_pil_image(f0):
        slide = openslide.open_slide(f0.numpy().decode('utf-8'))
        large_w, large_h = slide.dimensions
        new_w = math.floor(large_w / SCALE_FACTOR)
        new_h = math.floor(large_h / SCALE_FACTOR)
        level = slide.get_best_level_for_downsample(SCALE_FACTOR)
        whole_slide_image = slide.read_region((0, 0), level, slide.level_dimensions[level])
        whole_slide_image = whole_slide_image.convert("RGB")
        img = whole_slide_image.resize((new_w, new_h), PIL.Image.BILINEAR)
        img = np.array(img)
        return img, large_w, large_h, new_w, new_h

In [6]:
def mask_percent(np_img):
    if (len(np_img.shape) == 3) and (np_img.shape[2] == 3):
        np_sum = np_img[:, :, 0] + np_img[:, :, 1] + np_img[:, :, 2]
        mask_percentage = 100 - tf.math.count_nonzero(np_sum,dtype=tf.int32) / tf.size(np_sum) * 100
    else:
        mask_percentage = 100 - tf.math.count_nonzero(np_img,dtype=tf.int32) / tf.size(np_img) * 100
    return mask_percentage

In [7]:
def filter_grays(rgb, tolerance=15):
    (h, w, c) = rgb.shape
    rgb = tf.cast(rgb,tf.int32)
    rg_diff = tf.math.abs(rgb[:, :, 0] - rgb[:, :, 1]) <= tolerance
    rb_diff = tf.math.abs(rgb[:, :, 0] - rgb[:, :, 2]) <= tolerance
    gb_diff = tf.math.abs(rgb[:, :, 1] - rgb[:, :, 2]) <= tolerance
    result = ~(rg_diff & rb_diff & gb_diff)
    return result

In [8]:
def filter_red(rgb, red_lower_thresh, green_upper_thresh, blue_upper_thresh):
    r = rgb[:, :, 0] > red_lower_thresh
    g = rgb[:, :, 1] < green_upper_thresh
    b = rgb[:, :, 2] < blue_upper_thresh
    result = ~(r & g & b)
    return result

In [9]:
def filter_red_pen(rgb):
    result = filter_red(rgb, red_lower_thresh=150, green_upper_thresh=80, blue_upper_thresh=90) & \
               filter_red(rgb, red_lower_thresh=110, green_upper_thresh=20, blue_upper_thresh=30) & \
               filter_red(rgb, red_lower_thresh=185, green_upper_thresh=65, blue_upper_thresh=105) & \
               filter_red(rgb, red_lower_thresh=195, green_upper_thresh=85, blue_upper_thresh=125) & \
               filter_red(rgb, red_lower_thresh=220, green_upper_thresh=115, blue_upper_thresh=145) & \
               filter_red(rgb, red_lower_thresh=125, green_upper_thresh=40, blue_upper_thresh=70) & \
               filter_red(rgb, red_lower_thresh=200, green_upper_thresh=120, blue_upper_thresh=150) & \
               filter_red(rgb, red_lower_thresh=100, green_upper_thresh=50, blue_upper_thresh=65) & \
               filter_red(rgb, red_lower_thresh=85, green_upper_thresh=25, blue_upper_thresh=45)
    return result

In [10]:
def filter_green(rgb, red_upper_thresh, green_lower_thresh, blue_lower_thresh):
    r = rgb[:, :, 0] < red_upper_thresh
    g = rgb[:, :, 1] > green_lower_thresh
    b = rgb[:, :, 2] > blue_lower_thresh
    result = ~(r & g & b)
    return result

In [11]:
def filter_green_pen(rgb):
    result = filter_green(rgb, red_upper_thresh=150, green_lower_thresh=160, blue_lower_thresh=140) & \
               filter_green(rgb, red_upper_thresh=70, green_lower_thresh=110, blue_lower_thresh=110) & \
               filter_green(rgb, red_upper_thresh=45, green_lower_thresh=115, blue_lower_thresh=100) & \
               filter_green(rgb, red_upper_thresh=30, green_lower_thresh=75, blue_lower_thresh=60) & \
               filter_green(rgb, red_upper_thresh=195, green_lower_thresh=220, blue_lower_thresh=210) & \
               filter_green(rgb, red_upper_thresh=225, green_lower_thresh=230, blue_lower_thresh=225) & \
               filter_green(rgb, red_upper_thresh=170, green_lower_thresh=210, blue_lower_thresh=200) & \
               filter_green(rgb, red_upper_thresh=20, green_lower_thresh=30, blue_lower_thresh=20) & \
               filter_green(rgb, red_upper_thresh=50, green_lower_thresh=60, blue_lower_thresh=40) & \
               filter_green(rgb, red_upper_thresh=30, green_lower_thresh=50, blue_lower_thresh=35) & \
               filter_green(rgb, red_upper_thresh=65, green_lower_thresh=70, blue_lower_thresh=60) & \
               filter_green(rgb, red_upper_thresh=100, green_lower_thresh=110, blue_lower_thresh=105) & \
               filter_green(rgb, red_upper_thresh=165, green_lower_thresh=180, blue_lower_thresh=180) & \
               filter_green(rgb, red_upper_thresh=140, green_lower_thresh=140, blue_lower_thresh=150) & \
               filter_green(rgb, red_upper_thresh=185, green_lower_thresh=195, blue_lower_thresh=195)
    return result

In [12]:
def filter_blue(rgb, red_upper_thresh, green_upper_thresh, blue_lower_thresh):
    r = rgb[:, :, 0] < red_upper_thresh
    g = rgb[:, :, 1] < green_upper_thresh
    b = rgb[:, :, 2] > blue_lower_thresh
    result = ~(r & g & b)
    return result

In [13]:
def filter_blue_pen(rgb):
    result = filter_blue(rgb, red_upper_thresh=60, green_upper_thresh=120, blue_lower_thresh=190) & \
               filter_blue(rgb, red_upper_thresh=120, green_upper_thresh=170, blue_lower_thresh=200) & \
               filter_blue(rgb, red_upper_thresh=175, green_upper_thresh=210, blue_lower_thresh=230) & \
               filter_blue(rgb, red_upper_thresh=145, green_upper_thresh=180, blue_lower_thresh=210) & \
               filter_blue(rgb, red_upper_thresh=37, green_upper_thresh=95, blue_lower_thresh=160) & \
               filter_blue(rgb, red_upper_thresh=30, green_upper_thresh=65, blue_lower_thresh=130) & \
               filter_blue(rgb, red_upper_thresh=130, green_upper_thresh=155, blue_lower_thresh=180) & \
               filter_blue(rgb, red_upper_thresh=40, green_upper_thresh=35, blue_lower_thresh=85) & \
               filter_blue(rgb, red_upper_thresh=30, green_upper_thresh=20, blue_lower_thresh=65) & \
               filter_blue(rgb, red_upper_thresh=90, green_upper_thresh=90, blue_lower_thresh=140) & \
               filter_blue(rgb, red_upper_thresh=60, green_upper_thresh=60, blue_lower_thresh=120) & \
               filter_blue(rgb, red_upper_thresh=110, green_upper_thresh=110, blue_lower_thresh=175)
    return result

In [14]:
def filter_remove_small_objects(np_img, min_size=3000, avoid_overmask=True, overmask_thresh=95):
    rem_sm = tf.cast(np_img,tf.bool) 
    rem_sm = sk_morphology.remove_small_objects(rem_sm.numpy(), min_size=min_size)
    mask_percentage = mask_percent(rem_sm)
    if (mask_percentage >= overmask_thresh) and (min_size >= 1) and (avoid_overmask is True):
        new_min_size = min_size / 2
        rem_sm = filter_remove_small_objects(np_img, new_min_size, avoid_overmask, overmask_thresh)
    np_img = rem_sm
    return np_img


In [15]:
def mask_rgb(rgb, mask):
    mask_1 = tf.stack([mask,mask,mask] , axis=-1)
    mask_1 = tf.cast(mask_1,tf.uint8)
    result = rgb*mask_1
    return result

In [16]:
def filter_green_channel(np_img, green_thresh=200, avoid_overmask=True, overmask_thresh=90):
    g = np_img[:, :, 1]
    gr_ch_mask = (g < green_thresh) & (g > 0)
    mask_percentage = mask_percent(gr_ch_mask)
    if (mask_percentage >= overmask_thresh) and (green_thresh < 255) and (avoid_overmask is True):
        new_green_thresh = math.ceil((255 - green_thresh) / 2 + green_thresh)
        gr_ch_mask = filter_green_channel(np_img, new_green_thresh, avoid_overmask, overmask_thresh)
    np_img = gr_ch_mask
    return np_img

In [17]:
def apply_image_filters(rgb):
    mask_not_green = filter_green_channel(rgb)
    rgb_not_green = mask_rgb(rgb, mask_not_green)
    
    mask_not_gray = filter_grays(rgb)
    rgb_not_gray = mask_rgb(rgb, mask_not_gray)

    mask_no_red_pen = filter_red_pen(rgb)
    rgb_no_red_pen = mask_rgb(rgb, mask_no_red_pen)

    mask_no_green_pen = filter_green_pen(rgb)
    rgb_no_green_pen = mask_rgb(rgb, mask_no_green_pen)

    mask_no_blue_pen = filter_blue_pen(rgb)
    rgb_no_blue_pen = mask_rgb(rgb, mask_no_blue_pen)
    mask_gray_green_pens = mask_not_gray & mask_not_green & mask_no_red_pen & mask_no_green_pen & mask_no_blue_pen
    rgb_gray_green_pens = mask_rgb(rgb, mask_gray_green_pens)
    

    mask_remove_small = filter_remove_small_objects(mask_gray_green_pens, min_size=500)
    rgb_remove_small = mask_rgb(rgb, mask_remove_small)
    not_greenish = filter_green(rgb_remove_small, red_upper_thresh=125, green_lower_thresh=30, blue_lower_thresh=30)
    not_grayish = filter_grays(rgb_remove_small, tolerance=30)
    rgb_new = mask_rgb(rgb_remove_small, not_greenish & not_grayish)
    return rgb_new

In [18]:
import tensorflow as tf
import matplotlib.pyplot as plt

In [113]:
file_paths = []

for folder, subs, files in os.walk('/home/roshan/tcga_images'):
    for filename in files:
        file_paths.append(os.path.abspath(os.path.join(folder, filename)))

In [120]:
import pandas as pd
pd.set_option('display.max_rows', None)
df = pd.DataFrame(file_paths,columns=['filename'])

In [121]:
df['label'] = df['filename'].apply(lambda x : x[38:40])

In [123]:
df['label'] = df['label'].astype('int32')

In [124]:
df['label'] = df['label'].replace(1,0).replace(11,1)

In [127]:
X = df['filename']
y = df['label']

In [128]:
from sklearn.model_selection import train_test_split
X_train, X_test, y_train, y_test = train_test_split(X, y,stratify=y, 
                                                    test_size=0.1,random_state=123)

In [130]:
X_train, X_valid, y_train, y_valid = train_test_split(X_train, y_train,stratify=y_train, 
                                                    test_size=0.1,random_state=123)

In [166]:
train_labels = list(y_valid)
svs_names = list(X_valid)

In [167]:
len(svs_names)

124

In [168]:
list_ds = tf.data.Dataset.from_tensor_slices((svs_names,train_labels))

In [169]:
def _bytes_feature(value):
    if isinstance(value, type(tf.constant(0))):
        value = value.numpy() 
    return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))

def _float_feature(value):
    return tf.train.Feature(float_list=tf.train.FloatList(value=[value]))

def _int64_feature(value):
    return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))

In [170]:
def serialize_example(feature0, feature1):
    feature = {
      'label': _int64_feature(feature1),
    'svs_name': _bytes_feature(feature0)
    }
    example_proto = tf.train.Example(features=tf.train.Features(feature=feature))
    return example_proto.SerializeToString()

In [171]:
def tf_serialize_example(f0,f1):
    tf_string = tf.py_function(
    serialize_example,
    (f0,f1),  
    tf.string)      
    return tf.reshape(tf_string, ())

In [172]:
serialized_features_dataset = list_ds.map(tf_serialize_example)

In [173]:
filename = 'valid.tfrecord'
writer = tf.data.experimental.TFRecordWriter(filename)
writer.write(serialized_features_dataset)

In [19]:
def get_num_tiles(rows, cols, row_tile_size, col_tile_size):
    num_row_tiles = tf.math.ceil(rows / row_tile_size)
    num_col_tiles = tf.math.ceil(cols / col_tile_size)
    return num_row_tiles, num_col_tiles

In [20]:
def get_tile_indices(rows, cols, row_tile_size, col_tile_size):
    indices = list()
    rows = tf.cast(rows,tf.float64)
    cols = tf.cast(cols,tf.float64)
    row_tile_size = tf.cast(row_tile_size,tf.float64)
    col_tile_size = tf.cast(col_tile_size,tf.float64)
    num_row_tiles, num_col_tiles = get_num_tiles(rows, cols, row_tile_size, col_tile_size)
    for r in tf.range(num_row_tiles):
        start_r = r * row_tile_size
        end_r = ((r + 1) * row_tile_size) if (r < num_row_tiles - 1) else rows
        for c in tf.range(num_col_tiles):
                start_c = c * col_tile_size
                end_c = ((c + 1) * col_tile_size) if (c < num_col_tiles - 1) else cols
                indices.append((start_r, end_r, start_c, end_c, r + 1, c + 1))
    return indices

In [21]:
def tissue_percent(np_img):
    return 100 - mask_percent(np_img)

In [22]:
def tissue_quantity(tissue_percentage):
    if tissue_percentage >= TISSUE_HIGH_THRESH:
        return 3
    elif (tissue_percentage >= TISSUE_LOW_THRESH) and (tissue_percentage < TISSUE_HIGH_THRESH):
        return 2
    elif (tissue_percentage > 0) and (tissue_percentage < TISSUE_LOW_THRESH):
        return 1
    else:
        return 0

In [23]:
def small_to_large_mapping(small_pixel, large_dimensions):
    small_x, small_y = small_pixel
    large_w, large_h = large_dimensions
    large_x = tf.math.round((large_w / SCALE_FACTOR) / tf.math.floor(large_w / SCALE_FACTOR) * (SCALE_FACTOR * small_x))
    large_y = tf.math.round((large_h / SCALE_FACTOR) / tf.math.floor(large_h / SCALE_FACTOR) * (SCALE_FACTOR * small_y))
    return large_x, large_y

In [24]:
def filter_rgb_to_hsv(np_img):
    hsv = sk_color.rgb2hsv(np_img)
    return hsv

In [25]:
def filter_hsv_to_h(hsv, output_type="int"):
    h = hsv[:, :, 0]
    h = h.flatten()
    if output_type == "int":
        h *= 360
        h = h.astype("int")
    return h

In [26]:
def rgb_to_hues(rgb):
    hsv = filter_rgb_to_hsv(rgb)
    h = filter_hsv_to_h(hsv)
    return h

In [27]:
def hsv_purple_deviation(hsv_hues):
    purple_deviation = np.sqrt(np.mean(np.abs(hsv_hues - HSV_PURPLE) ** 2))
    return purple_deviation

In [28]:
def hsv_pink_deviation(hsv_hues):
    pink_deviation = np.sqrt(np.mean(np.abs(hsv_hues - HSV_PINK) ** 2))
    return pink_deviation

In [29]:
def hsv_purple_pink_factor(rgb):
    hues = rgb_to_hues(rgb)
    hues = hues[hues >= 260]  
    hues = hues[hues <= 340]  
    if len(hues) == 0:
        return 0 
    pu_dev = hsv_purple_deviation(hues)
    pi_dev = hsv_pink_deviation(hues)
    avg_factor = (340 - np.average(hues)) ** 2
    if pu_dev == 0:  
        return 0

    factor = pi_dev / pu_dev * avg_factor
    return factor

In [30]:
def filter_hsv_to_s(hsv):
    s = hsv[:, :, 1]
    s = s.flatten()
    return s

In [31]:
def filter_hsv_to_v(hsv):
    v = hsv[:, :, 2]
    v = v.flatten()
    return v

In [32]:
def hsv_saturation_and_value_factor(rgb):
    hsv = filter_rgb_to_hsv(rgb)
    s = filter_hsv_to_s(hsv)
    v = filter_hsv_to_v(hsv)
    s_std = np.std(s)
    v_std = np.std(v)
    if s_std < 0.05 and v_std < 0.05:
        factor = 0.4
    elif s_std < 0.05:
        factor = 0.7
    elif v_std < 0.05:
        factor = 0.7
    else:
        factor = 1

    factor = factor ** 2
    return factor

In [33]:
def tissue_quantity_factor(amount):
    if amount == 3:
        quantity_factor = 1.0
    elif amount == 2:
        quantity_factor = 0.2
    elif amount == 1:
        quantity_factor = 0.1
    else:
        quantity_factor = 0.0
    return quantity_factor

In [34]:
def score_tile(np_tile, tissue_percent,row, col):
    color_factor = hsv_purple_pink_factor(np_tile)
    s_and_v_factor = hsv_saturation_and_value_factor(np_tile)
    amount = tissue_quantity(tissue_percent)
    quantity_factor = tissue_quantity_factor(amount)
    combined_factor = color_factor * s_and_v_factor * quantity_factor
    score = (tissue_percent ** 2) * np.log(1 + combined_factor) / 1000.0
    score = 1.0 - (10.0 / (10.0 + score))
    return score, color_factor, s_and_v_factor, quantity_factor

In [35]:
def scores(filtered_image,f0,f1,f2,f3,f4):
    batchless_image = tf.squeeze(filtered_image,axis=0)
    o_w, o_h, w, h = f0,f1,f2,f3
    row_tile_size = tf.math.round(ROW_TILE_SIZE / SCALE_FACTOR)  
    col_tile_size = tf.math.round(COL_TILE_SIZE / SCALE_FACTOR)
    tile_indices = get_tile_indices(h, w, row_tile_size, col_tile_size)
    result = []
    patched_image = tf.zeros(
        [0, 299, 299, 3], dtype=tf.dtypes.uint8)
    for t in tile_indices:
        r_s, r_e, c_s, c_e, r, c = t
        small_tile = tf.slice(batchless_image,[tf.cast(r_s,tf.int64),tf.cast(c_s,tf.int64),0] 
                           ,[tf.cast(r_e-r_s,tf.int64),tf.cast(c_e-c_s,tf.int64),3])
        t_p = tissue_percent(small_tile)
        amount = tissue_quantity(t_p)
        o_c_s, o_r_s = small_to_large_mapping((c_s, r_s), (o_w, o_h))
        o_c_e, o_r_e = small_to_large_mapping((c_e, r_e), (o_w, o_h))
        if (o_c_e- o_c_s) > COL_TILE_SIZE:
            o_c_e -= 1
        if (o_r_e - o_r_s) > ROW_TILE_SIZE:
            o_r_e -= 1
        score, color_factor, s_and_v_factor, quantity_factor = score_tile(small_tile, t_p,r, c)
        if score>0:
            result.append({'o_c_s':o_c_s, 'o_r_s':o_r_s,'o_c_e':o_c_e,'o_r_e':o_r_e,'score':score})
    res = sorted(result, key = lambda i: i['score'],reverse=True) 
    tile_detail = res[:32]
    for i in range(len(tile_detail)):
        s = openslide.open_slide(f4.numpy().decode('utf-8'))
        x, y = tile_detail[i]['o_c_s'] , tile_detail[i]['o_r_s']
        w, h = tile_detail[i]['o_c_e'] - tile_detail[i]['o_c_s'], tile_detail[i]['o_r_e'] - tile_detail[i]['o_r_s']
        tile_region = s.read_region((int(x), int(y)), 0, (int(w), int(h)))
        patch = np.array(tile_region.convert("RGB").resize((299,299)))
        patched_image = tf.concat([patched_image, tf.expand_dims(patch, 0)], axis=0)
    return patched_image

In [36]:
@tf.function
def read_record(example_proto):
    feature = {
      'label': tf.io.FixedLenFeature([], tf.int64),
        'svs_name': tf.io.FixedLenFeature([], tf.string)
    }
    return tf.io.parse_single_example(example_proto, feature)

In [37]:
@tf.function
def tf_scale_image(features):
    f0 = features['svs_name']
    [image,large_w, large_h, new_w, new_h] = tf.py_function(slide_to_scaled_pil_image, [f0], [tf.uint8,tf.float64,tf.float64,tf.float64,tf.float64])
    return image,large_w, large_h, new_w, new_h, features

In [38]:
@tf.function
def tf_filter_image(image,large_w, large_h, new_w, new_h, features): 
    im_shape = image.shape
    [image,] = tf.py_function(apply_image_filters, [image], [tf.uint8])
    image.set_shape(im_shape)
    return image,large_w, large_h, new_w, new_h, features

In [39]:
@tf.function
def tf_tile(image,large_w, large_h, new_w, new_h, features):
    f0 = large_w[0]
    f1 = large_h[0]
    f2 = new_w[0]
    f3 = new_h[0]
    f4 = features['svs_name'][0]
    [final_img,] = tf.py_function(scores, [image,f0,f1,f2,f3,f4], [tf.uint8])
    return final_img,features

In [40]:
def flat_map_impl(tiled_images,features):
    lab = tf.one_hot(features['label'],depth=3)
    label = tf.repeat(lab,32,axis=0)
    return tf.data.Dataset.from_tensor_slices((tiled_images,label))

In [51]:
train_file_pattern = 'train.tfrecord'
train_files = tf.data.Dataset.list_files(train_file_pattern)

In [52]:
train_dataset = train_files.interleave(tf.data.TFRecordDataset,
                                cycle_length=4,
                                block_length=1,
                                num_parallel_calls=tf.data.experimental.AUTOTUNE)

In [53]:
train_dataset = train_dataset.map(read_record,num_parallel_calls=tf.data.experimental.AUTOTUNE)

In [54]:
train_dataset = train_dataset.map(tf_scale_image,num_parallel_calls=tf.data.experimental.AUTOTUNE)

In [55]:
train_dataset = train_dataset.map(tf_filter_image,num_parallel_calls=tf.data.experimental.AUTOTUNE).batch(1)

In [56]:
train_dataset = train_dataset.map(tf_tile,num_parallel_calls=tf.data.experimental.AUTOTUNE).cache()

In [57]:
train_dataset = train_dataset.flat_map(flat_map_impl).batch(10)

In [58]:
train_dataset = train_dataset.prefetch(buffer_size = tf.data.experimental.AUTOTUNE)

In [59]:
from tensorflow import keras
model = tf.keras.applications.InceptionV3(
 include_top=False, weights=None, input_shape=(299,299,3)
)

METRICS = [
      keras.metrics.TruePositives(name='tp'),
      keras.metrics.FalsePositives(name='fp'),
      keras.metrics.TrueNegatives(name='tn'),
      keras.metrics.FalseNegatives(name='fn'), 
      keras.metrics.BinaryAccuracy(name='accuracy'),
      keras.metrics.Precision(name='precision'),
      keras.metrics.Recall(name='recall'),
      keras.metrics.AUC(name='auc')
]
x = model.output
x = keras.layers.GlobalAveragePooling2D()(x)                 
predictions = keras.layers.Dense(3,activation='softmax')(x) 
model = keras.models.Model(inputs=model.input, outputs=predictions)
model.compile(optimizer='adam', loss = 'categorical_crossentropy',metrics=METRICS)

In [60]:
model.summary()

Model: "functional_1"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_1 (InputLayer)            [(None, 299, 299, 3) 0                                            
__________________________________________________________________________________________________
conv2d (Conv2D)                 (None, 149, 149, 32) 864         input_1[0][0]                    
__________________________________________________________________________________________________
batch_normalization (BatchNorma (None, 149, 149, 32) 96          conv2d[0][0]                     
__________________________________________________________________________________________________
activation (Activation)         (None, 149, 149, 32) 0           batch_normalization[0][0]        
_______________________________________________________________________________________

In [None]:
model.fit(train_dataset,epochs=2)

Epoch 1/2
