VAE viewer notebook for JetBot
===

This notebook can visualize vae. This repository using JetBot real camera.

In [None]:
import sys
import PIL
import numpy as np
import cv2
import traitlets
import ipywidgets.widgets as widgets
from IPython.display import display
import torch
from torchvision.transforms import transforms
sys.path.append('../../../vae')
from vae import VAE
from jetcam.csi_camera import CSICamera
from jetcam.utils import bgr8_to_jpeg

## Setting Parameter

|Name | Description| Default|
|:----|:-----------|:-------|
|IMAGE_CHANNELS | Image channel such as RGB | 3 Not change|
|VARIANTS_SIZE  | Variants size of VAE      | 32          |
|MODEL_PATH     | Trained VAE model file path | ../../vae.torch|

In [None]:
IMAGE_CHANNELS = 3
VARIANTS_SIZE = 32
MODEL_PATH = '../../../vae.torch'

## Load trained VAE model.
Loading trained VAE model on GPU memory. 

In [None]:
device = torch.device('cuda')
vae = VAE(image_channels=IMAGE_CHANNELS, z_dim=VARIANTS_SIZE)
vae.load_state_dict(torch.load(MODEL_PATH, map_location=torch.device(device)))
vae.to(device).eval()

## Create camera 

Capture size is W=320, H=240. 

In [None]:
CAMERA_WIDTH = 320
CAMERA_HEIGHT = 240
FPS = 60
camera = CSICamera(width=CAMERA_WIDTH, height=CAMERA_HEIGHT, capture_width=CAMERA_WIDTH,
                   capture_height=CAMERA_HEIGHT, capture_fps=FPS)
camera.running = True

image = widgets.Image(format='jpeg', width=CAMERA_WIDTH, height=CAMERA_HEIGHT)
camera_link = traitlets.dlink((camera,'value'), (image,'value'), transform=bgr8_to_jpeg)

## Define preprocess and postprocess

In [None]:
def preprocess(image):
    observe = PIL.Image.fromarray(image)
    observe = observe.resize((160,120))
    croped = observe.crop((0, 40, 160, 120))
    tensor = transforms.ToTensor()(croped)
    return tensor
    

def rgb8_to_jpeg(image):
    return bytes(cv2.imencode('.jpg', image)[1])

## Visualize latent space function

In [None]:
ABS_LATENT_MAX_VALUE = 10
PANEL_HEIGHT = 10
PANEL_WIDTH = 10

def sigmoid(x, gain=1, offset_x=0):
    return ((np.tanh(((x+offset_x)*gain)/2)+1)/2)

def color_bar_rgb(x):
    gain = 10
    offset_x= 0.2
    offset_green = 0.6
    x = (x * 2) - 1
    red = sigmoid(x, gain, -1*offset_x)
    blue = 1-sigmoid(x, gain, offset_x)
    green = sigmoid(x, gain, offset_green) + (1-sigmoid(x,gain,-1*offset_green))
    green = green - 1.0
    return [blue * 255,green * 255,red * 255]

def _get_color(value):
    t = (value + ABS_LATENT_MAX_VALUE) / (ABS_LATENT_MAX_VALUE * 2.0)
    color = color_bar_rgb(t)
    return color

def create_color_panel(latent_spaces):
    images = []
    for z in latent_spaces:
        p = np.zeros((PANEL_HEIGHT, PANEL_WIDTH, 3))
        color = _get_color(z)
        p += color[::-1]
        p = np.clip(p, 0, 255)
        images.append(p)
    panel = np.concatenate(images, axis=1)
    return panel

# Create GUI

In [None]:
image = widgets.Image(format='jpeg', width=320, height=240)
resize = widgets.Image(format='jpeg', width=160, height=80)
result = widgets.Image(format='jpeg', width=160, height=80)
camera_link = traitlets.dlink((camera,'value'), (image,'value'), transform=bgr8_to_jpeg)
color_bar = widgets.Image(format='jpeg', width=32*PANEL_WIDTH, height=10*PANEL_HEIGHT)
display(image)
display(widgets.HBox([resize,result]))
display(color_bar)

## Start main process

In [None]:
def vae_process(change):
    image = change['new']
    image = preprocess(image)
    resize.value = rgb8_to_jpeg(np.transpose(np.uint8(image*255),[1,2,0]))
    z, _ ,_ = vae.encode(torch.stack((image,image),dim=0)[:-1].to(device))
    reconst = vae.decode(z)
    reconst = reconst.detach().cpu()[0].numpy()
    reconst = np.transpose(np.uint8(reconst*255),[1,2,0])
    result.value = rgb8_to_jpeg(reconst)
    latent_space = z.detach().cpu().numpy()[0]
    color_bar.value = rgb8_to_jpeg(create_color_panel(latent_space))
vae_process({'new': camera.value})
camera.observe(vae_process, names='value')

## Cleanup process

In [None]:
camera.unobserve(vae_process, names='value')
camera_link.unlink()