# Demo of 3D Generative Model Latent Disentanglement via Local Eigenprojection

To run the demo download the `demo_files` folder and copy it in the project directory.

---

### Network initialisation

To select another network change the `net_id` variable.

In [None]:
%matplotlib notebook

import os
import trimesh
import torch
import numpy as np

import utils
from model_manager import get_model_manager

net_id = "led_vae"  # ["sd_vae", "led_vae"]
demo_directory = os.path.join("demo_tog", net_id)
configurations = utils.get_config(os.path.join(demo_directory, "config.yaml"))


if not torch.cuda.is_available():
    device = torch.device("cpu")
    print("GPU not available, running on CPU")
else:
    device = torch.device("cuda")

manager = get_model_manager(
    configurations=configurations, device=device,
    precomputed_storage_path=configurations['data']['precomputed_path'])
manager.resume(os.path.join(demo_directory, "checkpoints"))

normalization_dict_path = os.path.join(demo_directory, "norm.pt")
normalization_dict = torch.load(normalization_dict_path)

### Randomly generate a shape and create a scene

Every time you run this cell, a new shape is generated. Run it untli you are satisfied with the generated shape.

In [None]:
from pythreejs import *
from IPython.display import display
import ipywidgets

view_width = 800
view_height = 800

z = torch.randn([1, manager.model_latent_size])

gen_verts = manager.generate(z.to(device))[0, :, :] 
gen_verts = gen_verts * normalization_dict['std'].to(device) + \
    normalization_dict['mean'].to(device)
initially_gen_verts = gen_verts.clone()
faces = manager.template.face

def compute_vertex_colours(current_verts):
    source_dist = manager.compute_vertex_errors(current_verts, initially_gen_verts)
    colours = utils.errors_to_colors(source_dist.unsqueeze(dim=0), min_value=0,
                                     max_value=10, cmap='plasma') / 255
    return colours.squeeze().cpu().detach().numpy()


buffer_verts = BufferAttribute(gen_verts.detach().cpu().numpy().tolist(), normalized=False)
buffer_faces = BufferAttribute(np.uint32(faces.t().numpy().tolist()).ravel(), normalized=False)
buffer_colours = BufferAttribute(compute_vertex_colours(gen_verts).tolist(), normalized=False)

geometry = BufferGeometry(
    attributes={
        'position': buffer_verts,
        'index': buffer_faces,
        'color': buffer_colours
    })
geometry.exec_three_obj_method('computeVertexNormals')

material = MeshPhongMaterial(color="#34eb46", specular="#222222", shininess=15)
material_colours = MeshPhongMaterial(specular="#222222", shininess=15, vertexColors='VertexColors')

mesh = Mesh(geometry, material=material)
mesh_colours = Mesh(geometry, material=material_colours)

camera = PerspectiveCamera(position=[2, 0, 3], aspect=view_width/view_height)
ambient_light = AmbientLight(intensity=0.2)
ambient_light_dispmap = AmbientLight(intensity=1)
key_light = SpotLight(position=[0, 10, 10], angle = 0.3, penumbra = 0.1)

key_light.target = mesh
mesh.castShadow = True
mesh.receiveShadow = True
mesh_colours.castShadow = True
mesh_colours.receiveShadow = True

scene = Scene(children=[mesh, camera, key_light, ambient_light])
scene_colours = Scene(children=[mesh_colours, camera, ambient_light_dispmap])

controller = OrbitControls(controlling=camera)
renderer = Renderer(camera=camera, scene=scene, controls=[controller],
                    width=view_width, height=view_height, antialias=True)
renderer_colours = Renderer(camera=camera, scene=scene_colours, controls=[controller],
                            width=view_width/3, height=view_height/3, antialias=True)
renderers_pair = ipywidgets.HBox([renderer_colours, renderer])
renderers_pair.layout.align_items = "center"
display(renderers_pair)

### Create GUI

