Set path to where `mlruns` directory is located (usually, the `CardiacCOMA` repository)

In [None]:
CARDIAC_COMA_REPO = "/home/rodrigo/CISTIB/repos/CardiacCOMA/"

In [None]:
import mlflow
import os, sys

import torch
import torch.nn.functional as F

import os; os.chdir(CARDIAC_COMA_REPO)
from config.load_config import load_yaml_config, to_dict

import ipywidgets as widgets
from ipywidgets import interact
from IPython.display import Image
from mlflow.tracking import MlflowClient

import pickle as pkl
import pytorch_lightning as pl

from argparse import Namespace
import matplotlib.pyplot as plt

#import surgeon_pytorch
#from surgeon_pytorch import Inspect, get_layers

import numpy as np
import pandas as pd
from IPython import embed
sys.path.insert(0, '..')

import model.Model3D
from utils.helpers import get_coma_args, get_lightning_module, get_datamodule
from copy import deepcopy
from pprint import pprint

from copy import deepcopy
from typing import List

In [None]:
from mlflow_helpers import \
    list_artifacts,\
    get_significant_loci,\
    get_metrics_cols, \
    get_params_cols, \
    get_runs_df, \
    get_good_runs,\
    summarize_loci_across_runs,\
    get_model_pretrained_weights

In [None]:
TRACKING_URI = f"file://{CARDIAC_COMA_REPO}/mlruns"
mlflow.set_tracking_uri(TRACKING_URI)

# Select MLflow experiment

In [None]:
client = MlflowClient()

In [None]:
options = [exp.name for exp in mlflow.list_experiments()]

experiment_w = widgets.Select(
    options=options,
    value=options[1]
)
display(experiment_w)

Retrieve run data from MLflow for the chosen experiment

In [None]:
exp_id = mlflow.get_experiment_by_name(experiment_w.value).experiment_id
runs_df = get_runs_df(exp_name=experiment_w.value, only_finished=True)
metrics = get_metrics_cols(runs_df)
params = get_params_cols(runs_df)

runs_df.loc[:, [*metrics, *params]].drop("params.platform", axis=1).head(10)

In [None]:
RECON_LOSS_THRES = 1 # performance threshold for MSE mm2.
run_ids = sorted([x[1] for x in runs_df[runs_df["metrics.test_recon_loss"] < RECON_LOSS_THRES].index])
run_ids_w = widgets.Select(description="Choose run:", options={x[:10]: x for x in run_ids})
display(run_ids_w)

In [None]:
run_id = run_ids_w.value
run_info = runs_df.loc[exp_id, run_id].to_dict()
artifact_uri = run_info["artifact_uri"].replace("file://", "")

In [None]:
loci_summary_df = summarize_loci_across_runs(runs_df)
loci_summary_df.head(10)

In [None]:
interact(
    lambda run_id: get_significant_loci(runs_df, exp_id, run_id), 
    run_id=run_ids_w
); 

In [None]:
def overwrite_ref_config(ref_config, run_info):
    
    '''
    This is a workaround for adjusting the configuration of those runs that didn't have a YAML configuration file logged as an artifact.
    '''
    
    config = deepcopy(ref_config)
    config.network_architecture.latent_dim = int(run_info["params.latent_dim"])
    config.loss.regularization.weight = float(run_info["params.w_kl"])
    config.optimizer.parameters.lr = float(run_info["params.lr"])
    config.sample_sizes = [100, 100, 100, 100]
    
    return config


ref_config = load_yaml_config("config_files/config.yaml")
config = overwrite_ref_config(ref_config, run_info)
pprint(to_dict(config))

In [None]:
pl.utilities.seed.reset_seed() # seed_everything(seed=None)
pl.utilities.seed.seed_everything(seed=None)

In [None]:
dm = get_datamodule(config, perform_setup=True)

In [None]:
model = get_lightning_module(config, dm)

In [None]:
weights = get_model_pretrained_weights(runs_df, exp_id, run_id)

In [None]:
model.model.load_state_dict(_model_pretrained_weights)

Assess perfomance of model

In [None]:

def mse(s1, s2=None):
    if s2 is None:
        s2 = torch.zeros_like(s1)
    return ((s1-s2)**2).sum(-1).mean(-1)

In [None]:
s = dm.dataset[1]['s']
s_hat = model(s)[0][0]
mse(s, s_hat)