Set path to where `mlruns` directory is located (usually, the `CardiacCOMA` repository)

In [1]:
CARDIAC_COMA_REPO = "/home/rodrigo/CISTIB/repos/CardiacCOMA/"

In [2]:
import mlflow
import os, sys

import torch
import torch.nn.functional as F

import os; os.chdir(CARDIAC_COMA_REPO)
from config.load_config import load_yaml_config, to_dict

import ipywidgets as widgets
from ipywidgets import interact
from IPython.display import Image
from mlflow.tracking import MlflowClient

import pickle as pkl
import pytorch_lightning as pl

from argparse import Namespace
import matplotlib.pyplot as plt

#import surgeon_pytorch
#from surgeon_pytorch import Inspect, get_layers

import numpy as np
import pandas as pd
from IPython import embed
sys.path.insert(0, '..')

import model.Model3D
from utils.helpers import get_coma_args, get_lightning_module, get_datamodule
from copy import deepcopy
from pprint import pprint

from copy import deepcopy
from typing import List

In [3]:
import matplotlib.pyplot as plt
import seaborn as sns

In [4]:
from mlflow_helpers import \
    list_artifacts,\
    get_significant_loci,\
    get_metrics_cols, \
    get_params_cols, \
    get_runs_df, \
    get_good_runs,\
    summarize_loci_across_runs,\
    get_model_pretrained_weights

In [5]:
TRACKING_URI = f"file://{CARDIAC_COMA_REPO}/mlruns"
mlflow.set_tracking_uri(TRACKING_URI)

# Select MLflow experiment

In [6]:
client = MlflowClient()

In [32]:
def experiment_selection_widget():
    options = [exp.name for exp in mlflow.list_experiments()]

    experiment_w = widgets.Select(
      options=options,
      value=options[0]
    )
    
    return experiment_w

@interact
def get_runs(exp_name=experiment_selection_widget()):  
  exp_id = mlflow.get_experiment_by_name(exp_name).experiment_id
  runs_df = get_runs_df(exp_name=exp_name, only_finished=True)
  metrics, params = get_metrics_cols(runs_df), get_params_cols(runs_df)  
  display(runs_df.loc[:, [*metrics, *params]].drop("params.platform", axis=1).head(10))

