In [None]:
import seaborn as sns
import pandas as pd
from matplotlib import pyplot as plt
import numpy as np
import json

In [None]:
def json_to_df(source: str, stat: str, measurement: str, hypothesis: str):
    with open(source) as f:    
        data = json.load(f)
    _df = pd.DataFrame()
    for model, analysis in data.items():
        _mdf = pd.DataFrame()
        for dataset, metrics in analysis.items():
            df_δ = pd.DataFrame(metrics["δ"]).reset_index().rename({"index": "src"}, axis=1).melt(id_vars=["src"], var_name="tgt", value_name="δ")
            df_p = pd.DataFrame(metrics["p"]).reset_index().rename({"index": "src"}, axis=1).melt(id_vars=["src"], var_name="tgt", value_name="p")
            df = df_δ.merge(df_p)
            df["dataset"] = dataset
            _mdf = pd.concat([_mdf, df])
        _mdf["model"] = model
        _df = pd.concat([_df, _mdf])
    _df["stat"] = stat
    _df["measurement"] = measurement
    _df["hypothesis"] = hypothesis
    _df = _df.set_index(["model", "stat", "measurement", "hypothesis", "dataset"])
    return _df

from itertools import product

df_stat = pd.DataFrame()
basepath = "tests_20240320"
norm = "valset"
perms = "3000"
sampsize = "100"

for (stat, hypothesis, experiment, measurement) in product(["auc", "mrpp"], ["h1", "h2"], ["toy", "dom"], ["metrics", "embeddings"]):
    fp = f"{basepath}/{stat}_{norm}_p{perms}_s{sampsize}_{hypothesis}/stats_{experiment}_{measurement}.json"
    df_stat = pd.concat([df_stat, json_to_df(fp, stat=stat, measurement=measurement, hypothesis=hypothesis)])

df_stat.to_csv(f"agg_stats/{basepath}.csv")

In [None]:
df = pd.read_csv(f"agg_stats/{basepath}.csv").sort_index().reset_index()
df.loc[df.dataset == "QPM_species_A_strain-level", "dataset"] = "AMRB_ST_A"
df.loc[df.dataset == "QPM_species_B_strain-level", "dataset"] = "AMRB_ST_B"
df.loc[df.dataset == "QPM_species_A_species-level", "dataset"] = "AMRB_SP_A"
df.loc[df.dataset == "QPM_species_B_species-level", "dataset"] = "AMRB_SP_B"
#
models = df.model.unique()
stats = df.stat.unique()
measurements = df.measurement.unique()
hypotheses = df.hypothesis.unique()
datasets = df.dataset.str[:-2].unique()

Nt = [x for x in datasets if not x.startswith("AMRB")]
Nd = [x for x in datasets if x.startswith("AMRB")]