In [None]:
def update_vertices(z_i_value, z_i_index):
    z[0, z_i_index] = z_i_value.new
    verts = manager.generate(z.to(device))[0, :, :] 
    verts = verts * normalization_dict['std'].to(device) + \
        normalization_dict['mean'].to(device)
    colours = compute_vertex_colours(verts)
    verts = verts.detach().cpu().numpy()
    v = verts.astype("float32", copy=False)
    geometry.attributes["position"].array = v
    geometry.attributes["position"].needsUpdate = True
    mesh.geometry.exec_three_obj_method('computeVertexNormals')
    geometry.attributes["color"].array = colours.astype("float32", copy=False)
    geometry.attributes["color"].needsUpdate = True
    mesh.geometry.verticesNeedUpdate = True
    mesh.geometry.elementsNeedUpdate = True
    mesh.geometry.colorsNeedUpdate = True
    mesh.exec_three_obj_method('update')
    mesh_colours.geometry.verticesNeedUpdate = True
    mesh_colours.geometry.elementsNeedUpdate = True
    mesh_colours.geometry.colorsNeedUpdate = True
    mesh_colours.exec_three_obj_method('update')
    controller.exec_three_obj_method('update')
    camera.exec_three_obj_method('updateProjectionMatrix')
    scene.exec_three_obj_method('update')
    scene_colours.exec_three_obj_method('update')
    

color2name_dict = {'[ 70  82 164 255]': "eyes", '[ 12 120  60 255]': "ears", '[242 235 119 255]': "temporal", 
                   '[ 57  78 162 255]': "neck", '[255 255 255 255]': "back", '[246 133  31 255]': "mouth", 
                   '[110 202 206 255]': "chin", '[136  92 167 255]': "cheeks", '[109 190  69 255]': "cheekbones", 
                   '[160  80 159 255]': "forehead", '[ 35  41  73 255]': "jaw", '[237  26  77 255]': "nose"}

region_sliders_names = []
region_sliders = []
grouped_region_indices = []
for r_name, r_range in manager.latent_regions.items():
    region_indices = list(range(r_range[0], r_range[1]))
    grouped_region_indices.append(region_indices)
    sl = []
    for i in region_indices:
        current_s = ipywidgets.FloatSlider(min=-2.5, max=2.5, step=0.05, value=z[0, i], description=f"z_{str(i)}")
        current_s.observe(lambda x, y=i: update_vertices(x, y), names='value')
        sl.append(current_s)
    region_sliders.append(ipywidgets.VBox(sl))
    region_sliders_names.append(color2name_dict[r_name])

In [None]:
out_mesh_dir = os.path.join(demo_directory, "out_meshes")
if not os.path.isdir(out_mesh_dir):
    os.mkdir(out_mesh_dir)

fname_widget = ipywidgets.Text(placeholder="mesh_name.ply", description="Filename:", disabled=False)
out_message_widget = ipywidgets.Output()


def compute_vertex_colours(current_verts):
    source_dist = manager.compute_vertex_errors(current_verts, initially_gen_verts)
    colours = utils.errors_to_colors(source_dist.unsqueeze(dim=0), min_value=0,
                                     max_value=10, cmap='plasma') / 255
    return colours.squeeze().cpu().detach().numpy()


def save_current_mesh(b):
    verts = manager.generate(z.to(device))[0, ::] 
    verts = verts * normalization_dict['std'].to(device) + \
        normalization_dict['mean'].to(device)
    v_col = compute_vertex_colours(verts)
    mesh = trimesh.Trimesh(
        verts.cpu().detach().numpy(),
        manager.template.face.t().cpu().numpy(),
        vertex_colors=v_col)
    
    fname = fname_widget.value
    if fname.endswith(".ply") or fname.endswith(".ply"):
        mesh.export(os.path.join(out_mesh_dir, fname))
        with out_message_widget:
            print("Mesh saved!")
    else:
        with out_message_widget:
            print(f"'{fname}' is not a valid meshfile name. Make sure it finishes in '.ply' or '.obj'")

    
save_button_widget = ipywidgets.Button(description="Save", disabled=False, button_style='info')
save_button_widget.on_click(save_current_mesh)
saving_widgets = ipywidgets.HBox([fname_widget, save_button_widget, out_message_widget])

### Run the GUI 

Each slider corresponds to a latent variable. When sliders are changed, the VAE generates a new shapes that is displayed in real time. Sliders are grouped according to the anatomical region that they are influencing. On the left side a displacement map shows the regions that were altered while manipulationg the latent variables with the sliders. Displacements are computed from the random mesh that was initially generated. When you are satisfied with the final result, you can also save the mesh.

In [None]:
accordion = ipywidgets.Accordion(children=region_sliders, titles=region_sliders_names)
ipywidgets.VBox([ipywidgets.HBox([renderers_pair, accordion]), saving_widgets])