# Demo of "Latent Disentanglement in Mesh Variational Autoencoders Improves the Diagnosis of Craniofacial Syndromes and Aids Surgical Planning"

### Simone Foti, Alexander J. Rickart, Bongjin Koo, Eimear O’ Sullivan, Lara S. van de Lande, Athanasios Papaioannou, Roman Khonsari, Danail Stoyanov, N. u. Owase Jeelani, Silvia Schievano, David J. Dunaway, Matthew J. Clarkson

Before running this notebook, make sure you have followed the installation instructions detailed in the README.md file.

----

## Import all necessary libraries and initialise the model 

In [None]:
%matplotlib notebook

import os
import pickle
import trimesh
import torch
import scipy.stats
import scipy.linalg
import numpy as np

import utils
from model_manager import ModelManager


demo_directory = "demo_files"
configurations = utils.get_config(os.path.join(demo_directory, "config.yaml"))


if not torch.cuda.is_available():
    device = torch.device("cpu")
    print("GPU not available, running on CPU")
else:
    device = torch.device("cuda")

manager = ModelManager(
    configurations=configurations, device=device,
    precomputed_storage_path=configurations['data']['precomputed_path'])
manager.resume(os.path.join(demo_directory, "checkpoints"))
manager.set_class_conversions({'a': 0, 'm': 1, 'c': 2, 'n': 3, 'b': 4})  # b and n are merged for classification

label2name_dict = {'a': "Apert", 'b': "Healthy", 'c': "Crouzon", 
                   'm': "Muenke", 'n': "Healthy"}

## Load the demo meshes
Note that these meshes are not from real subjects and were obtained with our data augmentation based on spectral interpolation. 

In [None]:
demo_meshes = []
demo_meshes_labels = []
meshes_directory = os.path.join(demo_directory, "meshes")
for dirpath, _, fnames in os.walk(meshes_directory):
    for f in fnames:
        if f.endswith('.ply') or f.endswith('.obj'):
            mesh_path = os.path.join(meshes_directory, f)
            demo_meshes.append(trimesh.load_mesh(mesh_path, process=False))
            demo_meshes_labels.append(f[0])

print(f"{len(demo_meshes)} meshes available. Use indices between 0 and",
      f"{len(demo_meshes) - 1} to select them")

Run the next cell to visualize the different meshes. Change the value of `mesh_id` according to the values specified above.

In [None]:
mesh_id = 0

print(f"mesh class: {label2name_dict[demo_meshes_labels[mesh_id]]}")
demo_meshes[mesh_id].show()

## Example of data augmentation

Running the eigendecomposition may take up to a few minutes depending on the computer running the code. For this reason we reccomend to run the next cell once and then experiment with the data augmentation in the following cell.

In [None]:
eigd = utils.compute_laplacian_eigendecomposition(manager.template, k=1000)

Change the value of `mesh1_id` and `mesh2_id` to test the data augmentation with different mesh pairs. Note that in this demo example you can create augmented samples even from subjects in different age groups and with different syndromes.

In [None]:
mesh1_id = 1
mesh2_id = 3

mesh1_class = label2name_dict[demo_meshes_labels[mesh1_id]]
mesh2_class = label2name_dict[demo_meshes_labels[mesh2_id]]
assert mesh1_id != mesh2_id
if mesh1_class != mesh2_class: 
    print("The selected meshes come from different classes:",
          f"{mesh1_class} and {mesh2_class}."
          "You should select meshes with the same class")

mesh1 = demo_meshes[mesh1_id].copy()
mesh2 = demo_meshes[mesh2_id].copy()
x1 = np.array(mesh1.vertices)
x2 = np.array(mesh2.vertices)

x_aug = utils.spectral_interpolation(x1, x2, eigd)
mesh_aug = demo_meshes[mesh1_id].copy()
mesh_aug.vertices = x_aug

