In [None]:
from mlflow.tracking import MlflowClient
import pandas as pd
mlflow_path = "/home/davina/Private/repos/CRRT/mlruns"
client = MlflowClient(mlflow_path)
adult_run_id = "ad70333e489c457386c53d079028454b"
peds_run_id = "ed0cfbb4ff654c78b61a2299137e5864"

In [None]:
import seaborn as sns
import matplotlib.pyplot as plt
import numpy as np

run = client.get_run(adult_run_id)
metrics = run.data.metrics

In [None]:
from os.path import join
experiment_id = client.get_experiment_by_name("static_learning").experiment_id
pd.read_csv(join(mlflow_path, experiment_id, run.info.run_id, "artifacts", "xgb_test__dist_comparison_table.txt"), sep="\t")

In [None]:
for subgroup in ["", "_heart", "_liver", "_infection"]:
    prefix = f"xgb_test{subgroup}"

    ### Confusion Matrix ###
    confusion_matrix = np.reshape(
        [
            metrics[v] for v in 
            [ f"{prefix}__TP", f"{prefix}__FP", f"{prefix}__FN", f"{prefix}__TN" ]
        ],
        (-1,2)
    )
    # ref: https://stackoverflow.com/a/29648332/1888794
    ax = sns.heatmap(confusion_matrix, annot=True, fmt="g", cmap='Blues', cbar=False)
    subgroup = "_all" if subgroup == "" else subgroup
    ax.set_title('Confusion Matrix' + subgroup.replace("_", ": ")+ '\n');
    ax.set_ylabel('\nPredicted Values')
    ax.set_xlabel('Actual Values ');

    ax.xaxis.set_ticklabels(["Recommend", "Not Recommend"])
    ax.yaxis.set_ticklabels(["Recommend", "Not Recommend"])

    ## Display the visualization of the Confusion Matrix.
    plt.show()

In [None]:
table = {}

for metric_name in ["accuracy", "ap", "auroc", "brier", "precision", "recall"]:
    table[metric_name] = {}
    for subgroup in ["", "_heart", "_liver", "_infection"]:
        prefix = f"xgb_test{subgroup}"
        subgroup = "all" if subgroup == "" else subgroup
        table[metric_name][subgroup.replace("_","")] = metrics[f"{prefix}__{metric_name}"]
pd.DataFrame(table)