In [None]:
import mlflow
from mlflow.tracking import MlflowClient

import os, sys

import torch
import torch.nn.functional as F

HOME = os.environ["HOME"]
CARDIAC_COMA_REPO = f"{HOME}/01_repos/CardiacCOMA/"
import os; os.chdir(CARDIAC_COMA_REPO)

from config.cli_args import overwrite_config_items
from config.load_config import load_yaml_config, to_dict

import ipywidgets as widgets
from ipywidgets import interact
from IPython.display import Image

import pandas as pd
import shlex
from subprocess import check_output

import pickle as pkl
import pytorch_lightning as pl

from argparse import Namespace
import matplotlib.pyplot as plt

#import surgeon_pytorch
#from surgeon_pytorch import Inspect, get_layers

import numpy as np
import pandas as pd
from IPython import embed
sys.path.insert(0, '..')

# import model.Model3D
# from utils.helpers import get_coma_args, get_lightning_module, get_datamodule
from copy import deepcopy
from pprint import pprint

from typing import List
from tqdm import tqdm

import pandas as pd

import pyvista as pv
from ipywidgets import interact, interactive, fixed, interact_manual

In [None]:
from utils.helpers import get_datamodule, get_lightning_module

In [None]:
from utils.mlflow_helpers import get_model_pretrained_weights

In [None]:
CARDIAC_GWAS_REPO = f"{HOME}/01_repos/CardiacGWAS"
CARDIAC_COMA_REPO = f"/{HOME}/01_repos/CardiacCOMA/"
MLRUNS_DIR = f"{CARDIAC_GWAS_REPO}/mlruns"
import os; os.chdir(CARDIAC_COMA_REPO)

In [None]:
def overwrite_ref_config(ref_config, run_info):
    
    '''
    This is a workaround for adjusting the configuration of those runs that didn't have a YAML configuration file logged as an artifact.
    '''
    
    config = deepcopy(ref_config)
    config.network_architecture.latent_dim = int(run_info["params.latent_dim"])
    config.loss.regularization.weight = float(run_info["params.w_kl"])
    config.optimizer.parameters.lr = float(run_info["params.lr"])
    config.sample_sizes = [100, 100, 100, 100]
    
    return config


In [None]:
good_runs_df = pd.read_csv(f"{CARDIAC_GWAS_REPO}/good_runs.csv")

In [None]:
def get_model_pretrained_weights(exp_id, run_id):
    
    # run_info = runs_df.loc[experiment_id, run_id].to_dict()
    chkpt_dir = f"{CARDIAC_COMA_REPO}/mlruns/{exp_id}/{run_id}/checkpoints"
    if not os.path.exists(chkpt_dir):
        chkpt_dir = f"{CARDIAC_COMA_REPO}/mlruns/{exp_id}/{run_id}/artifacts/restored_model_checkpoint"
    
    chkpt_file = os.path.join(chkpt_dir, os.listdir(chkpt_dir)[0])
    
    model_pretrained_weights = torch.load(chkpt_file, map_location=torch.device('cpu'))["state_dict"]
    
    # Remove "model." prefix from state_dict's keys.
    _model_pretrained_weights = {k.replace("model.", ""): v for k, v in model_pretrained_weights.items()}

    return _model_pretrained_weights

In [None]:
run_w = widgets.Select(options=sorted(good_runs_df.run_id.to_list()))

@interact
def load_model(run_id=run_w):
    global _run_id, config, weights
    _run_id = run_id
    config_file = f"{mlruns_dir}/1/{run_id}/artifacts/config.yaml"    
    config = load_yaml_config(config_file)
    config.sample_sizes = [100, 100, 100, 100]
    # pprint(to_dict(config))
    
    exp_id = "1"
    weights = get_model_pretrained_weights(exp_id, _run_id)

In [None]:
dm = get_datamodule(config, perform_setup=True)
model = get_lightning_module(config, dm)
model.model.load_state_dict(weights)

In [None]:
# 1. get z dataframe
# 2. get mean and std
# 3. pass z_mean through model

exp_id = '1'
run_id = _run_id

# GOSR2
z_var = "z001"

# PLN
z_var = "z003"

# TTN
z_var = "z003"

In [None]:
df = pd.read_csv(f"{CARDIAC_COMA_REPO}/mlruns/{exp_id}/{_run_id}/artifacts/output/latent_vector.csv")
df = df.set_index("ID")
df

In [None]:
min_mae = 1000
for row in list(df.iterrows()):
    
    dev = np.array(row[1]) - z    
    mae = np.sum(dev**2)    
    
    if mae < min_mae:
        min_mae = mae
        id_min = row[0]
        dev_min = dev
        
print(id_min, min_mae, dev)

In [None]:
# sphere = vedo.Sphere(res=params["mesh_resolution"]).to_trimesh()
# conn = sphere.faces # connectivity
# conn = np.c_[np.ones(conn.shape[0]) * 3, conn].astype(int)  # add column of 3, as required by PyVista

import random
pv.set_plot_theme("document")

faces, _ = pkl.load(open("data/cardio/faces_and_downsampling_mtx_frac_0.1_LV.pkl", "rb")).values()
faces = np.c_[np.ones(faces.shape[0]) * 3, faces].astype(int)

color_palette = list(pv.colors.color_names.values())
random.shuffle(color_palette)

def f(z_dev=widgets.IntSlider(min=-3,max=3)):
    
    df = pd.read_csv(f"{CARDIAC_COMA_REPO}/mlruns/{exp_id}/{_run_id}/artifacts/output/latent_vector.csv").drop("ID", axis=1)
    z_mean, z_std = df.mean(), df.std()
    z = torch.Tensor(z_mean + z_dev * z_std)
    # z = torch.zeros(z_mean.shape)
    s = model.model.decoder(z).detach().numpy()[0]

    pl = pv.Plotter(notebook=True, off_screen=False, polygon_smoothing=False)
    mesh = pv.PolyData(s, faces)
    pl.add_mesh(mesh, show_edges=False, point_size=1.5, color=color_palette[0], opacity=0.5)
    pl.show(interactive=True, interactive_update=True)
    
interact(f);

In [None]:
model.model.decoder.requires_grad_ = False

In [None]:
s = dm.dataset[1]['s']
s_hat = model(s)[0][0]
mse(s, s_hat)