In [None]:
# Add autoreload magic
%load_ext autoreload
%autoreload 2

In [None]:
import warnings

warnings.filterwarnings("ignore")

import torch
from pathlib import Path
import pandas as pd
import numpy as np
import pickle
import scanpy as sc
from addict import Dict

import matplotlib.pyplot as plt
import seaborn as sns

from mmfm.data import dgp_iccite_data
from mmfm.utils_eval import (
    load_all_fm_models,
    process_all_fm_models,
    get_model_battery_of_best_models,
    predict_on_testset_mmfm,
    predict_on_testset_fsi,
)
from mmfm.models import VectorFieldModel, MultiVectorFieldModelTCFM


# plt.style.use(["science", "no-latex"])

In [None]:
PROD = True
DGP = "a"
coupling = "cot"
minimum_seeds = 1
batch_size = None
hvg = None
train_test_split = 0.8
subsample_frac = None
top_n_effects = None
leave_out_middle = None
leave_out_end = None
preset = "z"
n_samples_per_c_in_b = 100
use_pca = 10
embedding_type = "free"
classifier_free = False

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
path_name_fsi = "/home/rohbeckm/scratch/results/dgp_iccite/results_fsi"
verbose = False

if preset == "a":
    n_classes = 50
    label_list = list(range(1, 50 + 1))
    df_experiment = pd.read_csv(
        "/home/rohbeckm/code/mmfm/benchmark/dgp_iccite/data/experiment_50_20_random.csv", index_col=0
    )
elif preset == "b":
    n_classes = 70
    label_list = list(range(1, 70 + 1))
    df_experiment = pd.read_csv(
        "/home/rohbeckm/code/mmfm/benchmark/dgp_iccite/data/experiment_70_30_random.csv", index_col=0
    )
elif preset == "c":
    n_classes = 90
    label_list = list(range(1, 90 + 1))
    df_experiment = pd.read_csv(
        "/home/rohbeckm/code/mmfm/benchmark/dgp_iccite/data/experiment_90_40_random.csv", index_col=0
    )
elif preset == "z":
    n_classes = 60
    label_list = list(range(1, 60 + 1))
    df_experiment = pd.read_csv(
        "/home/rohbeckm/code/mmfm/benchmark/dgp_iccite/data/experiment_4t_60_30_random.csv", index_col=0
    )


In [None]:
def load_all_files(DGP, PROD, load_from_parquet=True, filter_values=None):
    if load_from_parquet & (Path("data") / f"df_gt_{DGP}.parquet").exists():
        df = pd.read_parquet(Path("data") / f"df_gt_{DGP}.parquet")
        # Try to convert each value in column flow_variance to float if possible, otherwise leave value as string
        try:
            df["flow_variance"] = df["flow_variance"].astype(float)
        except ValueError:
            pass
        with open(Path("data") / f"grouping_columns_{DGP}.pkl", "rb") as f:
            grouping_columns = pickle.load(f)
        with open(Path("data") / f"performance_columns_{DGP}.pkl", "rb") as f:
            performance_columns = pickle.load(f)
    else:
        df, grouping_columns, performance_columns = load_all_fm_models(
            path="/home/rohbeckm/scratch/results/dgp_iccite/results_mmfm", production=PROD, dgp=DGP, filter_values=filter_values
        )

        # Save everything
        # Convert
        df["flow_variance"] = df["flow_variance"].astype(str)
        df.to_parquet(Path("data") / f"df_gt_{DGP}.parquet")
        # Try to convert each value in column flow_variance to float if possible, otherwise leave value as string
        try:
            df["flow_variance"] = df["flow_variance"].astype(float)
        except ValueError:
            pass

        with open(Path("data") / f"grouping_columns_{DGP}.pkl", "wb") as f:
            pickle.dump(grouping_columns, f)
        with open(Path("data") / f"performance_columns_{DGP}.pkl", "wb") as f:
            pickle.dump(performance_columns, f)

    return df, grouping_columns, performance_columns

