# Development

In [6]:
import sys, os

import utils.VTKHelpers
sys.path.append("utils/VTKHelpers/")

from config.load_config import load_config
from CardiacMesh import Cardiac3DMesh, Cardiac4DMesh, CardiacMeshPopulation
from models import layers

import pickle as pkl
import yaml
from pprint import pprint
from argparse import Namespace
import logging

import numpy as np
import torch
from torch.utils.data import TensorDataset, DataLoader, random_split
import pytorch_lightning as pl

import ipywidgets as widgets
from IPython.display import display, HTML

import os
import pickle as pkl
from utils import mesh_operations
from utils.helpers import *

In [7]:
import mlflow.pytorch
from mlflow.tracking import MlflowClient

In [8]:
%%javascript
$('<div id="toc"></div>').css({position: 'fixed', top: '120px', left: 0}).appendTo(document.body);
$.getScript('https://kmahelona.github.io/ipython_notebook_goodies/ipython_notebook_toc.js');

<IPython.core.display.Javascript object>

Select configuration file

In [5]:
config_files_w = widgets.Dropdown(
    options=[x for x in os.listdir("config") if x.endswith("yaml")],
    value="config.yaml"
)
display(config_files_w)

Dropdown(index=1, options=('config_template.yaml', 'config.yaml'), value='config.yaml')

In [6]:
config = load_config(os.path.join("config", config_files_w.value))

## PyTorch Lightning DataModule

In [8]:
# To cache data
# popu = CardiacMeshPopulation(config.root_folder, N_subj=200)
# kk = {"meshes": popu.as_numpy_array(), "ids": popu.ids}
# with open("data/cached/cardiac_population_200_meshes.pkl", "wb") as ff:
#     pkl.dump(kk, ff)

### Synthetic Meshes

In [12]:
from data.SyntheticDataModules import SyntheticMeshesDM
from data.DataModules import CardiacMeshPopulationDataset, CardiacMeshPopulationDM
import vedo
from utils import mesh_operations

sphere = vedo.Sphere()
template_mesh = "data/vedo_sphere_template.vtk"
sphere.write(template_mesh, binary=False)

M, A, D, U = mesh_operations.generate_transform_matrices(
    Cardiac3DMesh(template_mesh), 
    config.network_architecture.pooling.parameters.downsampling_factors
)

A_t, D_t, U_t = ([scipy_to_torch_sparse(x) for x in X] for X in (A, D, U))
n_nodes = [len(M[i].v) for i in range(len(M))]

dm = SyntheticMeshesDM()

### Cardiac meshes

In [13]:
# with open("data/cached/cardiac_population_200_meshes.pkl", "rb") as ff:
#     kk = pkl.load(ff)
#     
# dm = CardiacMeshPopulationDM(
#     cardiac_population=kk, 
#     batch_size=2
# )
# 
# A_t, D_t, U_t, n_nodes = pkl.load(open("data/cached/matrices.pkl", "rb"))

In [None]:
coma_args = {
  "num_features": config.network_architecture.n_features,
  "n_layers": len(config.network_architecture.convolution.parameters.channels), # REDUNDANT
  "num_conv_filters": config.network_architecture.convolution.parameters.channels,
  "polygon_order": config.network_architecture.convolution.parameters.polynomial_degree,
  "latent_dim": config.network_architecture.latent_dim,
  "is_variational": config.loss.regularization_loss.weight != 0,
  "downsample_matrices": D_t,
  "upsample_matrices": U_t, 
  "adjacency_matrices": A_t,
  "n_nodes": n_nodes, 
  "mode": "testing"
}

from models.model import Coma4D
from models.coma_ml_module import CoMA
coma4D = Coma4D(**coma_args)
model = CoMA(coma4D, config)

# train
trainer = pl.Trainer()
trainer.fit(model, datamodule=dm)

GPU available: False, used: False
TPU available: False, using: 0 TPU cores

  | Name  | Type   | Params
---------------------------------
0 | model | Coma4D | 412 K 
---------------------------------
412 K     Trainable params
0         Non-trainable params
412 K     Total params
1.652     Total estimated model params size (MB)


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validation sanity check', layout=Layout…



HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Training', layout=Layout(flex='2'), max…

## PyTorch Lightning Module

### Testing PL module

In [None]:
# train
trainer = pl.Trainer()
trainer.fit(model, datamodule=dm)

## Load data

In [None]:
# import objgraph
# objgraph.show_refs(config, max_depth=2, )

In [None]:
popu = CardiacMeshPopulation(config.root_folder)
print(popu.as_numpy_array().shape)