In [1]:
import os, shlex
from subprocess import check_output

# go to the root of the repository
repo_rootdir = check_output(shlex.split("git rev-parse --show-toplevel")).strip().decode('ascii')
os.chdir(repo_rootdir)

import os
import sys
import torch

from config.config_parser import read_config
from utils.helpers import *

import utils.mesh_operations
device = get_device()
from IPython.display import display

import ipywidgets as widgets

import pyvista as pv
from ipywidgets import interact, interactive, fixed, interact_manual
import numpy
from pprint import pprint

2020-10-09 14:58:26,079 [INFO] No OpenGL_accelerate module loaded: No module named 'OpenGL_accelerate'


# Select experiment
An "experiment" is an instance of training for a specific model. Each experiment has a folder associated in the "output" folder. These folders' names are just timestamps, by default. The associated configuration (network architecture, input data, random seed, training parameters, etc.) can be seen in a file called config.json that is located inside the experiment's folder.

In [2]:
output_dir = "output"

experiments = [
  x for x in sorted(os.listdir(output_dir)) 
    if os.path.exists(os.path.join(output_dir, x, ".finished")) # just a workaround to check if the training finished
]

w = widgets.Dropdown(
    options=experiments,
    description='Experiment:',
    disabled=False,
)

display(w)

Dropdown(description='Experiment:', options=('2020-09-03_16-58-02', '2020-09-09_13-24-32', '2020-09-10_02-46-2…

In [6]:
import ExperimentClass

# Load experiment
experiment = ExperimentClass.ComaExperiment(os.path.join(output_dir, w.value))
experiment.load_model() # load a trained model

2020-10-09 14:58:49,189 [ERROR] Unsupported point attribute type: point_data for file: ./template/template.vtk


In [7]:
# print configuration for the experiment selected
pprint(experiment.config)

{'activation_function': 'relu',
 'batch_size': 16,
 'checkpoint_file': 'output/test_2020-09-11_01-49-33/checkpoints/checkpoint_117.pt',
 'comments': '',
 'data_dir': 'data/meshes/numpy_files/LV_all_subjects/train.npy',
 'downsampling_factors': [4, 4, 3, 2],
 'epoch': 118,
 'eval': False,
 'group_label': '',
 'ids_file': 'data/meshes/numpy_files/LV_all_subjects/LVED_all_subjects_subj_ids.txt',
 'kld_weight': 0.01,
 'label': '',
 'learning_rate': 0.5,
 'learning_rate_decay': 0.99,
 'nTraining': 5000,
 'nVal': 1000,
 'n_layers': 4,
 'num_conv_filters': [3, 16, 16, 16, 32, 32],
 'optimizer': 'adam',
 'output_dir': 'output/{TIMESTAMP}',
 'partition': 'LV',
 'polygon_order': [6, 6, 6, 6, 6],
 'preprocessed_data': 'data/transforms/cached/2ch_segmentation__LV__ED__scaled.pkl',
 'procrustes_scaling': True,
 'procrustes_type': 'generalized',
 'reconstruction_loss': 'l1',
 'run_id': '2020-09-11_02-13-41',
 'template_fname': './template/template.vtk',
 'test': False,
 'visual_output_dir': '',
 'vi

# Load mesh data used for the experiment

In [8]:
# Load pre-aligned meshes
prealigned_meshes = experiment.load_prealigned_meshes()
dataloader = get_loader(prealigned_meshes.point_clouds, prealigned_meshes.ids, batch_size=8, num_workers=4, shuffle=False)

# Reconstruct meshes from specific subjects

In [9]:
# build a dictionary mapping subject ID with the position of the corresponding mesh in the array
id_dict = {k:v for v,k in enumerate(prealigned_meshes.ids)}

def get_mesh_for_id(dataloader, id, id_dict):
    position = id_dict[id]
    return dataloader.dataset[position][0].unsqueeze(0)

In [10]:
def reconstruct(experiment, mesh):
    model = experiment.model
    with torch.no_grad():
        if model.is_variational:
            mu, log_var = model.encoder(x=mesh)
            z = mu
        else:
            z = model.encoder(x=mesh)
        mesh_r = model.decoder(z)
    return mesh_r



def plot_mesh(mesh):
  kargs = {"point_size": 5, "render_points_as_spheres": True}
  plotter = pv.Plotter(notebook=True)
  plotter.add_mesh(mesh, **kargs)
  plotter.show(interactive=True)
  plotter.enable()



def f(id, reconstructed=True, deviation=True):
  try:
    mesh = get_mesh_for_id(dataloader, id, id_dict)
    _mesh = reconstruct(experiment, mesh) if reconstructed else mesh
    _mesh = _mesh.detach().numpy() if deviation else prealigned_meshes.mean + prealigned_meshes.std * _mesh.detach().numpy()
    plot_mesh(_mesh.squeeze(0))
  except:
    pass

NameError: name 'wcb' is not defined

In [None]:
Select a spec

In [None]:
wcb=widgets.Combobox(
    placeholder="Choose a subject",
    options=sorted(prealigned_meshes.ids),
    value="1000336"
)
interact(f, 
  id=wcb,
  reconstructed=True,
  deviation=True
)

# Create synthetic meshes

In [None]:
def reconstruct_from_z(experiment, z):
    model = experiment.model
    with torch.no_grad():
        mesh_r = model.decoder(z)
    return mesh_r    

In [None]:
def g(z0, z1, z2, z3, z4, z5, z6, z7, deviation=True):
    
    z = torch.Tensor([z0, z1, z2, z3, z4, z5, z6, z7])    
    mesh = reconstruct_from_z(experiment, z.unsqueeze(0)) # get_mesh_for_id(dataloader, id, id_dict)
    _mesh = mesh.detach().numpy() if deviation else prealigned_meshes.mean + prealigned_meshes.std * mesh.detach().numpy()
    plot_mesh(_mesh.squeeze(0))

z_sliders = {"z"+str(i):widgets.FloatSlider(value=0, min=-3, max=3, step=1) for i in range(experiment.config['z'])}

interact(g, 
  **z_sliders,
  deviation=True
)