In [None]:
def get_model_battery_of_best_models(
    df,
    dgp,
    grouping_columns,
    select_by="mmd_mean",
    device="cuda",
    n_top_models=1,
    n_top_seeds=5,
    model_string="dgp_waves",
    label_list=None,
    verbose=True,
    average_out_seed=True,
    model_type="mmfm",
):
    """Select the best MMFM model and return it with all its seeds."""

    # If eval_points is given, filter df by this
    # Note, df only contains the valid data. We pick the best model averaged over all seeds
    def weighted_avg(group_df, whole_df, values, weights):
        v = whole_df.loc[group_df.index, values]
        w = whole_df.loc[group_df.index, weights]
        return (v * w).sum() / w.sum()

    def weighted_max(group_df, whole_df, values, weights):
        v = whole_df.loc[group_df.index, values]
        w = whole_df.loc[group_df.index, weights]
        return (v * w).max()

    if average_out_seed:
        grouping = [x for x in grouping_columns if x not in ["marginal", "c", "time", "seed"]]
    else:
        grouping = [x for x in grouping_columns if x not in ["marginal", "c", "time"]]

    df_agg = (
        df.drop(columns=["marginal", "c", "time"])
        .groupby(grouping)
        .agg(
            mmd_mean=("mmd", lambda x: weighted_avg(x, df, "mmd", "weight")),
            mmd_max=("mmd", lambda x: weighted_max(x, df, "mmd", "weight")),
            mmd_std=("mmd", "std"),
            mmd_median_mean=("mmd_median", lambda x: weighted_avg(x, df, "mmd_median", "weight")),
            mmd_median_max=("mmd_median", lambda x: weighted_max(x, df, "mmd_median", "weight")),
            mmd_median_std=("mmd_median", "std"),
            wasserstein_mean=("wasserstein", lambda x: weighted_avg(x, df, "wasserstein", "weight")),
            wasserstein_max=("wasserstein", lambda x: weighted_max(x, df, "wasserstein", "weight")),
            wasserstein_std=("wasserstein", "std"),
            mean_diff_l1_mean=("mean_diff_l1", lambda x: weighted_avg(x, df, "mean_diff_l1", "weight")),
            mean_diff_l1_max=("mean_diff_l1", lambda x: weighted_max(x, df, "mean_diff_l1", "weight")),
            mean_diff_l1_std=("mean_diff_l1", "std"),
            mean_diff_l2_mean=("mean_diff_l2", lambda x: weighted_avg(x, df, "mean_diff_l2", "weight")),
            mean_diff_l2_max=("mean_diff_l2", lambda x: weighted_max(x, df, "mean_diff_l2", "weight")),
            mean_diff_l2_std=("mean_diff_l2", "std"),
            # kl_div_mean=("kl_div", lambda x: weighted_avg(x, df, "kl_div", "weight")),
            # kl_div_max=("kl_div", lambda x: weighted_max(x, df, "kl_div", "weight")),
            # kl_div_std=("kl_div", "std"),
            filename_first=("filename", "first"),
        )
    ).reset_index()

    # Find best model according to MMD/Wasserstein mean and std on valid data
    # and compute scores on test data for the best model
    model_battery = Dict()
    model_guidances = Dict()
    model_states = Dict()
    for n in range(1, n_top_models + 1):
        # for c in df["c"].unique():
        # print(f"Optimizing for condition c: {c}")
        df_top_valid = df_agg.sort_values(by=select_by, ascending=True).head(1).reset_index(drop=True)
        if verbose:
            print(f"Best model: {df_top_valid['filename_first'].values[n-1]}")
        model_guidances[n - 1] = df_top_valid["guidance"].values[n - 1]

        # Load the model from its filename and all its seed-variations
        model_path = df_top_valid["filename_first"].values[n - 1]
        model_path = model_path.replace("df_results.csv", "model.pt")

        for seed in range(n_top_seeds):
            # Replace "dgp2_{seed}" with "dgp2_x" in the path
            current_seed = model_path.split("_")[5]
            model_path = model_path.replace(f"{model_string}_{dgp}_{current_seed}", f"{model_string}_{dgp}_{seed}")
            filename = model_path.split("/")[-2]
            try:
                state = torch.load(model_path, weights_only=True)
                if verbose:
                    print(f"✓ {filename}")

            except FileNotFoundError:
                if verbose:
                    print(f"✗ {filename}")
                continue

            if model_type == "mmfm":
                mmfm_model = VectorFieldModel(
                    data_dim=state["dimension"] if "dimension" in state else state["use_pca"],
                    x_latent_dim=state["x_latent_dim"],
                    time_embed_dim=state["time_embed_dim"],
                    cond_embed_dim=state["cond_embed_dim"],
                    conditional_model=state["conditional_model"],
                    embedding_type=state["embedding_type"],
                    n_classes=state["n_classes"],
                    label_list=label_list,
                    normalization=state["normalization"],
                    activation=state["activation"],
                    affine_transform=state["affine_transform"],
                    sum_time_embed=state["sum_time_embed"],
                    sum_cond_embed=state["sum_cond_embed"],
                    max_norm_embedding=state["max_norm_embedding"],
                    num_out_layers=state["num_out_layers"],
                    spectral_norm=state["spectral_norm"],
                    dropout=state["dropout"],
                    conditional_bias=state["conditional_bias"],
                    keep_constants=state["keep_constants"],
                ).to(device)
            elif model_type == "totcfm":
                mmfm_model = MultiVectorFieldModelTCFM(
                    model_list=[0, 0.33, 0.67, 1.0],
                    data_dim=state["dimension"] if "dimension" in state else state["use_pca"],
                    x_latent_dim=state["x_latent_dim"],
                    time_embed_dim=state["time_embed_dim"],
                    cond_embed_dim=state["cond_embed_dim"],
                    conditional_model=state["conditional_model"],
                    embedding_type=state["embedding_type"],
                    n_classes=state["n_classes"],
                    label_list=label_list,
                    normalization=state["normalization"],
                    activation=state["activation"],
                    affine_transform=state["affine_transform"],
                    sum_time_embed=state["sum_time_embed"],
                    sum_cond_embed=state["sum_cond_embed"],
                    max_norm_embedding=state["max_norm_embedding"],
                    num_out_layers=state["num_out_layers"],
                    spectral_norm=state["spectral_norm"],
                    dropout=state["dropout"],
                    conditional_bias=state["conditional_bias"],
                    keep_constants=state["keep_constants"],
                ).to(device)

            mmfm_model.load_state_dict(state["state_dict"], strict=True)
            model_battery[n - 1][seed] = mmfm_model
            model_states[n - 1][seed] = state

            # Print average values of absolute model weights in state["state_dict"]
            # print(
            #     f"Average absolute model weights: {np.mean([torch.mean(torch.abs(p)).item() for p in mmfm_model.parameters()])}"
            # )

    return df_top_valid, model_battery, model_states, model_guidances

