Set path to where `mlruns` directory is located (usually, the `CardiacCOMA` repository)

In [1]:
CARDIAC_COMA_REPO = "/home/rodrigo/CISTIB/repos/CardiacCOMA/"

In [2]:
import mlflow
import os, sys

import torch
import torch.nn.functional as F

import os; os.chdir(CARDIAC_COMA_REPO)
from config.load_config import load_yaml_config, to_dict

import ipywidgets as widgets
from ipywidgets import interact
from IPython.display import Image
from mlflow.tracking import MlflowClient

import pickle as pkl
import pytorch_lightning as pl

from argparse import Namespace
import matplotlib.pyplot as plt

#import surgeon_pytorch
#from surgeon_pytorch import Inspect, get_layers

import numpy as np
import pandas as pd
from IPython import embed
sys.path.insert(0, '..')

import model.Model3D
from utils.helpers import get_coma_args, get_lightning_module, get_datamodule
from copy import deepcopy
from pprint import pprint

from copy import deepcopy
from typing import List

In [28]:
from functools import partial

In [3]:
import matplotlib.pyplot as plt
import seaborn as sns

In [4]:
from mlflow_helpers import \
    list_artifacts,\
    get_significant_loci,\
    get_metrics_cols, \
    get_params_cols, \
    get_runs_df, \
    get_good_runs,\
    summarize_loci_across_runs,\
    get_model_pretrained_weights

In [5]:
TRACKING_URI = f"file://{CARDIAC_COMA_REPO}/mlruns"
mlflow.set_tracking_uri(TRACKING_URI)

# Select MLflow experiment

In [6]:
client = MlflowClient()

In [27]:
def experiment_selection_widget():
    options = [exp.name for exp in mlflow.list_experiments()]

    experiment_w = widgets.Select(
      options=options,
      value=options[1]
    )
    
    return experiment_w

exp_w = experiment_selection_widget()

@interact
def get_runs(exp_name=exp_w):  
  try:
    global exp_id, runs_df
    exp_id = mlflow.get_experiment_by_name(exp_name).experiment_id
    _get_runs_df = partial(get_runs_df, sort_by=None)
    runs_df = _get_runs_df(exp_name=exp_name, only_finished=True)
    metrics, params = get_metrics_cols(runs_df), get_params_cols(runs_df)  
    # display(runs_df.loc[:, [*metrics, *params]].drop("params.platform", axis=1).head(10))    
  except:
    pass

A Jupyter Widget

Retrieve run data from MLflow for the chosen experiment

In [29]:
RECON_LOSS_THRES = 1. # performance threshold for MSE mm2.
run_ids = sorted([x[1] for x in runs_df[runs_df["metrics.test_recon_loss"] < RECON_LOSS_THRES].index])

In [30]:
run_ids_w = widgets.Select(description="Choose run:", options={x[:10]: x for x in run_ids})
display(run_ids_w)
run_id = run_ids_w.value
run_info = runs_df.loc[exp_id, run_id].to_dict()
artifact_uri = run_info["artifact_uri"].replace("file://", "")

A Jupyter Widget

KeyError: '999'

In [None]:
#runs_df_ = runs_df[runs_df["metrics.recon_loss"].astype(float) < 0.4]
ORDER_BY = {"by":"count", "ascending":False}
ORDER_BY = {"by":"min_P", "ascending":True}
ORDER_BY = {"by":"-log10(min_P)", "ascending":False}

runs_df_ = runs_df[runs_df["params.w_kl"].astype(float) == 0]
loci_summary_df = summarize_loci_across_runs(runs_df)
loci_summary_df["-log10(min_P)"] = loci_summary_df.apply(lambda row: -np.log10(row["min_P"]), axis=1)
loci_summary_df.sort_values(**ORDER_BY, axis=0)

In [None]:
def summarize_loci_across_runs(runs_df: pd.DataFrame):

    '''
    Parameters: run_ids
    Return: pd.DataFrame with ["count", "min_P"].
    '''

    # run_ids = sorted([x[1] for x in runs_df[runs_df["metrics.test_recon_loss"] < RECON_LOSS_THRES].index])
    run_ids = sorted([x[1] for x in runs_df.index])

    all_signif_loci = pd.concat([
      get_significant_loci(runs_df, "1", run).\
        assign(run=run).\
        reset_index().\
        set_index(["run", "pheno", "region"]) 
      for run in run_ids
    ])
    
    return all_signif_loci

In [None]:
kk = summarize_loci_across_runs(runs_df).reset_index().drop("index", axis=1)
kk.pheno = kk.apply(lambda x: f"1_{x.run[:5]}_{x.pheno}", axis=1)

In [None]:
z_corr = pd.read_csv("data/cardio/corr_z_vs_indices.csv").set_index("phenotype")

In [None]:
corrs = []

for index, row in pp.sort_values(by="region").iterrows():
    try:
        corrs.append(list(z_corr.loc[row.pheno]))
    except:
        corrs.append([pd.NA]*4)        

In [None]:
corrs_df = pd.DataFrame(corrs, columns=["LVEDV_corr", "LVM_corr", "RVEDV_corr", "LVSph_corr"])
corrs_df.set_index(pp.index)

In [None]:
kk_grouped = pd.concat([kk, corrs_df.abs()], axis=1).groupby("region")

In [None]:
from functools import partial
mean_f = partial(pd.Series.mean, skipna = True)
std_f = partial(pd.Series.std, skipna = True)

In [None]:
counts = kk_grouped.agg("count")["LVEDV_corr"]

In [None]:
phenos =  ["LVEDV", "LVM", "RVEDV", "LVSph"]
corr_per_locus = kk_grouped.aggregate(func={f"{pheno}_corr": [mean_f, std_f] for pheno in phenos})

In [None]:
corr_per_locus["counts"] = counts

In [None]:
corr_per_locus.sort_values(by="counts", ascending=False)

# Statistics on the GWAS loci counts

In [None]:
signif_loci_dfs = {}
dd = []

def loci_count(run_df):
    from collections import Counter
    return dict(Counter([x[1] for x in run_df.index]))

for run in runs_df.index:
    
    try:     
      
      pp = get_significant_loci(runs_df[runs_df["metrics.val_recon_loss"] < 2], exp_id, run[1]) #.sort_values(by=["CHR", "BP"], axis=0)
      n_distinct_loci = len(loci_cnt.keys())
      n_hits_with_duplication = sum(loci_cnt.values())
      
      ff = [  run[1], 
         runs_df.loc[run, "metrics.test_recon_loss"], 
         runs_df.loc[run, "metrics.test_kld_loss"], 
         runs_df.loc[run, "params.latent_dim"], 
         runs_df.loc[run, "params.w_kl"],
         n_distinct_loci, 
         n_hits_with_duplication, 
         n_hits_with_duplication / n_distinct_loci             
      ]
      
      signif_loci_dfs[run[1]] = pp
      loci_cnt = loci_count(signif_loci_dfs[run[1]])
      dd.append(ff)
    except:
      pass

kk = pd.DataFrame(dd)

kk.columns = [
    "run_id",
    "test_mse",
    "kld",    
    "lat_dim",
    "w_kl",
    "n_loci",
    "n_loci_dupl",
    "ratio"    
]

In [None]:
interact(
    lambda xcol, ycol: sns.boxplot(x=xcol, y=ycol, data=kk),
    xcol = widgets.Select(options=kk.columns),
    ycol = widgets.Select(options=kk.columns)
);

In [None]:
@interact
def show_signif_loci(run_id=run_ids_w):
    return get_significant_loci(runs_df, exp_id, run_id)