print("Mesh 1 depicted on the left, \nMesh 2 on the right, \nAugmented mesh in the middle")
scene = trimesh.scene.scene.Scene()
mesh1.vertices[:, 0] = mesh1.vertices[:, 0] - 2
scene.add_geometry(mesh1)
mesh2.vertices[:, 0] = mesh2.vertices[:, 0] + 2
scene.add_geometry(mesh2)
scene.add_geometry(mesh_aug)
trimesh.scene.lighting.autolight(scene)
scene.show()

## Diagnose demo meshes

#### Patient classification

The mesh is encoded, with SD-VAE and the latent vector is classified with QDA. See different classification results by changing `mesh_id`.

In [None]:
mesh_id = 2

normalization_dict_path = os.path.join(demo_directory, "norm.pt")
normalization_dict = torch.load(normalization_dict_path)

mesh_class = label2name_dict[demo_meshes_labels[mesh_id]]
mesh_verts = torch.tensor(demo_meshes[mesh_id].vertices, 
                          dtype=torch.float,
                          requires_grad=False, device='cpu')
v_p = (mesh_verts - normalization_dict['mean']) / normalization_dict['std']
z_p = manager.encode(v_p.unsqueeze(0).to(device))
classification_result = label2name_dict[manager.classify_latent(z_p, 'qda')[0]]

print(f"The selected mesh was a {mesh_class} patient") 
print(f"It was classified as a {classification_result} patient")

demo_meshes[mesh_id].show()

#### Project on global manifold visualisation

The 75-dimensional latent vector corresponding to the patient's head is projected with LDA. Since the latent distribution plot is obtained relying on the entire dataset, here we load a precomputed plot and project new samples on it. 

If you want to see the plot of a different subject, change `mesh_id` in the **patient classification** section and run that cell before.

In [None]:
import seaborn as sns 

fig_entire_z_name = os.path.join(demo_directory, "lda_emb_distributions.pkl")

with open(fig_entire_z_name, 'rb') as f:
    fig_entire_z = pickle.load(f)

z_proj = manager.lda_project_latents_in_2d(z_p.detach().cpu().numpy())

ax = fig_entire_z.gca()
sns.scatterplot(x=z_proj[:, 0], y=z_proj[:, 1], ax=ax, c=['#e881a7'])
fig_entire_z.show()

#### Project on local manifold visualisations

Similarly to the global manifold visualisation, also the local distribution plots are precomputed. The projections are performed with the pre-trained attribute-specific LDA models. Each LDA model projects a 5-dimensional latent vector in a 2-dimensional space.

If you want to see the plot of a different subject, change `mesh_id` in the **patient classification** section and run that cell before.

In [None]:
fig_regions_z_name = os.path.join(demo_directory, "emb_all_train_dist.pkl")
region_ldas_name = os.path.join(demo_directory, "region_ldas.pkl")

with open(fig_regions_z_name, 'rb') as f:
    fig_fgrid_regions_z = pickle.load(f)
with open(region_ldas_name, 'rb') as f:
    region_ldas = pickle.load(f)
    

z_p_np = z_p.detach().cpu().numpy()
r_proj = {}
for key, z_region in manager.latent_regions.items():
    z_p_region = z_p_np[:, z_region[0]:z_region[1]]
    z_r_embeddings = region_ldas[key].transform(z_p_region)
    r_proj[key] = z_r_embeddings
    x1, x2 = z_r_embeddings[:, 0], z_r_embeddings[:, 1]
    fig_fgrid_regions_z.axes_dict[utils.colour2attribute_dict[key]].scatter(
        x1, x2, c=['#e881a7'], s=10)

## Surgical planning of demo meshes

Running the first cell will display the keywords associated to the different surgical procedures. 

In [None]:
print("Use the following keys to select a procedure in the next cell:")
print(list(utils.procedures2attributes_dict.keys()))

Choose a patient and a procedure to start the surgical planning. This can be done by modifying the `mesh_id` and `procedure_id` variables in the next code cell. Procedures' keywords are reported above. 

