In [None]:
import mlflow
import os, sys

import torch
import torch.nn.functional as F

HOME = os.environ["HOME"]
CARDIAC_COMA_REPO = f"{HOME}/01_repos/CardiacCOMA/"
import os; os.chdir(CARDIAC_COMA_REPO)

from config.load_config import load_yaml_config, to_dict

import ipywidgets as widgets
from ipywidgets import interact
from IPython.display import Image
from mlflow.tracking import MlflowClient

import pandas as pd
import shlex
from subprocess import check_output

import pickle as pkl
import pytorch_lightning as pl

from argparse import Namespace
import matplotlib.pyplot as plt

#import surgeon_pytorch
#from surgeon_pytorch import Inspect, get_layers

import numpy as np
import pandas as pd
from IPython import embed
sys.path.insert(0, '..')

# import model.Model3D
# from utils.helpers import get_coma_args, get_lightning_module, get_datamodule
from copy import deepcopy
from pprint import pprint

from typing import List
from tqdm import tqdm

import pyvista as pv
from ipywidgets import interact, interactive, fixed, interact_manual

In [None]:
from utils.mlflow_queries import \
    list_artifacts,\
    get_significant_loci,\
    get_metrics_cols, \
    get_params_cols, \
    get_runs_df, \
    get_good_runs,\
    summarize_loci_across_runs,\
    get_model_pretrained_weights

In [None]:
TRACKING_URI = f"file://{CARDIAC_COMA_REPO}/mlruns"
mlflow.set_tracking_uri(TRACKING_URI)
client = MlflowClient()

In [None]:
def experiment_selection_widget():
    
    '''
    Returns a selection widget for MLflow experiments.
    '''
    
    options = [exp.name for exp in mlflow.list_experiments()]

    experiment_w = widgets.Select(
      options=options,
      value="Cardiac - ED"
    )
    
    return experiment_w

In [None]:
exp_w = experiment_selection_widget()

@interact
def get_runs(exp_name=exp_w):
    try:  
        global runs_df
        exp_id = mlflow.get_experiment_by_name(exp_name).experiment_id
        runs_df = get_runs_df(exp_name=exp_name, only_finished=True)
        metrics, params    = get_metrics_cols(runs_df), get_params_cols(runs_df)  
        # display(runs_df.loc[:, [*metrics, *params]].drop("params.platform", axis=1).head(10))
    except:
        pass

In [None]:
def change_col_names(exper_id, run_id, df):
    df.columns = [f"{exper_id}_{run_id[:5]}_{col}" for col in df.columns]
    return df

def path_to_z(row):
    run_id = (row.experiment_id, row.run_id)
    artifacts_dir = row.artifact_uri.replace("file://", "") 
    z_adj_path = f'''output/z_adj_{row.experiment_id}_{row.run_id}.tsv'''    
    z_adj_path = os.path.join(artifacts_dir, z_adj_path)    
    return (run_id, z_adj_path)

In [None]:
z_paths = runs_df.reset_index().apply(path_to_z, axis=1)
z_paths = dict(z_paths.values.tolist())

In [None]:
z_dfs = {}
for run_id, z_path in tqdm(z_paths.items()):
    try:
        z_dfs[run_id] = pd.read_csv(z_path, sep="\t").set_index("ID")
    except FileNotFoundError:
        # print(f"Latent vector file does not exist for {run_id}")
        pass
    
z_dfs_renamed = [change_col_names(expid, runid, z_df) for (expid, runid), z_df in z_dfs.items()]    
z_all_df = pd.concat(z_dfs_renamed, axis=1)
z_all_df.head()

# Genomic PCA

In [None]:
GENOMIC_PC_FILE = f'''{os.environ["HOME"]}/01_repos/GWAS_pipeline/data/transforms/GenomicPCA/pcs.txt'''

In [None]:
genomic_pca_df = pd.read_csv(GENOMIC_PC_FILE, sep="\t")
genomic_pca_df = genomic_pca_df.set_index("IID").drop("FID", axis=1)

In [None]:
genomic_pca_df

### Correlation genomic PCs vs. latent variables

In [None]:
import statsmodels.api as sm

In [None]:
from scipy.stats import spearmanr

In [None]:
genomic_pca_df.loc[z_all_df.index]

In [None]:
spearman_coef, spearman_pvalue = spearmanr(a=z_all_df, b=genomic_pca_df.loc[z_all_df.index])

In [None]:
# pd.DataFrame(spearman_coef)
np.log10(pd.DataFrame(spearman_pvalue[:-10,-10:])).describe()

### Correlation genomic PCs vs. traditional cardiac indices

In [None]:
timeframe = "1".zfill(3) # 001 --> end-diastole
datafolder = "data/cardio/cardiac_indices"

df = pd.concat([
    pd.read_csv(f"{datafolder}/G{i}/LVRV_time{timeframe}.csv", index_col="case_id") 
    for i in range(1,5)
])

df.index = df.index.astype(str)

df.head()

In [None]:
sph_df = pd.read_csv("data/cardio/sphericity.csv").set_index("id")
sph_df.index = sph_df.index.astype(str)

In [None]:
cardiac_indices_df = df.merge(sph_df, left_index=True, right_index=True)