In [None]:
adata = sc.read("/home/rohbeckm/code/mmfm/data/icCITE-plex_filtered_top_drugs.h5ad")
sc.pp.normalize_total(adata)
sc.pp.log1p(adata)

if preset in ["z", "y"]:
    # Convert age to float timepoints
    adata.obs["timepoint"] = adata.obs["timepoint"].astype(str).str.extract(r"(\d+)").astype(int)
    # Create new timepoint 0 for all non-stimulated data
    adata.obs.loc[adata.obs["treatment"].isin(["No stim_1uM", "No stim_100nM", "No stim_10uM"]), "timepoint"] = 0
    adata.obs["timepoint"] = adata.obs["timepoint"] / 72

    # Round to two digits
    adata.obs["timepoint"] = adata.obs["timepoint"].round(2)
    adata.obs["Timepoint"] = adata.obs["Timepoint"].astype(str)
    adata.obs.loc[(adata.obs.timepoint == 0.0), "Timepoint"] = "0h"

# Gene activateion
activation_gene = [
    "TNFRSF18",
    "TNFRSF4",
    "IL12RB2",
    "LMNA",
    "RRM2",
    "DUSP2",
    "GBE1",
    "ZBED2",
    "IER3",
    "LTA",
    "CD109",
    "TNFAIP3",
    "SYTL3",
    "GARS",
    "SNHG15",
    "NAMPT",
    "HILPDA",
    "DUSP4",
    "RNF19A",
    "NINJ1",
    "IL2RA",
    "DDIT4",
    "PGAM1",
    "MICAL2",
    "SLC43A3",
    "SLC3A2",
    "LAG3",
    "LINC02341",
    "GNA15",
    "ZBTB32",
    "MIR155HG",
    "PIM3",
    "GK",
]

sc.tl.score_genes(adata, activation_gene, score_name="score_activation_new")

In [None]:
seed = 0  # Holdout ist constant across seeds
if preset in ["z", "y"]:
    filename_full_data = f"/home/rohbeckm/code/mmfm/benchmark/dgp_iccite/data/iccite_4t_{hvg}_{use_pca}_{subsample_frac}_{coupling}_{batch_size}_{n_samples_per_c_in_b}_{train_test_split}_{None}_{None}_{None}_{seed}_{preset}.pt"