for (stat, model, measurement) in product(stats, models, measurements):

    df_statmodel = df[(df.model == model) & (df.stat == stat) & (df.measurement == measurement)][["hypothesis", "dataset", "src", "tgt", "δ", "p"]]

    if len(Nt) > 0:
        fig, axs = plt.subplots(nrows=2, ncols=len(Nt), dpi=150, figsize=(5*len(Nt), 8), sharex="col")
        print(stat, model, measurement, "toy datasets")
        for i, dataset in enumerate(Nt):
            data_a = df_statmodel[(df_statmodel.hypothesis == "h2") & (df_statmodel.dataset == f"{dataset}_A")][["src", "tgt", "δ", "p"]]
            data_b = df_statmodel[(df_statmodel.hypothesis == "h2") & (df_statmodel.dataset == f"{dataset}_B")][["src", "tgt", "δ", "p"]]
            if len(data_a) == 0 or len(data_b) == 0:
                continue
            kwargs = {}
            piv_a_δ = pd.pivot(data_a, index="src", columns="tgt", values="δ").sort_index()
            piv_b_δ = pd.pivot(data_b, index="src", columns="tgt", values="δ").sort_index()
            piv_a_p = pd.pivot(data_a, index="src", columns="tgt", values="p").sort_index()
            piv_b_p = pd.pivot(data_b, index="src", columns="tgt", values="p").sort_index()
            ax1, ax2 = axs[0][i], axs[1][i]
            piv_δ = pd.concat([piv_a_δ, piv_b_δ], axis=1)
            piv_p = pd.concat([piv_a_p, piv_b_p], axis=1)
            piv_δ = piv_δ.loc[piv_δ.columns]
            piv_p = piv_p.loc[piv_p.columns]
            sns.heatmap(piv_δ, ax=ax1, cmap=sns.color_palette("Blues", as_cmap=True), **kwargs)
            sns.heatmap(piv_p, ax=ax2, cmap=sns.color_palette("Blues", as_cmap=True), vmin=0, vmax=1)
            ka = len(piv_a_δ.columns)
            kb = len(piv_b_δ.columns)
            ra, rb = plt.Rectangle([0,0], ka, ka, fill=False, color="gray"), plt.Rectangle([ka, ka], kb, kb, fill=False, color="gray")
            ax1.add_patch(ra)
            ax1.add_patch(rb)
            ra, rb = plt.Rectangle([0,0], ka, ka, fill=False, color="gray"), plt.Rectangle([ka, ka], kb, kb, fill=False, color="gray")
            ax2.add_patch(ra)
            ax2.add_patch(rb)
            ax1.set(xlabel="", ylabel="")
            ax2.set(xlabel="", ylabel="")
        plt.show()
        plt.close()

    if len(Nd) > 0:
        fig, axs = plt.subplots(nrows=2, ncols=len(Nd), dpi=150, figsize=(5*len(Nd), 8), sharex="col")
        print(stat, model, measurement, "bacteria datasets")
        for i, dataset in enumerate(Nd):
            data_a = df_statmodel[(df_statmodel.hypothesis == "h2") & (df_statmodel.dataset == f"{dataset}_A")][["src", "tgt", "δ", "p"]]
            data_b = df_statmodel[(df_statmodel.hypothesis == "h2") & (df_statmodel.dataset == f"{dataset}_B")][["src", "tgt", "δ", "p"]]
            if len(data_a) == 0 or len(data_b) == 0:
                continue
            kwargs = {}
            piv_a_δ = pd.pivot(data_a, index="src", columns="tgt", values="δ").sort_index()
            piv_b_δ = pd.pivot(data_b, index="src", columns="tgt", values="δ").sort_index()
            piv_a_p = pd.pivot(data_a, index="src", columns="tgt", values="p").sort_index()
            piv_b_p = pd.pivot(data_b, index="src", columns="tgt", values="p").sort_index()
            ax1, ax2 = axs[0][i], axs[1][i]
            piv_δ = pd.concat([piv_a_δ, piv_b_δ], axis=1)
            piv_p = pd.concat([piv_a_p, piv_b_p], axis=1)
            piv_δ = piv_δ.loc[piv_δ.columns]
            piv_p = piv_p.loc[piv_p.columns]
            sns.heatmap(piv_δ, ax=ax1, cmap=sns.color_palette("Blues", as_cmap=True), **kwargs)
            sns.heatmap(piv_p, ax=ax2, cmap=sns.color_palette("Blues", as_cmap=True), vmin=0, vmax=1)
            ka = len(piv_a_δ.columns)
            kb = len(piv_b_δ.columns)
            ra, rb = plt.Rectangle([0,0], ka, ka, fill=False, color="gray"), plt.Rectangle([ka, ka], kb, kb, fill=False, color="gray")
            ax1.add_patch(ra)
            ax1.add_patch(rb)
            ra, rb = plt.Rectangle([0,0], ka, ka, fill=False, color="gray"), plt.Rectangle([ka, ka], kb, kb, fill=False, color="gray")
            ax2.add_patch(ra)
            ax2.add_patch(rb)
            ax1.set(xlabel="", ylabel="")
            ax2.set(xlabel="", ylabel="")
        plt.show()
        plt.close()