#### Global interpolation trajectory
The following cell will project the different interpolation steps on the plot depicting the latent distributions of the whole latent.

In [None]:
mesh_id = 0
procedure_id = 'monobloc'

def vector_linspace(start, finish, steps):
    ls = []
    for s, f in zip(start[0], finish[0]):
        ls.append(torch.linspace(s, f, steps))
    res = torch.stack(ls)
    return res.t()

mesh_class = label2name_dict[demo_meshes_labels[mesh_id]]
assert mesh_class != "Healthy", "The patient is already healthy!"

mesh_verts = torch.tensor(demo_meshes[mesh_id].vertices, 
                          dtype=torch.float,
                          requires_grad=False, device='cpu')
v_p = (mesh_verts - normalization_dict['mean']) / normalization_dict['std']
z_p = manager.encode(v_p.unsqueeze(0).to(device))

# Find healthy patients latent vectors
normal_p_index = manager.class2idx('n')
normal_p_mean = manager.qda.means_[normal_p_index]

# Move from mean of distribution to 1std in direction of z_p.
# Eigenvalues of covariance matrix are diagonal of covariance of aligned
# distribution -> use them to find pdf at 1 std
normal_p_covariance = manager.qda.covariance_[normal_p_index]
multi_normal_dist = scipy.stats.multivariate_normal(
    mean=normal_p_mean, cov=normal_p_covariance)
eigenval, eigenvec = scipy.linalg.eigh(normal_p_covariance)
reference_dist = scipy.stats.multivariate_normal(
    mean=np.zeros_like(normal_p_mean), cov=np.diag(eigenval))
reference_std_on_x1 = np.sqrt(reference_dist.cov[0, 0])
reference_std_vec_on_x1 = np.zeros_like(normal_p_mean)
reference_std_vec_on_x1[0] = reference_std_on_x1

reference_pdf_1std = - reference_dist.logpdf(reference_std_vec_on_x1)
reference_pdf_2std = - reference_dist.logpdf(2 * reference_std_vec_on_x1)
reference_pdf_3std = - reference_dist.logpdf(3 * reference_std_vec_on_x1)

z_mean_target = torch.tensor(normal_p_mean).unsqueeze(0)
z_interp_full = vector_linspace(z_p, z_mean_target, 5000)

# find z vectors with correct pdf
pdf_intermediate = [-multi_normal_dist.logpdf(z.detach().cpu().numpy())
                    for z in z_interp_full]
pdf_lt_3std = [p <= reference_pdf_3std for p in pdf_intermediate]
pdf_lt_2std = [p <= reference_pdf_2std for p in pdf_intermediate]
pdf_lt_1std = [p <= reference_pdf_1std for p in pdf_intermediate]

z_3std_target = z_interp_full[pdf_lt_3std.index(True), :].unsqueeze(0)
z_2std_target = z_interp_full[pdf_lt_2std.index(True), :].unsqueeze(0)
z_1std_target = z_interp_full[pdf_lt_1std.index(True), :].unsqueeze(0)



# Interpolate subsets of attributes ####################################

attributes = utils.procedures2attributes_dict[procedure_id]
n_p_to_3std = 8
z_interp = z_p.repeat(n_p_to_3std + 3, 1)
for attr in attributes:
    zf_idxs = manager.latent_regions[attr]
    z_pf = z_p[:, zf_idxs[0]:zf_idxs[1]].to(device)
    z_3f = z_3std_target[:, zf_idxs[0]:zf_idxs[1]].to(device)
    z_interp[:n_p_to_3std, zf_idxs[0]:zf_idxs[1]] = \
        vector_linspace(z_pf, z_3f, n_p_to_3std)
    z_2f = z_2std_target[:, zf_idxs[0]:zf_idxs[1]].to(device)
    z_1f = z_1std_target[:, zf_idxs[0]:zf_idxs[1]].to(device)
    z_mf = z_mean_target[:, zf_idxs[0]:zf_idxs[1]].to(device)
    z_interp[n_p_to_3std, zf_idxs[0]:zf_idxs[1]] = z_2f
    z_interp[n_p_to_3std + 1, zf_idxs[0]:zf_idxs[1]] = z_1f
    z_interp[n_p_to_3std + 2, zf_idxs[0]:zf_idxs[1]] = z_mf


