In [1]:
import os
import glob

import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
import numpy as np
import yaml

from src.da_utils import data_loading, evaluation

In [36]:
MODEL_NAMES = ("ADDA", "CellDART", "CORAL", "DANN")
FINAL_RESULTS_FOLDER = "results_FINAL"
CONFIGS_DIR = "configs"
DSET = "spotless"
EARLY_STOPPING = False

PS_SEEDS = (3679, 343, 25, 234, 98098)
MODEL_SEEDS = (2353, 24385, 284, 86322, 98237)

def get_results_df(model_name):
    CONFIG_FNAME = f"{model_name.lower()}-final-{DSET}-ht.yml"

    with open(os.path.join(CONFIGS_DIR, model_name, CONFIG_FNAME), "r") as f:
        config = yaml.safe_load(f)

    data_params = config["data_params"]
    model_params = config["model_params"]

    results_dfs = []
    for model_seed in MODEL_SEEDS:
        model_rel_path = data_loading.get_model_rel_path(
            model_name,
            model_params["model_version"],
            lib_seed_path=str(model_seed),
            **data_params,
        )
        results_folder = os.path.join(FINAL_RESULTS_FOLDER, "std", model_rel_path)

        if EARLY_STOPPING:
            results_fname = glob.glob(
                os.path.join(results_folder, "results_checkpt-*.csv"),
            )
            if len(results_fname) != 1:
                raise OSError(
                    f"{len(results_fname)} reverse_checkpt files found "
                    f"in {results_folder}; expected 1"
                )
            results_fname = os.path.basename(results_fname[0])
        else:
            results_fname = "results.csv"

        results_dfs.append(
            pd.read_csv(
                os.path.join(results_folder, results_fname),
                header=[0, 1],
                index_col=[0, 1, 2],
            )
        )

    results_df = pd.concat(results_dfs, keys=MODEL_SEEDS, names=["model_seed"])

    results_df.index.set_names("da", level=1, inplace=True)

    results_mean_samples_df = results_df.groupby(["model_seed", "da", "SC Split"]).mean()
    results_mean_samples_df.groupby(["da", "SC Split"]).mean()

    results_mean_samples_df = results_df.groupby(["model_seed", "da", "SC Split"]).mean()
    results_mean_samples_df.groupby(["da", "SC Split"]).mean()
    return results_mean_samples_df.groupby(["da", "SC Split"]).mean(), results_mean_samples_df.groupby(["da", "SC Split"]).std()

In [41]:
results_d = {model_name: get_results_df(model_name) for model_name in MODEL_NAMES}

for model_name in MODEL_NAMES:
    results_df, results_std_df = results_d[model_name]
    results_df.loc[:, :] = np.asarray(
        [
            f"{round(mean, 4)} ({round(std, 4)})"
            for mean, std in zip(results_df.values.flatten(), results_std_df.values.flatten())
        ]
    ).reshape(results_df.shape)
    results_d[model_name] = results_df


In [46]:
results_d["ADDA"]

Unnamed: 0_level_0,Unnamed: 1_level_0,Pseudospots (Cosine Distance),Pseudospots (Cosine Distance),Pseudospots (Cosine Distance),RF50,RF50,RF50,miLISI (perplexity=5),miLISI (perplexity=5),miLISI (perplexity=5),Real Spots (Cosine Distance)
Unnamed: 0_level_1,Unnamed: 1_level_1,train,val,test,train,val,test,train,val,test,0
da,SC Split,Unnamed: 2_level_2,Unnamed: 3_level_2,Unnamed: 4_level_2,Unnamed: 5_level_2,Unnamed: 6_level_2,Unnamed: 7_level_2,Unnamed: 8_level_2,Unnamed: 9_level_2,Unnamed: 10_level_2,Unnamed: 11_level_2
After DA (final model),test,0.0164 (0.0001),0.0295 (0.0004),0.0279 (0.0001),0.8344 (0.2259),0.9744 (0.0093),0.9773 (0.0079),1.3532 (0.1328),1.4027 (0.2377),1.3524 (0.1598),0.3278 (0.0078)
After DA (final model),train,0.0164 (0.0001),0.0295 (0.0004),0.0279 (0.0001),0.9161 (0.0737),0.8389 (0.102),0.8346 (0.0377),1.3763 (0.0857),1.4098 (0.1027),1.3829 (0.0132),0.5832 (0.0219)
Before DA,test,0.0164 (0.0001),0.0295 (0.0004),0.0279 (0.0001),0.9872 (0.0041),0.9876 (0.0053),0.9895 (0.0026),1.0527 (0.031),1.0787 (0.0493),1.0475 (0.0106),0.3301 (0.0143)
Before DA,train,0.0164 (0.0001),0.0295 (0.0004),0.0279 (0.0001),0.9098 (0.0194),0.9077 (0.0376),0.9309 (0.0169),1.1227 (0.041),1.1519 (0.0508),1.1126 (0.0172),0.5975 (0.0031)


In [45]:
pd.concat({model_name:results_d[model_name].loc["After DA (final model)", "Real Spots (Cosine Distance)"].T for model_name in MODEL_NAMES})

Unnamed: 0,SC Split,test,train
ADDA,0,0.3278 (0.0078),0.5832 (0.0219)
CellDART,0,0.3619 (0.0828),0.6527 (0.0433)
CORAL,0,0.292 (0.0266),0.6106 (0.0145)
DANN,0,0.2888 (0.0125),0.5641 (0.0144)


In [52]:
pd.concat({model_name:results_d[model_name].loc[("After DA (final model)", "test"), "Pseudospots (Cosine Distance)"] for model_name in MODEL_NAMES}, axis=1).T

Unnamed: 0,train,val,test
ADDA,0.0164 (0.0001),0.0295 (0.0004),0.0279 (0.0001)
CellDART,0.0868 (0.0031),0.0873 (0.0033),0.0881 (0.0031)
CORAL,0.6176 (0.036),0.6179 (0.0367),0.6173 (0.0376)
DANN,0.1906 (0.0394),0.1905 (0.0396),0.1917 (0.0385)


In [54]:
pd.concat({model_name:results_d[model_name].loc[("After DA (final model)", "test"), "miLISI (perplexity=5)"] for model_name in MODEL_NAMES}, axis=1).T

Unnamed: 0,train,val,test
ADDA,1.3532 (0.1328),1.4027 (0.2377),1.3524 (0.1598)
CellDART,1.0521 (0.0341),1.0481 (0.0393),1.0383 (0.0589)
CORAL,1.5847 (0.2054),1.5247 (0.0908),1.5464 (0.1523)
DANN,1.3246 (0.1064),1.2672 (0.1678),1.2894 (0.1959)
