# Development

In [None]:
import sys, os
os.chdir("..")

import utils.VTKHelpers
sys.path.append("utils/VTKHelpers/")

from config.load_config import load_yaml_config
from CardiacMesh import Cardiac3DMesh, Cardiac4DMesh, CardiacMeshPopulation
from models import layers

import pickle as pkl
import yaml
from pprint import pprint
from argparse import Namespace
import logging

import numpy as np
import torch
from torch.utils.data import TensorDataset, DataLoader, random_split
import pytorch_lightning as pl

import ipywidgets as widgets
from IPython.display import display, HTML

import os
import pickle as pkl
from utils import mesh_operations
from utils.helpers import *

In [None]:
import mlflow.pytorch
from mlflow.tracking import MlflowClient

In [None]:
%%javascript
$('<div id="toc"></div>').css({position: 'fixed', top: '120px', left: 0}).appendTo(document.body);
$.getScript('https://kmahelona.github.io/ipython_notebook_goodies/ipython_notebook_toc.js');

Select configuration file

In [None]:
config_files_w = widgets.Dropdown(
    options=[x for x in os.listdir("config") if x.endswith("yaml")],
    value="config_folded_c_and_s.yaml"
)
display(config_files_w)

In [None]:
config = load_yaml_config(os.path.join("config", config_files_w.value))

## PyTorch Lightning DataModule

In [None]:
# To cache data
# popu = CardiacMeshPopulation(config.root_folder, N_subj=200)
# kk = {"meshes": popu.as_numpy_array(), "ids": popu.ids}
# with open("data/cached/cardiac_population_200_meshes.pkl", "wb") as ff:
#     pkl.dump(kk, ff)

### Synthetic Meshes

In [None]:
from data.SyntheticDataModules import SyntheticMeshesDM
from data.DataModules import CardiacMeshPopulationDataset, CardiacMeshPopulationDM
import vedo
from utils import mesh_operations

sphere = vedo.Sphere()
template_mesh = "data/vedo_sphere_template.vtk"
sphere.write(template_mesh, binary=False)

M, A, D, U = mesh_operations.generate_transform_matrices(
    Cardiac3DMesh(template_mesh), 
    config.network_architecture.pooling.parameters.downsampling_factors
)

A_t, D_t, U_t = ([scipy_to_torch_sparse(x) for x in X] for X in (A, D, U))
n_nodes = [len(M[i].v) for i in range(len(M))]

dm = SyntheticMeshesDM(
  batch_size=config.batch_size, 
  data_params=config.dataset.parameters.__dict__, 
  preprocessing_params=config.dataset.preprocessing
)
dm.setup()

### Cardiac meshes

In [None]:
# with open("data/cached/cardiac_population_200_meshes.pkl", "rb") as ff:
#     kk = pkl.load(ff)
#     
# dm = CardiacMeshPopulationDM(
#     cardiac_population=kk, 
#     batch_size=2
# )
# 
# A_t, D_t, U_t, n_nodes = pkl.load(open("data/cached/matrices.pkl", "rb"))

In [None]:
from main import get_coma_matrices, get_coma_args, get_dm_model_trainer

In [None]:
coma_args = get_coma_args(config, dm)

In [None]:
from models.model import Coma4D
from models.model_c_and_s import Coma4D_C_and_S
from models.coma_ml_module import CoMA
coma4D = Coma4D_C_and_S(**coma_args)
model = CoMA(coma4D, config)

## PyTorch Lightning Module

### Testing PL module

In [None]:
# train
trainer = pl.Trainer()
trainer.fit(model, datamodule=dm)

## Load data

In [None]:
# import objgraph
# objgraph.show_refs(config, max_depth=2, )

In [None]:
popu = CardiacMeshPopulation(config.root_folder)
print(popu.as_numpy_array().shape)