# Example usage of vrn-unguided-keras.h5
## Modified by Qhan
* Rendering Texture
    * maplotlib.pyplot
    * visvis.mesh
    
* Download the h5py file here:
https://drive.google.com/file/d/1oh8Zpe4wh00iXcm8ztRsi5ZL6GMkHdjj/view?usp=sharing

In [None]:
from keras.models import load_model
import cv2
import visvis as vv
import numpy as np
from skimage import measure

import os
import os.path as osp

from matplotlib import pyplot as plt
from matplotlib.collections import PolyCollection
from mpl_toolkits.mplot3d import Axes3D
from mpl_toolkits.mplot3d.art3d import Poly3DCollection
import mcubes
import matplotlib
matplotlib.rcParams['figure.figsize'] = [10,10]

import custom_layers

In [None]:
custom_objects = {
    'Conv': custom_layers.Conv,
    'BatchNorm': custom_layers.BatchNorm,
    'UpSamplingBilinear': custom_layers.UpSamplingBilinear
}
model = load_model('vrn-unguided-keras.h5', custom_objects=custom_objects)

## Color Interpolation

In [None]:
def interp(x, v):
    d0, d1 = x, 1-x
    v0, v1 = v
    return (v0 * d1 + v1 * d0)

def interp2d(xy, v):
    x, y = xy
    u = interp(x, v[0])
    d = interp(x, v[1])
    return interp(y, [u, d])

def interpColors(verts, image):
    colors = []
    #im = cv2.GaussianBlur(image, (3, 3), 0)
    for v in verts:
        x, y = v[:2]
        c, r = int(x), int(y)
        corners = im[r:r+2, c:c+2].astype(float) / 256
        colors += [interp2d([x-c, y-r], corners)]
    return colors

## Read Image & Get 3D Model

In [None]:
im = cv2.imread('images/qhan-head-2.png')
im = cv2.resize(im, (192, 192))
im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB)
img = np.swapaxes(im, 2, 0)
img = np.swapaxes(img, 2, 1)
img = np.array([img])

In [None]:
pred = model.predict(img)
print(pred[0].shape)
vol = pred[0] * 255

## pyplot trisurf

In [None]:
plt.clf()
fig = plt.figure(figsize=(10, 10))
ax = fig.add_subplot(111, projection='3d')
verts, faces = mcubes.marching_cubes(vol, 20) # verts: x, y, z

ax.plot_trisurf(192-verts[:, 0], 192-verts[:, 1], faces, verts[:, 2], cmap='Spectral', lw=2)
ax.set_xlabel('X')
ax.set_ylabel('Y')
ax.set_zlabel('Z')
plt.show()

## pyplot polygons

In [None]:
verts, faces, normals, values = measure.marching_cubes_lewiner(vol, level=10, step_size=1)
print('vertices:', len(verts), '\nfaces:', len(faces))

# verts: (z, y, x) -> convert to (x, y, z)
x = np.array(verts[:, 2])
verts[:, 2] = np.array(verts[:, 0])
verts[:, 0] = np.array(x)

colors = interpColors(verts, im)

In [None]:
plt.clf()
fig = plt.figure(figsize=(10, 10))
ax = fig.add_subplot(111, projection='3d')

ax.scatter(verts[:, 0], verts[:, 1], color=colors, s=5)

for i, face in enumerate(faces):
    print('\r%d' % i, end='')
    triangle = verts[face]
    mesh = Poly3DCollection([triangle], facecolor=colors[face[0]]) # select one vertice as face color (flat)
    ax.add_collection3d(mesh)

ax.set_xlim(192, 0)
ax.set_ylim(192, 0)
ax.set_zlim(0, 200)
ax.set_xlabel('X')
ax.set_ylabel('Y')
ax.set_zlabel('Z')
plt.show()

## visvis
In mac OS, visvis demo has some displaying bugs. Check this issue: https://github.com/almarklein/visvis/issues/97

In [None]:
vv.settings.figureSize = (720, 720)

**vv.mesh**

In [None]:
def switch_axis(verts, a1, a2):
    a = np.array(verts[:, a1])
    verts[:, a1] = np.array(verts[:, a2])
    verts[:, a2] = np.array(a)
    return verts

In [None]:
verts, faces, normals, values = measure.marching_cubes_lewiner(vol, level=1, step_size=1)
print('vertice:', len(verts), '\nfaces:', len(faces))

# verts: (z, y, x) -> convert to (x, y, z)
verts = switch_axis(verts, 0, 2)

# interpolate colors
colors = interpColors(verts, im)

# in visvis, y is depth, need to switch y, z, axis
verts = switch_axis(verts, 1, 2)
verts[:, 1] = 200 - verts[:, 1]
verts[:, 2] = 192 - verts[:, 2]

In [None]:
vv.clf()
f = vv.gcf()
ax = vv.gca()

white_bg = np.zeros_like(im)
#t = vv.imshow(white_bg, interpolate=True)

m = vv.mesh(verts, faces, normals, colors)
m.faceShading = 'plain'
m.edgeShading = 'plain'

light0 = ax.light0
light0.ambient = 0.9 # 0.2 is default for light 0
light0.diffuse = 1.0 # 1.0 is default

camera = ax.camera
camera.fov = 0 # orthographic
camera.zoom = 0.0045

perspectives = [
    [  0,  0,  0], # center 
    [ 15,  0,  0], # down
    [-15,  0,  0], # up
    [  0,  0, 15], # left
    [  0,  0,-15], # right
    [  0, 15,  0], # counter clock wise
    [  0,-15,  0]  # clock wise
]

if 1:
    ax.axis.visible = False
    if not osp.exists('test'): os.mkdir('test')
    for i, (x, y, z) in enumerate(perspectives):
        camera.elevation = x
        camera.roll = y
        camera.azimuth = z
        ax.Draw()
        f.DrawNow()
        cv2.imwrite('test/%d.jpg' % i, cv2.cvtColor(vv.getframe(f) * 255, cv2.COLOR_RGB2BGR))

camera.elevation = 0 # x
camera.roll = 0
camera.azimuth = 0 # z

ax.axis.xLabel = 'X width'
ax.axis.yLabel = 'Y depth'
ax.axis.zLabel = 'Z height'

app = vv.use()
app.Run()

**record the result**

In [None]:
rec = vv.record(ax)

Nangles = 4
for i in range(Nangles):
    camera.azimuth = 360 * float(i) / Nangles
    if camera.azimuth>180:
        camera.azimuth -= 360
    ax.Draw() # Tell the axes to redraw
    f.DrawNow() # Draw the figure NOW, instead of waiting for GUI event loop

rec.Stop()
rec.Export('demo.gif')

**vv.volshow()**  *(old method)*

In [None]:
volRGB = np.stack(((vol > 1) * im[:,:,0],
                   (vol > 1) * im[:,:,1],
                   (vol > 1) * im[:,:,2]), axis=3)

vv.clf()

t = vv.imshow(im, interpolate=True)
v = vv.volshow3(volRGB, renderStyle='iso')

l = vv.gca()
l.light0.ambient = 0.9 # 0.2 is default for light 0
l.light0.diffuse = 1.0 # 1.0 is default

a = vv.gca()
a.camera.fov = 0 # orthographic

vv.use().Run()

## show z depth layers

In [None]:
plt.clf()

for i, p in enumerate(volRGB[80:180:5]):
    plt.subplot(5, 4, i + 1)
    plt.imshow(p)
    
plt.show()