Set path to where `mlruns` directory is located (usually, the `CardiacCOMA` repository)

In [5]:
CARDIAC_COMA_REPO = "/home/rodrigo/CISTIB/repos/CardiacCOMA/"

In [6]:
import mlflow
import os, sys

import torch
import torch.nn.functional as F

import os; os.chdir(CARDIAC_COMA_REPO)
from config.load_config import load_yaml_config, to_dict

import ipywidgets as widgets
from ipywidgets import interact
from IPython.display import Image
from mlflow.tracking import MlflowClient

import pickle as pkl
import pytorch_lightning as pl

from argparse import Namespace
import matplotlib.pyplot as plt

#import surgeon_pytorch
#from surgeon_pytorch import Inspect, get_layers

import numpy as np
import pandas as pd
from IPython import embed
sys.path.insert(0, '..')

import model.Model3D
from utils.helpers import get_coma_args, get_lightning_module, get_datamodule
from copy import deepcopy
from pprint import pprint

from copy import deepcopy
from typing import List

In [7]:
import matplotlib.pyplot as plt
import seaborn as sns

In [8]:
from mlflow_helpers import \
    list_artifacts,\
    get_significant_loci,\
    get_metrics_cols, \
    get_params_cols, \
    get_runs_df, \
    get_good_runs,\
    summarize_loci_across_runs,\
    get_model_pretrained_weights

In [9]:
TRACKING_URI = f"file://{CARDIAC_COMA_REPO}/mlruns"
mlflow.set_tracking_uri(TRACKING_URI)

# Select MLflow experiment

In [10]:
client = MlflowClient()

In [11]:
def experiment_selection_widget():
    options = [exp.name for exp in mlflow.list_experiments()]

    experiment_w = widgets.Select(
      options=options,
      value=options[1]
    )
    
    return experiment_w

exp_w = experiment_selection_widget()

@interact
def get_runs(exp_name=exp_w):  
  try:
    exp_id = mlflow.get_experiment_by_name(exp_name).experiment_id
    runs_df = get_runs_df(exp_name=exp_name, only_finished=True)
    metrics, params = get_metrics_cols(runs_df), get_params_cols(runs_df)  
    # display(runs_df.loc[:, [*metrics, *params]].drop("params.platform", axis=1).head(10))
    return runs_df
  except:
    pass