interactive(children=(Select(description='exp_name', options=('Cardiac - ED', 'Default'), value='Cardiac - ED'…

Retrieve run data from MLflow for the chosen experiment

In [35]:
RECON_LOSS_THRES = 1. # performance threshold for MSE mm2.
run_ids = sorted([x[1] for x in runs_df[runs_df["metrics.test_recon_loss"] < RECON_LOSS_THRES].index])

In [10]:
run_ids_w = widgets.Select(description="Choose run:", options={x[:10]: x for x in run_ids})
display(run_ids_w)
run_id = run_ids_w.value
run_info = runs_df.loc[exp_id, run_id].to_dict()
artifact_uri = run_info["artifact_uri"].replace("file://", "")

In [36]:
#runs_df_ = runs_df[runs_df["metrics.recon_loss"].astype(float) < 0.4]
ORDER_BY = {"by":"count", "ascending":False}
ORDER_BY = {"by":"min_P", "ascending":True}
ORDER_BY = {"by":"-log10(min_P)", "ascending":False}

runs_df_ = runs_df[runs_df["params.w_kl"].astype(float) == 0]
loci_summary_df = summarize_loci_across_runs(runs_df)
loci_summary_df["-log10(min_P)"] = loci_summary_df.apply(lambda row: -np.log10(row["min_P"]), axis=1)
loci_summary_df.sort_values(**ORDER_BY, axis=0)

Unnamed: 0_level_0,Unnamed: 1_level_0,count,min_P,-log10(min_P)
region,locus_name,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1
chr6_78,PLN,186,1.794734e-22,21.746
chr6_79,PLN,159,4.518559e-20,19.345
chr17_27,GOSR2,76,3.926449e-16,15.406
chr2_108,TTN,321,5.321083e-14,13.274
chr12_19,CCDC91*,9,7.585776e-13,12.12
chr11_2,LSP1*,12,1.648162e-12,11.783
chr12_69,TBX5,162,5.520774e-12,11.258
chr21_10,NCSTNP1*,15,8.729714e-12,11.059
chr12_67,Unnamed,30,1.119438e-11,10.951
chr6_26,Unnamed,11,1.164126e-11,10.934


# Statistics on the GWAS loci counts

In [13]:
signif_loci_dfs = {}
dd = []

def loci_count(run_df):
    from collections import Counter
    return dict(Counter([x[1] for x in run_df.index]))

for run in runs_df.index:
    
    try:     
      
      pp = get_significant_loci(runs_df[runs_df["metrics.val_recon_loss"] < 2], exp_id, run[1]) #.sort_values(by=["CHR", "BP"], axis=0)
      n_distinct_loci = len(loci_cnt.keys())
      n_hits_with_duplication = sum(loci_cnt.values())
      
      ff = [  run[1], 
         runs_df.loc[run, "metrics.test_recon_loss"], 
         runs_df.loc[run, "metrics.test_kld_loss"], 
         runs_df.loc[run, "params.latent_dim"], 
         runs_df.loc[run, "params.w_kl"],
         n_distinct_loci, 
         n_hits_with_duplication, 
         n_hits_with_duplication / n_distinct_loci             
      ]
      
      signif_loci_dfs[run[1]] = pp
      loci_cnt = loci_count(signif_loci_dfs[run[1]])
      dd.append(ff)
    except:
      pass

kk = pd.DataFrame(dd)

kk.columns = [
    "run_id",
    "test_mse",
    "kld",    
    "lat_dim",
    "w_kl",
    "n_loci",
    "n_loci_dupl",
    "ratio"    
]

ValueError: Length mismatch: Expected axis has 0 elements, new values have 8 elements

In [None]:
interact(
    lambda xcol, ycol: sns.boxplot(x=xcol, y=ycol, data=kk),
    xcol = widgets.Select(options=kk.columns),
    ycol = widgets.Select(options=kk.columns)
);

In [22]:
@interact
def show_signif_loci(run_id=run_ids_w):
    return get_significant_loci(runs_df, exp_id, run_id)

interactive(children=(Select(description='Choose run:', options={'0285fa2356': '0285fa2356fd454e88e3c30d6b63f1…

In [None]:
def overwrite_ref_config(ref_config, run_info):
    
    '''
    This is a workaround for adjusting the configuration of those runs that didn't have a YAML configuration file logged as an artifact.
    '''
    
    config = deepcopy(ref_config)
    config.network_architecture.latent_dim = int(run_info["params.latent_dim"])
    config.loss.regularization.weight = float(run_info["params.w_kl"])
    config.optimizer.parameters.lr = float(run_info["params.lr"])
    config.sample_sizes = [100, 100, 100, 100]
    
    return config


ref_config = load_yaml_config("config_files/config.yaml")
config = overwrite_ref_config(ref_config, run_info)
pprint(to_dict(config))

In [None]:
pl.utilities.seed.reset_seed() # seed_everything(seed=None)
pl.utilities.seed.seed_everything(seed=None)

In [None]:
dm = get_datamodule(config, perform_setup=True)

In [None]:
model = get_lightning_module(config, dm)

In [None]:
weights = get_model_pretrained_weights(runs_df, exp_id, run_id)

In [None]:
model.model.load_state_dict(_model_pretrained_weights)

Assess perfomance of model

In [None]:
def mse(s1, s2=None):
    if s2 is None:
        s2 = torch.zeros_like(s1)
    return ((s1-s2)**2).sum(-1).mean(-1)

In [None]:
s = dm.dataset[1]['s']
s_hat = model(s)[0][0]
mse(s, s_hat)