else:
    filename_full_data = f"/home/rohbeckm/code/mmfm/benchmark/dgp_iccite/data/iccite_{hvg}_{use_pca}_{subsample_frac}_{coupling}_{batch_size}_{n_samples_per_c_in_b}_{train_test_split}_{None}_{None}_{None}_{seed}_{preset}.pt"

data = torch.load(filename_full_data)

names_holdout = df_experiment.loc[df_experiment["leave_out"].isin(["beg", "mid", "end"]), "treatment"].values

In [None]:
from functools import partial

process_fm_models_partial = partial(
    process_all_fm_models,
    plot=False,
    verbose=False,
    minimum_seeds=minimum_seeds,
    data_cols=[
        "coupling",
        "dgp",
        "interpolation",
        "guidance",
        "embedding_type",
        "classifier_free",
    ],
)

## COT-MMFM

In [None]:
df, grouping_columns, performance_columns = load_all_files(DGP, PROD, load_from_parquet=True)

In [None]:
df = df.loc[
    (df["preset"] == preset)
    & (df["coupling"] == coupling)
    & (df["train_test_split"] == train_test_split)
    & (df["n_samples_per_c_in_b"] == n_samples_per_c_in_b)
    & (df["use_pca"] == use_pca)
    & (df["preset"] == preset)
    & (df["model_type"] == "mmfm")
]
if not classifier_free:
    df = df.loc[~df["classifier_free"]]

# Filter for validation results
df.loc[:, "train"] = df["train"].astype(bool)
df = df.loc[~df["train"]]

df, grouping_columns, performance_columns = process_fm_models_partial(df, grouping_columns, performance_columns)

# We only compare to natural cubic splines on FSI later, so let's focus on
# these interpolations for MMFM as well
df_cubic = df.loc[(df["interpolation"] == "cubic")]

print(df_cubic.shape)

# Add weight column to lay focus on conditional generalization
df_cubic["weight"] = 1

# this is our extra validation timepoint
# Note, that for a fair comparison we provide this sample as training data to FSI
add_time_cond = None

In [None]:
# Filtering for target validation time and c
# Use condition 25 for validation
for valid_filter in [11]:  # [i for i, k in data["ps"].items() if k in names_holdout]:
    print(f"Validating for condition {valid_filter}")
    df_cubic_valid = df_cubic.loc[(df_cubic["c"] == valid_filter)]

    for select_by in [
        # "mean_diff_l1_mean",
        # "mean_diff_l2_mean",
        # "mmd_mean",
        # "mmd_median_mean",
        "wasserstein_mean",
    ]:
        print(f"Selecting by {select_by}")
        _, model_battery, model_states, model_guidances = get_model_battery_of_best_models(
            df_cubic_valid,
            DGP,
            grouping_columns,
            select_by=select_by,
            device="cuda",
            n_top_models=1,
            n_top_seeds=5,
            model_string="dgp_iccite",
            label_list=label_list,
        )

        X_test, y_test, t_test, n_classes, timepoints, all_classes, ps = dgp_iccite_data(
            hvg,
            subsample_frac,
            use_pca,
            coupling,
            batch_size,
            n_samples_per_c_in_b,
            train_test_split,
            DGP,
            top_n_effects,
            leave_out_middle,
            leave_out_end,
            preset=preset,
            return_data="test",
            seed=0,
        )

        df_cubic_results_mmfm, trajectories_test = predict_on_testset_mmfm(
            model_battery=model_battery,
            model_states=model_states,
            model_guidances=model_guidances,
            X_test=X_test,
            y_test=y_test,
            t_test=t_test,
            device=device,
            steps=101,
            method="rk4",
        )
        df_cubic_results_mmfm["model"] = "COT-MMFM"


## OT-CFM

In [None]:
df, grouping_columns, performance_columns = load_all_files(DGP, PROD, load_from_parquet=True)

df = df.loc[
    (df["preset"] == preset)
    & (df["coupling"] == coupling)
    & (df["train_test_split"] == train_test_split)
    & (df["n_samples_per_c_in_b"] == n_samples_per_c_in_b)
    & (df["use_pca"] == use_pca)
    & (df["preset"] == preset)
    & (df["model_type"] == "fm")
]
if not classifier_free:
    df = df.loc[~df["classifier_free"]]

# Filter for validation results
df.loc[:, "train"] = df["train"].astype(bool)
df = df.loc[~df["train"]]

