**A tool to visualize the segmentation model inference output.**\
This tool is used verify that the exported tflite can produce expected segmentation results.



In [None]:
MODEL='gs://**/placeholder_for_edgetpu_models/autoseg/segmentation_search_edgetpu_s_not_fused.tflite'#@param
IMAGE_HOME = 'gs://**/PS_Compare/20190711'#@param
# Relative image file names separated by comas.
TEST_IMAGES = 'ADE_val_00001626.jpg,ADE_val_00001471.jpg,ADE_val_00000557.jpg'#@param
IMAGE_WIDTH = 512 #@param
IMAGE_HEIGHT = 512 #@param

In [None]:
import numpy as np
import tensorflow as tf
from PIL import Image as PILImage
import matplotlib.pyplot as plt
from scipy import ndimage

In [None]:
# This block creates local copies of /cns and /x20 files.
TEST_IMAGES=','.join([IMAGE_HOME+'/'+image for image in TEST_IMAGES.split(',')])

# The tflite interpreter only accepts model in local path.
def local_copy(awaypath):
  localpath = '/tmp/' + awaypath.split('/')[-1]
  !rm -f {localpath}
  !fileutil cp -f {awaypath} {localpath}
  !ls -lht {localpath}
  %download_file {localpath}
  return localpath

IMAGES = [local_copy(image) for image in TEST_IMAGES.split(',')]
MODEL_COPY=local_copy(MODEL)

In [None]:
# Creates a 6px wide boolean edge mask to highlight the segmentation.
def edge(mydata):
  mydata = mydata.reshape(512, 512)
  mydatat = mydata.transpose([1, 0])
  mydata = np.convolve(mydata.reshape(-1), [-1, 0, 1], mode='same').reshape(512, 512)
  mydatat = np.convolve(mydatat.reshape(-1), [-1, 0, 1], mode='same').reshape(512, 512).transpose([1, 0])
  mydata = np.maximum((mydata != 0).astype(np.int8), (mydatat != 0).astype(np.int8))
  mydata = ndimage.binary_dilation(mydata).astype(np.int8)
  mydata = ndimage.binary_dilation(mydata).astype(np.int8)
  mydata = ndimage.binary_dilation(mydata).astype(np.int8)
  return mydata

In [None]:
def run_model(input_data):
  _input_data = input_data
  _input_data = (_input_data-128).astype(np.int8)
  # Load the tflite model and allocate tensors.
  interpreter_x = tf.lite.Interpreter(model_path=MODEL_COPY)
  interpreter_x.allocate_tensors()
  # Get input and output tensors.
  input_details = interpreter_x.get_input_details()
  output_details = interpreter_x.get_output_details()
  interpreter_x.set_tensor(input_details[0]['index'], _input_data)
  interpreter_x.invoke()
  output_data = interpreter_x.get_tensor(output_details[0]['index'])
  return output_data.reshape((512, 512, 1))

In [None]:
# Set visualization wind sizes.
fig, ax = plt.subplots(max(len(IMAGES),2), 3)
fig.set_figwidth(30)
fig.set_figheight(10*max(len(IMAGES),2))

# Read and test image.
for r, image in enumerate(IMAGES):
  im = PILImage.open(image).convert('RGB')
  min_dim=min(im.size[0], im.size[1])
  im = im.resize((IMAGE_WIDTH*im.size[0] // min_dim, IMAGE_HEIGHT*im.size[1] // min_dim))
  input_data = np.expand_dims(im, axis=0)
  input_data = input_data[:, :IMAGE_WIDTH,:IMAGE_HEIGHT]
  ax[r, 0].imshow(input_data.reshape([512, 512, 3]).astype(np.uint8))
  ax[r, 0].set_title('Original')
  ax[r, 0].grid(False)

  # Test the model on random input data.
  output_data = run_model(input_data)
  ax[r, 1].imshow(output_data, vmin = 0, vmax = 32)
  ax[r, 1].set_title('Segmentation')
  ax[r, 1].grid(False)

  output_data = np.reshape(np.minimum(output_data, 32), [512,512])
  output_edge = edge(output_data).reshape(512,512, 1)
  output_data = np.stack([output_data%3, (output_data//3)%3, (output_data//9)%3], axis = -1)
  
  output_data = input_data.reshape([512, 512, 3]).astype(np.float32) * (1-output_edge) + output_data * output_edge * 255
  ax[r, 2].imshow(output_data.astype(np.uint8), vmin = 0, vmax = 256)
  ax[r, 2].set_title('Segmentation & original')
  ax[r, 2].grid(False)
