[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/tensorflow/models/blob/master/official/projects/edgetpu/vision/serving/inference_visualization_tool.ipynb)

# Visualizing segmentation outputs using colab.

This file is located in [github](https://github.com/tensorflow/models/blob/master/official/projects/edgetpu/vision/serving/inference_visualization_tool.ipynb) and uses [colab integration](https://colab.research.google.com/github/tensorflow/models/blob/master/official/projects/edgetpu/vision/serving/inference_visualization_tool.ipynb) to seemlessly show [segmentation model](https://github.com/tensorflow/models/blob/master/official/projects/edgetpu/vision/README.md) outputs.

## Setup sandbox

Imports required libs and get ready to load data.

In [None]:
from google.colab import auth # access to saved model in tflite format
auth.authenticate_user()
from PIL import Image # used to read images as arrays
import tensorflow as tf # runs tested model
import numpy as np # postprocessing for render.
from scipy import ndimage # postprocessing for render.
import matplotlib.pyplot as plt # render

# Copies reference to colab's sandbox.
def copy_to_sandbox(web_path):
  sandbox_path = web_path.split('/')[-1]
  !rm -f {sandbox_path}
  if web_path[:2] == "gs":
    !gsutil cp {web_path} {sandbox_path}
  else:
    !wget -v {web_path} --no-check-certificate
  return sandbox_path


## Prepare sandbox images

Running this notebook will show sample segmentation of 3 pictures from [ADE20K](https://groups.csail.mit.edu/vision/datasets/ADE20K/) dataset. You can try it on other pictures by adding you own URLS to `IMAGE_URLS` list.

In [None]:
# Common image URL pattern.
_IMAGE_URL_PATTERN = 'https://raw.githubusercontent.com/tensorflow/models/master/official/projects/edgetpu/vision/serving/testdata/ADE_val_{name}.jpg'
# Coma separated list of image ids.
_IMAGE_NAMES = ['00001626','00001471','00000557']
# List
IMAGE_URLS = [_IMAGE_URL_PATTERN.replace('{name}', image) for image in _IMAGE_NAMES]
# IMAGE_URLS.append('your URL')

In [None]:
IMAGES = [copy_to_sandbox(image_url) for image_url in IMAGE_URLS]

## Prepare sandbox model

Default visualize is running M-size model. Model is copiend to sandbox to run.You can use another model from the list.

In [None]:
MODEL_HOME='gs://tf_model_garden/models/edgetpu/checkpoint_and_tflite/vision/segmentation-edgetpu/tflite/default_argmax'
!gsutil ls {MODEL_HOME}

In [None]:
# Path to tflite file, can use any other from list above.
MODEL_NAME='deeplabv3plus_mobilenet_edgetpuv2_m_ade20k_32.tflite'#@param
MODEL = copy_to_sandbox(MODEL_HOME + "/" + MODEL_NAME)

In [None]:
# Image sizes compatible with the model
MODEL_IMAGE_WIDTH = 512
MODEL_IMAGE_HEIGHT = 512

## Image preprocess

Function defines how to preprocess image before running inference

In [None]:
def read_image(image):
  im = Image.open(image).convert('RGB')
  min_dim=min(im.size[0], im.size[1])
  new_y_dim = MODEL_IMAGE_HEIGHT * im.size[0] // min_dim
  new_x_dim = MODEL_IMAGE_WIDTH * im.size[1] // min_dim
  # scale to outer fit.
  im = im.resize((new_y_dim, new_x_dim))
  input_data = np.expand_dims(im, axis=0)
  # crop to size
  return input_data[:, :MODEL_IMAGE_HEIGHT, :MODEL_IMAGE_WIDTH]


## Model runner.

Simple wrapper of tflite interpreter  invoke.

In [None]:
def run_model(input_data, model_data):
  preprocessed_data = (input_data-128).astype(np.int8)
  # Load the tflite model and allocate tensors.
  interpreter_x = tf.lite.Interpreter(model_path=model_data)
  interpreter_x.allocate_tensors()
  # Get input and output tensors.
  input_details = interpreter_x.get_input_details()
  output_details = interpreter_x.get_output_details()
  interpreter_x.set_tensor(input_details[0]['index'], preprocessed_data)
  interpreter_x.invoke()
  output_data = interpreter_x.get_tensor(output_details[0]['index'])
  return output_data.reshape((MODEL_IMAGE_HEIGHT, MODEL_IMAGE_WIDTH))

## 6px wide edge highlighter.

First function bellow finds edges of classes, and highlights them with 6px edge. Second function blends edge with original image.

In [None]:
# Creates a 6px wide boolean edge mask to highlight the segmentation.
def edge(mydata):
  mydata = mydata.reshape((MODEL_IMAGE_HEIGHT, MODEL_IMAGE_WIDTH))
  mydatat = mydata.transpose([1, 0])
  mydata = np.convolve(mydata.reshape(-1), [-1, 0, 1], mode='same').reshape((MODEL_IMAGE_HEIGHT, MODEL_IMAGE_WIDTH))
  mydatat = np.convolve(mydatat.reshape(-1), [-1, 0, 1], mode='same').reshape((MODEL_IMAGE_HEIGHT, MODEL_IMAGE_WIDTH)).transpose([1, 0])
  mydata = np.maximum((mydata != 0).astype(np.int8), (mydatat != 0).astype(np.int8))
  mydata = ndimage.binary_dilation(mydata).astype(np.int8)
  mydata = ndimage.binary_dilation(mydata).astype(np.int8)
  mydata = ndimage.binary_dilation(mydata).astype(np.int8)
  return mydata

In [None]:
def fancy_edge_overlay(input_data, output_data):
  output_data = np.reshape(np.minimum(output_data, 32), (MODEL_IMAGE_HEIGHT, MODEL_IMAGE_WIDTH))
  output_edge = edge(output_data).reshape((MODEL_IMAGE_HEIGHT, MODEL_IMAGE_WIDTH,1))
  output_data = np.stack([output_data%3, (output_data//3)%3, (output_data//9)%3], axis = -1)
  return input_data.reshape((MODEL_IMAGE_HEIGHT, MODEL_IMAGE_WIDTH, 3)).astype(np.float32) * (1-output_edge) + output_data * output_edge * 255


## Visualize!

In [None]:
# Set visualization wind sizes.
fig, ax = plt.subplots(max(len(IMAGES),2), 3)
fig.set_figwidth(30)
fig.set_figheight(10*max(len(IMAGES),2))

# Read and test image.
for r, image in enumerate(IMAGES):
  input_data = read_image(image)
  ax[r, 0].imshow(input_data.reshape((MODEL_IMAGE_HEIGHT, MODEL_IMAGE_WIDTH, 3)).astype(np.uint8))
  ax[r, 0].set_title('Original')
  ax[r, 0].grid(False)

  # Test the model on input data.
  output_data = run_model(input_data, MODEL)
  ax[r, 1].imshow(output_data, vmin = 0, vmax = 32)
  ax[r, 1].set_title('Segmentation')
  ax[r, 1].grid(False)

  fancy_data = fancy_edge_overlay(input_data, output_data)
  ax[r, 2].imshow(fancy_data.astype(np.uint8), vmin = 0, vmax = 32)
  ax[r, 2].set_title('Segmentation & original')
  ax[r, 2].grid(False)