df, grouping_columns, performance_columns = process_fm_models_partial(df, grouping_columns, performance_columns)

df_cubic = df.loc[(df["interpolation"] == "cubic")]
df_cubic["weight"] = 1
df_cubic_valid = df_cubic.loc[(df_cubic["c"] == valid_filter)]

print(df_cubic.shape)

_, model_battery, model_states, model_guidances = get_model_battery_of_best_models(
    df_cubic_valid,
    DGP,
    grouping_columns,
    select_by=select_by,
    device="cuda",
    n_top_models=1,
    n_top_seeds=5,
    model_string="dgp_iccite",
    label_list=label_list,
)

df_cubic_results_fm, trajectories_test = predict_on_testset_mmfm(
    model_battery=model_battery,
    model_states=model_states,
    model_guidances=model_guidances,
    X_test=X_test,
    y_test=y_test,
    t_test=t_test,
    device=device,
    steps=101,
    method="rk4",
)
df_cubic_results_fm["model"] = "OT-CFM"


## L-COT-MMFM

In [None]:
df, grouping_columns, performance_columns = load_all_files(DGP, PROD, load_from_parquet=True)

df = df.loc[
    (df["preset"] == preset)
    & (df["coupling"] == coupling)
    & (df["train_test_split"] == train_test_split)
    & (df["n_samples_per_c_in_b"] == n_samples_per_c_in_b)
    & (df["use_pca"] == use_pca)
    & (df["preset"] == preset)
    & (df["model_type"] == "mmfm")
]
if not classifier_free:
    df = df.loc[~df["classifier_free"]]

# Filter for validation results
df.loc[:, "train"] = df["train"].astype(bool)
df = df.loc[~df["train"]]

df, grouping_columns, performance_columns = process_fm_models_partial(df, grouping_columns, performance_columns)

df_cubic = df.loc[(df["interpolation"] == "linear")]
df_cubic["weight"] = 1
df_cubic_valid = df_cubic.loc[(df_cubic["c"] == valid_filter)]

_, model_battery, model_states, model_guidances = get_model_battery_of_best_models(
    df_cubic_valid,
    DGP,
    grouping_columns,
    select_by=select_by,
    device="cuda",
    n_top_models=1,
    n_top_seeds=5,
    model_string="dgp_iccite",
    label_list=label_list,
)

df_cubic_results_pcfm, trajectories_test = predict_on_testset_mmfm(
    model_battery=model_battery,
    model_states=model_states,
    model_guidances=model_guidances,
    X_test=X_test,
    y_test=y_test,
    t_test=t_test,
    device=device,
    steps=101,
    method="rk4",
)
df_cubic_results_pcfm["model"] = "L-COT-MMFM"


# T-OT-CFM

In [None]:
df, grouping_columns, performance_columns = load_all_files(DGP, PROD, load_from_parquet=True, filter_values="totcfm")

In [None]:
df = df.loc[
    (df["preset"] == preset)
    & (df["coupling"] == coupling)
    & (df["train_test_split"] == train_test_split)
    & (df["n_samples_per_c_in_b"] == n_samples_per_c_in_b)
    & (df["use_pca"] == use_pca)
    & (df["preset"] == preset)
    & (df["model_type"] == "totcfm")
]
if not classifier_free:
    df = df.loc[~df["classifier_free"]]

# Filter for validation results
df.loc[:, "train"] = df["train"].astype(bool)
df = df.loc[~df["train"]]

df, grouping_columns, performance_columns = process_fm_models_partial(df, grouping_columns, performance_columns)

df_cubic = df.loc[(df["interpolation"] == "linear")]
df_cubic["weight"] = 1
df_cubic_valid = df_cubic.loc[(df_cubic["c"] == valid_filter)]

In [None]:
print(df_cubic_valid.shape)

In [None]:
_, model_battery, model_states, model_guidances = get_model_battery_of_best_models(
    df_cubic_valid,
    DGP,
    grouping_columns,
    select_by=select_by,
    device="cuda",
    n_top_models=1,
    n_top_seeds=5,
    model_string="dgp_iccite",
    label_list=label_list,
    model_type="totcfm",
)

df_cubic_results_totcfm, trajectories_test = predict_on_testset_mmfm(
    model_battery=model_battery,
    model_states=model_states,
    model_guidances=model_guidances,
    X_test=X_test,
    y_test=y_test,
    t_test=t_test,
    device=device,
    steps=101,
    method="rk4",
)
df_cubic_results_totcfm["model"] = "T-OT-CFM"