interactive(children=(Select(description='exp_name', index=1, options=('Cardiac - ED', 'Default'), value='Defa…

In [12]:
runs_df = get_runs_df(exp_name=exp_w.value, only_finished=True)

Retrieve run data from MLflow for the chosen experiment

In [13]:
RECON_LOSS_THRES = 1. # performance threshold for MSE mm2.
run_ids = sorted([x[1] for x in runs_df[runs_df["metrics.test_recon_loss"] < RECON_LOSS_THRES].index])

In [10]:
run_ids_w = widgets.Select(description="Choose run:", options={x[:10]: x for x in run_ids})
display(run_ids_w)
run_id = run_ids_w.value
run_info = runs_df.loc[exp_id, run_id].to_dict()
artifact_uri = run_info["artifact_uri"].replace("file://", "")

A Jupyter Widget

NameError: name 'exp_id' is not defined

In [None]:
#runs_df_ = runs_df[runs_df["metrics.recon_loss"].astype(float) < 0.4]
ORDER_BY = {"by":"count", "ascending":False}
ORDER_BY = {"by":"min_P", "ascending":True}
ORDER_BY = {"by":"-log10(min_P)", "ascending":False}

runs_df_ = runs_df[runs_df["params.w_kl"].astype(float) == 0]
loci_summary_df = summarize_loci_across_runs(runs_df)
loci_summary_df["-log10(min_P)"] = loci_summary_df.apply(lambda row: -np.log10(row["min_P"]), axis=1)
loci_summary_df.sort_values(**ORDER_BY, axis=0)

In [14]:
import pandas as pd

In [15]:
def summarize_loci_across_runs(runs_df: pd.DataFrame):

    '''
    Parameters: run_ids
    Return: pd.DataFrame with ["count", "min_P"].
    '''

    # run_ids = sorted([x[1] for x in runs_df[runs_df["metrics.test_recon_loss"] < RECON_LOSS_THRES].index])
    run_ids = sorted([x[1] for x in runs_df.index])

    all_signif_loci = pd.concat([
      get_significant_loci(runs_df, "1", run).\
        assign(run=run).\
        reset_index().\
        set_index(["run", "pheno", "region"]) 
      for run in run_ids
    ])
    
    return all_signif_loci

In [64]:
kk = summarize_loci_across_runs(runs_df).reset_index().drop("index", axis=1)
kk.pheno = kk.apply(lambda x: f"1_{x.run[:5]}_{x.pheno}", axis=1)

In [24]:
z_corr = pd.read_csv("data/cardio/corr_z_vs_indices.csv").set_index("phenotype")

Unnamed: 0,run,pheno,region,CHR,SNP,BP,AF,a_0,a_1,BETA,SE,T,P,locus_name
0,0285fa2356fd454e88e3c30d6b63f163,z007,chr6_78,6.0,rs11153730,118667522.0,0.49203,T,C,-0.076241,0.007928,-9.6162,7.311391e-22,PLN
1,0285fa2356fd454e88e3c30d6b63f163,z007,chr6_79,6.0,rs10872167,118988362.0,0.46056,A,G,-0.071901,0.007982,-9.0084,2.202926e-19,PLN
2,0285fa2356fd454e88e3c30d6b63f163,z012,chr2_108,2.0,rs2042995,179558366.0,0.22166,T,C,-0.071122,0.009542,-7.4534,9.332543e-14,TTN
3,0285fa2356fd454e88e3c30d6b63f163,z004,chr2_108,2.0,rs2042995,179558366.0,0.22166,T,C,0.069439,0.009544,7.2755,3.531832e-13,TTN
4,0285fa2356fd454e88e3c30d6b63f163,z006,chr2_108,2.0,rs2042995,179558366.0,0.22166,T,C,-0.069385,0.009546,-7.2685,3.723917e-13,TTN
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
1201,ff7c594accb74ad7947621a6b3e2527a,z002,chr6_79,6.0,rs11756440,118993642.0,0.47505,C,A,0.046588,0.007990,5.8307,5.573141e-09,PLN
1202,ff7c594accb74ad7947621a6b3e2527a,z011,chr17_27,17.0,rs117953218,45244074.0,0.13850,T,C,-0.067495,0.011645,-5.7959,6.861199e-09,GOSR2
1203,ff7c594accb74ad7947621a6b3e2527a,z007,chr17_27,17.0,rs11570508,45228560.0,0.22405,C,A,-0.055592,0.009600,-5.7911,7.059924e-09,GOSR2
1204,ff7c594accb74ad7947621a6b3e2527a,z006,chr6_79,6.0,rs11756440,118993642.0,0.47505,C,A,-0.045982,0.007994,-5.7521,8.892011e-09,PLN


In [58]:
corrs = []

for index, row in pp.sort_values(by="region").iterrows():
    try:
        corrs.append(list(z_corr.loc[row.pheno]))
    except:
        corrs.append([pd.NA]*4)        

In [72]:
corrs_df = pd.DataFrame(corrs, columns=["LVEDV_corr", "LVM_corr", "RVEDV_corr", "LVSph_corr"])
corrs_df.set_index(pp.index)

Unnamed: 0,LVEDV_corr,LVM_corr,RVEDV_corr,LVSph_corr
0,0.621646,0.519715,0.530295,-0.088648
1,-0.641221,-0.480643,-0.528625,-0.02096
2,0.651826,0.483752,0.542813,0.055546
3,0.716259,0.546951,0.61227,0.246831
4,0.728424,0.538357,0.615668,0.172088
...,...,...,...,...
1201,0.657502,0.497687,0.568479,-0.07987
1202,-0.643888,-0.503688,-0.590995,0.00454
1203,-0.632191,-0.517196,-0.574951,-0.278471
1204,-0.623888,-0.479795,-0.545096,0.02543


In [None]:
pd.agg

In [91]:
kk_grouped = pd.concat([kk, corrs_df.abs()], axis=1).groupby("region")

In [81]:
from functools import partial
#s_na_mean = partial(pd.Series.mean, skipna = True)

In [92]:
mean_f = partial(pd.Series.mean, skipna = True)
std_f = partial(pd.Series.std, skipna = True)

In [102]:
counts = kk_grouped.agg("count")["LVEDV_corr"]

In [103]:
corr_per_locus = kk_grouped.aggregate(func={"LVEDV_corr": [mean_f, std_f], "LVM_corr": [mean_f, std_f], "RVEDV_corr": [mean_f, std_f], "LVSph_corr": [mean_f, std_f]})

In [105]:
corr_per_locus["counts"] = counts

In [108]:
corr_per_locus.sort_values(by="counts", ascending=False)

Unnamed: 0_level_0,LVEDV_corr,LVEDV_corr,LVM_corr,LVM_corr,RVEDV_corr,RVEDV_corr,LVSph_corr,LVSph_corr,counts
Unnamed: 0_level_1,mean,std,mean,std,mean,std,mean,std,Unnamed: 9_level_1
region,Unnamed: 1_level_2,Unnamed: 2_level_2,Unnamed: 3_level_2,Unnamed: 4_level_2,Unnamed: 5_level_2,Unnamed: 6_level_2,Unnamed: 7_level_2,Unnamed: 8_level_2,Unnamed: 9_level_2
chr2_108,0.554358,0.216404,0.414105,0.175413,0.469407,0.189154,0.247178,0.214509,311
chr6_78,0.541329,0.229387,0.407375,0.181504,0.459942,0.199748,0.268299,0.228139,184
chr12_69,0.540362,0.223193,0.39892,0.178696,0.454546,0.195408,0.261623,0.22152,159
chr6_79,0.547776,0.231587,0.407717,0.183816,0.462688,0.202715,0.263249,0.230405,156
chr17_27,0.542719,0.235192,0.405957,0.187949,0.462545,0.205292,0.26575,0.234328,74
chr5_103,0.541524,0.206925,0.424351,0.154115,0.469862,0.168764,0.189173,0.17733,44
chr12_67,0.50321,0.271738,0.3829,0.210662,0.42635,0.235438,0.291752,0.244292,29
chr6_84,0.595206,0.195004,0.440869,0.157941,0.504311,0.165896,0.255052,0.222846,26
chr1_124,0.569561,0.216011,0.436589,0.175478,0.49324,0.183317,0.22617,0.231334,21
chr10_69,0.567403,0.184704,0.424344,0.156419,0.476595,0.171887,0.267065,0.184819,16


# Statistics on the GWAS loci counts

In [None]:
signif_loci_dfs = {}
dd = []

def loci_count(run_df):
    from collections import Counter
    return dict(Counter([x[1] for x in run_df.index]))

for run in runs_df.index:
    
    try:     
      
      pp = get_significant_loci(runs_df[runs_df["metrics.val_recon_loss"] < 2], exp_id, run[1]) #.sort_values(by=["CHR", "BP"], axis=0)
      n_distinct_loci = len(loci_cnt.keys())
      n_hits_with_duplication = sum(loci_cnt.values())
      
      ff = [  run[1], 
         runs_df.loc[run, "metrics.test_recon_loss"], 
         runs_df.loc[run, "metrics.test_kld_loss"], 
         runs_df.loc[run, "params.latent_dim"], 
         runs_df.loc[run, "params.w_kl"],
         n_distinct_loci, 
         n_hits_with_duplication, 
         n_hits_with_duplication / n_distinct_loci             
      ]
      
      signif_loci_dfs[run[1]] = pp
      loci_cnt = loci_count(signif_loci_dfs[run[1]])
      dd.append(ff)
    except:
      pass

kk = pd.DataFrame(dd)

kk.columns = [
    "run_id",
    "test_mse",
    "kld",    
    "lat_dim",
    "w_kl",
    "n_loci",
    "n_loci_dupl",
    "ratio"    
]

In [None]:
interact(
    lambda xcol, ycol: sns.boxplot(x=xcol, y=ycol, data=kk),
    xcol = widgets.Select(options=kk.columns),
    ycol = widgets.Select(options=kk.columns)
);

In [None]:
@interact
def show_signif_loci(run_id=run_ids_w):
    return get_significant_loci(runs_df, exp_id, run_id)

In [None]:
def overwrite_ref_config(ref_config, run_info):
    
    '''
    This is a workaround for adjusting the configuration of those runs that didn't have a YAML configuration file logged as an artifact.
    '''
    
    config = deepcopy(ref_config)
    config.network_architecture.latent_dim = int(run_info["params.latent_dim"])
    config.loss.regularization.weight = float(run_info["params.w_kl"])
    config.optimizer.parameters.lr = float(run_info["params.lr"])
    config.sample_sizes = [100, 100, 100, 100]
    
    return config


ref_config = load_yaml_config("config_files/config.yaml")
config = overwrite_ref_config(ref_config, run_info)
pprint(to_dict(config))

In [None]:
pl.utilities.seed.reset_seed() # seed_everything(seed=None)
pl.utilities.seed.seed_everything(seed=None)

In [None]:
dm = get_datamodule(config, perform_setup=True)

In [None]:
model = get_lightning_module(config, dm)

In [None]:
weights = get_model_pretrained_weights(runs_df, exp_id, run_id)

In [None]:
model.model.load_state_dict(_model_pretrained_weights)

Assess perfomance of model

In [None]:
def mse(s1, s2=None):
    if s2 is None:
        s2 = torch.zeros_like(s1)
    return ((s1-s2)**2).sum(-1).mean(-1)

In [None]:
s = dm.dataset[1]['s']
s_hat = model(s)[0][0]
mse(s, s_hat)