z_interp_proj = manager.lda_project_latents_in_2d(z_interp.detach().cpu().numpy())

with open(fig_entire_z_name, 'rb') as f:
    fig_entire_z = pickle.load(f)
ax = fig_entire_z.gca()
sns.scatterplot(x=z_interp_proj[:, 0], y=z_interp_proj[:, 1], ax=ax, c=['#e881a7'])
fig_entire_z.show()

#### Local interpolation trajectories

The following cell will project the different interpolation steps on the plots depicting the latent distributions of the 5D attribute-specific latent vectors.

If you want to change patient or procedure modify the `mesh_id` and `procedure_id` variables in the **global interpolation trajectory** section. Then run all cells from there.

In [None]:
with open(fig_regions_z_name, 'rb') as f:
    fig_fgrid_regions_z = pickle.load(f)

z_interp_np = z_interp.detach().cpu().numpy()
r_proj = {}
for key, z_region in manager.latent_regions.items():
    z_interp_region = z_interp_np[:, z_region[0]:z_region[1]]
    z_r_embeddings = region_ldas[key].transform(z_interp_region)
    r_proj[key] = z_r_embeddings
    x1, x2 = z_r_embeddings[:, 0], z_r_embeddings[:, 1]
    fig_fgrid_regions_z.axes_dict[utils.colour2attribute_dict[key]].scatter(
        x1, x2, c=['#e881a7'], s=2)

#### Show results of latent interpolation as a static image

Every rendered shape corresponds to one of the pink dots in the previous plots.

If you want to change patient or procedure modify the `mesh_id` and `procedure_id` variables in the **global interpolation trajectory** section. Then run all cells from there.

In [None]:
import matplotlib.pyplot as plt
import matplotlib.animation

from torchvision.transforms.functional import to_pil_image
from pytorch3d.renderer import BlendParams
from torchvision.utils import make_grid

manager.renderer.rasterizer.raster_settings.image_size = 512
blend_params = BlendParams(background_color=[1, 1, 1])
manager.default_shader.blend_params = blend_params
manager.simple_shader.blend_params = blend_params

v_interp = manager.generate(z_interp.to(device))
v_interp = v_interp * normalization_dict['std'].to(device) + \
    normalization_dict['mean'].to(device)


source_dist = manager.compute_vertex_errors(
    v_interp, v_interp[0, ::].expand(v_interp.shape[0], -1, -1))
source_colours = utils.errors_to_colors(
    source_dist, min_value=0, max_value=10, cmap='plasma') / 255

renderings = manager.render(v_interp).cpu()
renderings_dist = manager.render(v_interp, source_dist, error_max_scale=10).cpu()

im = make_grid(torch.cat([renderings, renderings_dist], dim=-2),
               padding=10, pad_value=1, nrow=v_interp.shape[0])
plt.figure()
plt.imshow(np.asarray(to_pil_image(im)))
plt.axis('off')
plt.show()

#### Show results of latent interpolation as video

Every frame corresponds to one of the pink dots in the previous plots.

If you want to change patient or procedure modify the `mesh_id` and `procedure_id` variables in the **global interpolation trajectory** section. Then run all cells from there.

In [None]:
rend_comb = torch.cat([renderings, renderings_dist], dim=-1)
fig = plt.figure()
plt.axis('off')
frames = [[plt.imshow(np.asarray(to_pil_image(rend_comb[i])), animated=True)]
          for i in range(rend_comb.shape[0])]
ani = matplotlib.animation.ArtistAnimation(
    fig, frames, interval=200, blit=True, repeat_delay=500)
plt.show()