### FSI

In [None]:
# Evaluate FSI models on the same test data
df_cubic_results_fsi = []
for seed in range(3):
    filename = (
        f"dgp_iccite_{DGP}_{seed}_{hvg}_{subsample_frac}_{use_pca}_{batch_size}_{n_samples_per_c_in_b}"
        + f"_{train_test_split}_{top_n_effects}_{leave_out_middle}_{leave_out_end}_{preset}"
    )
    results_path = Path(path_name_fsi) / filename

    df_results_fsi = predict_on_testset_fsi(
        results_path,
        X_test,
        y_test,
        t_test,
        seed=seed,
        coupling=coupling,
        n_classes=len(label_list),
        plot_results=False,
        ncols=5,
    )
    df_cubic_results_fsi.append(df_results_fsi)

df_cubic_results_fsi = pd.concat(df_cubic_results_fsi).reset_index(drop=True)

In [None]:
plt.style.use(["science", "no-latex"])
from mmfm.utils import COLORMAP10, ThickerLine2D


params = {
    "axes.labelsize": 18,
    "axes.titlesize": 22,
    "xtick.labelsize": 18,
    "ytick.labelsize": 18,
    "legend.fontsize": 18,
    "legend.title_fontsize": 22,
    "figure.titlesize": 30,
}
plt.rcParams.update(params)

df_results = pd.concat(
    [df_cubic_results_mmfm, df_cubic_results_fsi, df_cubic_results_fm, df_cubic_results_pcfm, df_cubic_results_totcfm], ignore_index=True, axis=0
).reset_index(drop=True)

leave_out_beg_idx = [
    k for k, v in ps.items() if v in df_experiment.loc[(df_experiment["leave_out"] == "beg"), "treatment"].values
]
leave_out_mid_idx = [
    k for k, v in ps.items() if v in df_experiment.loc[(df_experiment["leave_out"] == "mid"), "treatment"].values
]
leave_out_end_idx = [
    k for k, v in ps.items() if v in df_experiment.loc[(df_experiment["leave_out"] == "end"), "treatment"].values
]

df_results.loc[:, "training"] = True
df_results.loc[(df_results["c"].isin(leave_out_beg_idx)) & (df_results["time"] == 0.33), "training"] = False
df_results.loc[(df_results["c"].isin(leave_out_mid_idx)) & (df_results["time"] == 0.67), "training"] = False
df_results.loc[(df_results["c"].isin(leave_out_end_idx)) & (df_results["time"] == 1.00), "training"] = False

df_experiment["group"] = pd.cut(
    df_experiment["strength"],
    bins=[-np.inf, 1.75, 2.4, 2.9, np.inf],
    labels=["weak", "small", "medium", "strong"],
)

# Group strength into 3 categories
df_results["treatment"] = df_results["c"].map(ps)
# Joing group column from df_experiment using the treatment column as key
df_results = df_results.merge(df_experiment[["treatment", "group"]], on="treatment", how="left")

df_results.loc[
    (df_results["c"].isin(leave_out_end_idx)) & (df_results["time"] == 1.00) & (df_results["model"] == "FM"),
    ["mean_diff_l1", "mean_diff_l2", "kl_div", "mmd", "mmd_median", "wasserstein"],
] = np.nan

df_results["Timepoint"] = np.NaN
df_results.loc[df_results["time"] == 0.0, "Timepoint"] = "0h"
df_results.loc[df_results["time"] == 0.33, "Timepoint"] = "24h"
df_results.loc[df_results["time"] == 0.67, "Timepoint"] = "48h"
df_results.loc[df_results["time"] == 1.0, "Timepoint"] = "72h"

# Merge score_activation_new into df_results
df_results = df_results.merge(
    adata.obs[["treatment", "Timepoint", "score_activation_new"]],
    left_on=["treatment", "Timepoint"],
    right_on=["treatment", "Timepoint"],
    how="left",
)

fig, ax = plt.subplots(1, 4, figsize=(15, 5))
for idx, group in enumerate(["weak", "small", "medium", "strong"]):
    sns.boxplot(
        data=df_results.loc[(df_results["group"] == group) & (df_results["time"] > 0.0)],
        x="time",
        y=select_by.replace("_mean", ""),
        hue="model",
        ax=ax[idx],
        legend=True,
        showfliers=False,
    )
    ax[idx].set_title(f"Treatment effect: {group}")
plt.tight_layout()
# Create a single legend for the figure
handles, labels = fig.gca().get_legend_handles_labels()
by_label = dict(zip(labels, handles))
fig.legend(
    by_label.values(),
    by_label.keys(),
    loc="center left",
    bbox_to_anchor=(1.0, 0.5),
    handler_map={plt.Line2D: ThickerLine2D()},
    title="Model",
)
plt.tight_layout()
for k in range(4):
    ax[k].set_xlabel("Timepoint")
    ax[k].set_ylabel("EMD")
    ax[k].grid(axis="y")
    # Remove legend from ax[2]
    ax[k].legend().remove()
plt.savefig(f"../../figures_paper/iccite_timecourse.png")
# plt.close()
# Add vertical lines between the on the right side of the boxplot
for k in range(4):
    for i in range(3):
        ax[k].axvline(i - 0.5, color="black", lw=1, alpha=0.25, linestyle="--")
plt.show()

fig, ax = plt.subplots(1, 3, figsize=(15, 5), sharey=False, sharex=True)
df_results["binned_score_activation_new"] = pd.cut(
    df_results["score_activation_new"],
    bins=[-np.inf, -0.1, 0.1, 0.3, 0.5, np.inf],
    labels=["low", "medium-low", "medium", "medium-high", "high"],
)
sns.boxplot(
    data=df_results.loc[(~df_results["training"]) & (df_results["time"] == 0.33)],
    x="binned_score_activation_new",
    y=select_by.replace("_mean", ""),
    hue="model",
    ax=ax[0],
    legend=False,
    showfliers=False,
)
sns.boxplot(
    data=df_results.loc[(~df_results["training"]) & (df_results["time"] == 0.67)],
    x="binned_score_activation_new",
    y=select_by.replace("_mean", ""),
    hue="model",
    ax=ax[1],
    legend=False,
    showfliers=False,
)

sns.boxplot(
    data=df_results.loc[(~df_results["training"]) & (df_results["time"] == 1.0)],
    x="binned_score_activation_new",
    y=select_by.replace("_mean", ""),
    hue="model",
    ax=ax[2],
    legend=True,
    showfliers=False,
)
# Add vertical lines between the on the right side of the boxplot
for k in range(3):
    for i in range(5):
        ax[k].axvline(i - 0.5, color="black", lw=1, alpha=0.25, linestyle="--")


# Create a single legend for the figure
handles, labels = fig.gca().get_legend_handles_labels()
by_label = dict(zip(labels, handles))
fig.legend(
    by_label.values(),
    by_label.keys(),
    loc="center left",
    bbox_to_anchor=(1.0, 0.5),
    handler_map={plt.Line2D: ThickerLine2D()},
    title="Model",
)

ax[0].set_title("Timepoint: 24h")
ax[1].set_title("Timepoint: 48h")
ax[2].set_title("Timepoint: 72h")
plt.tight_layout()
for k in range(3):
    ax[k].set_xlabel("Gene Activation Level")
    ax[k].set_ylabel("EMD")
    # Rotate x axis ticks for better readability
    ax[k].set_xticklabels(ax[k].get_xticklabels(), rotation=45)
    # Set legend title to Model
    # ax[k].legend(title="Model")
    ax[k].grid(axis="y")
# Remove legend from ax[2]
ax[2].legend().remove()

plt.savefig(f"../../figures_paper/iccite_gene_activation.png")
# plt.close()
plt.show()

# Wasserstein metric
# dgp_iccite_a_0_0.001_0.1_3_False_0.0_64_64_64_True_False_free_False_False_None_False_True_xavier_SELU_cosine_cubic_300_cot_None_0.8_None_10_100_None_adam_None_None_None_a_False_0.0_False_False_emd_mmfm

In [None]:
df_r = (
    df_results.groupby(["model", "training", "time"])
    .agg(
        {
            "mmd": ["mean", "std", "max"],
            "mmd_median": ["mean", "std", "max"],
            "wasserstein": ["mean", "std", "max"],
            "mean_diff_l1": ["mean", "std", "max"],
            "mean_diff_l2": ["mean", "std", "max"],
        }
    )
    .round(2)
    .rename(columns={"model": "Model"})[["mean_diff_l2", "wasserstein"]]
    .reset_index()
)
df_r = df_r.loc[~df_r["training"]]
df_r