# Generate figures for the joint simulation experiment
Here, we analyze results of running models on simulated datasets with a combination of linear and nonlinear signal.

Prerequisites: 
- you ran the joint simulation experiment
- the results are saved as a CSV (you can generate this by running the relevant portion of the `notebooks/analysis/make_supplementary_tables.ipynb` notebook)

In [None]:
# Import Required Libraries
import ast
import logging
import os
import sys

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns

sys.path.insert(0, "../..")  # add project_config to path
import project_config

# Setup Logging and Configuration
logging.basicConfig(
    format="%(asctime)s %(levelname)-8s [%(name)s] %(message)s",
    level=logging.INFO,
    datefmt="%Y-%m-%d %H:%M:%S",
    force=True,
)
logger = logging.getLogger(__name__)

## Setup Logging and Configuration
Set up logging and define configurations such as WandB project name, sweep ID, and directories for saving results.

In [None]:
# Path to file with performance metrics and feature importance paths for each run
METRIC_CSV_PATH = project_config.RESULT_DIR / "joint_simulation_prediction_metrics_per_run_full.csv"

# # Define directories for saving results
FIGDIR = project_config.FIGURE_1_DIR
RESULTS_DIR = "../../results/joint_simulation_results/"
# make directory if it doesn't exist
os.makedirs(os.path.join(FIGDIR), exist_ok=True)

# Figure parameters
figsize_tuple = (2.4, 2.1)
aspect_ratio_value = 0.5
cmap_style = "viridis"

## Fetch WandB Runs and Performance Metrics (grouped by model_type, OR, and sigma)
Use the WandB API to fetch runs and extract performance metrics grouped by `model_type`, `odds_ratio`, `sigma`, and `n_samples` for purposes of assessing the 'power curve' of the model. Store the results in a DataFrame.

In [None]:
df = pd.read_csv(METRIC_CSV_PATH, converters={"deltaMuGenes": ast.literal_eval,
                                              "mod0_genes": ast.literal_eval,
                                              "mod1_genes": ast.literal_eval,
                                              "datasets": ast.literal_eval,
                                              })  

# Add a counts column for each group
group_cols = ["model_type", "n_samples", "odds_ratio", "sigma", "sample_binary"]
df_counts = df.groupby(group_cols).size().reset_index(name="count")

# filter to only model_type == "pnet"
df = df[df["model_type"] == "pnet"].reset_index(drop=True)

# in "save_dir" column, replace "../../results/" with "/mnt/disks/gmiller_data1/pnet/results"
df["save_dir"] = df["save_dir"].str.replace("../../results/", "/mnt/disks/gmiller_data1/pnet/results/", regex=False)

# Add a unique group identifer column by joining together the unique identifiers: ["model_type", "n_features", "odds_ratio", "control_frequency", "sample_binary"]
df["group_identifier"] = df.apply(
    lambda row: f"OR-{row['odds_ratio']}_sigma-{row['sigma']}_nSamples-{row['n_samples']}_sampleBinary-{row['sample_binary']}",
    axis=1)

# Create df_avg by averaging over runs with the same group identifiers
df_avg = df.groupby(group_cols, as_index=False).mean()
df_avg = df_avg.merge(df_counts, on=group_cols)

# Create df_stdev by calculating the standard deviation over runs with the same group identifiers
df_stdev = df.groupby(group_cols, as_index=False).std()
df_stdev = df_stdev.merge(df_counts, on=group_cols)

In [None]:
'run_id' in df.columns
'run_id' in df_counts.columns

In [None]:
df.shape

## helper functions

In [None]:
def smart_ticklabels(labels):
    return [f"{int(l)}" if float(l).is_integer() else f"{l:.1f}" for l in labels]

def get_smart_odds_ratio_labels(axessubplot_object):
    """Expects type <class 'matplotlib.axes._subplots.AxesSubplot'> as input"""
    smart_yticks = smart_ticklabels([float(l.get_text()) for l in axessubplot_object.get_yticklabels()])
    return smart_yticks


## 3 different results/subresults
Here we make performance heatmaps for three different subsets of the data:

- (A) sample_binary = True, n_samples=1000
- (B) sample_binary = False, n_samples=1000
- (C) sample_binary = False, n_samples=10000

Plot these side-by-side in one figure: "Mean Average Precision under varied joint sampling strategies"

In [None]:
# current best: one 2x2 heatmap
def heatmap_no_counts(df, title, ax):
    precision = df.pivot(index="odds_ratio", columns="sigma", values=f"{eval_set}_avg_precision")
    sns.heatmap(
        precision, cmap=cmap_style, annot=True, fmt=".2f",
        cbar=False, ax=ax
    )
    ax.set_aspect(aspect_ratio_value, adjustable='box')
    ax.set_yticklabels(get_smart_odds_ratio_labels(ax))
    ax.set_title(title)
    ax.set_xlabel("sigma")
    ax.set_ylabel("odds ratio")

eval_set = "test"
fig, axes = plt.subplots(2, 2, figsize=(figsize_tuple[0]*2, figsize_tuple[1]*2))

# Row 0: N = 1k
heatmap_no_counts(
    df_avg[(df_avg["sample_binary"] == True) & (df_avg["n_samples"] == 1000)],
    "Binary, N=1k",
    ax=axes[0, 0]
)

heatmap_no_counts(
    df_avg[(df_avg["sample_binary"] == False) & (df_avg["n_samples"] == 1000)],
    "Continuous, N=1k",
    ax=axes[0, 1]
)

# Row 1: N = 10k
heatmap_no_counts(
    df_avg[(df_avg["sample_binary"] == True) & (df_avg["n_samples"] == 10000)],
    "Binary, N=10k",
    ax=axes[1, 0]
)

heatmap_no_counts(
    df_avg[(df_avg["sample_binary"] == False) & (df_avg["n_samples"] == 10000)],
    "Continuous, N=10k",
    ax=axes[1, 1]
)
axes[0,1].set_ylabel("")
axes[1,1].set_ylabel("")
axes[0,1].set_yticklabels("")
axes[1,1].set_yticklabels("")

fig.suptitle(f"AUPRC on all simulated datasets ({eval_set} set)")
plt.tight_layout()
plt.savefig(os.path.join(FIGDIR, f"2D_simulations_heatmap_auprc_all_simulated_datasets_{eval_set}.png"), format='png', dpi=600)
plt.show()


In [None]:
# current best: one 2x2 heatmap
def heatmap_no_counts(df, title, ax):
    precision = df.pivot(index="odds_ratio", columns="sigma", values=f"{eval_set}_avg_precision")
    sns.heatmap(
        precision, cmap=cmap_style, annot=True, fmt=".2f",
        cbar=False, ax=ax
    )
    ax.set_aspect(aspect_ratio_value, adjustable='box')
    ax.set_yticklabels(get_smart_odds_ratio_labels(ax))
    ax.set_title(title)
    ax.set_xlabel("sigma")
    ax.set_ylabel("odds ratio")

eval_set = "test"
fig, axes = plt.subplots(2, 2, figsize=(figsize_tuple[0]*2, figsize_tuple[1]*2))

# Row 0: N = 1k
heatmap_no_counts(
    df_stdev[(df_stdev["sample_binary"] == True) & (df_stdev["n_samples"] == 1000)],
    "Binary, N=1k",
    ax=axes[0, 0]
)

heatmap_no_counts(
    df_stdev[(df_stdev["sample_binary"] == False) & (df_stdev["n_samples"] == 1000)],
    "Continuous, N=1k",
    ax=axes[0, 1]
)

# Row 1: N = 10k
heatmap_no_counts(
    df_stdev[(df_stdev["sample_binary"] == True) & (df_stdev["n_samples"] == 10000)],
    "Binary, N=10k",
    ax=axes[1, 0]
)

heatmap_no_counts(
    df_stdev[(df_stdev["sample_binary"] == False) & (df_stdev["n_samples"] == 10000)],
    "Continuous, N=10k",
    ax=axes[1, 1]
)
axes[0,1].set_ylabel("")
axes[1,1].set_ylabel("")
axes[0,1].set_yticklabels("")
axes[1,1].set_yticklabels("")

fig.suptitle(f"Std dev of AUPRC on all simulated datasets ({eval_set} set)")
plt.tight_layout()
plt.savefig(os.path.join(FIGDIR, f"2D_simulations_heatmap_auprc_stddev_all_simulated_datasets_{eval_set}.png"), format='png', dpi=600)
plt.show()


## scratch: WIP get feat imps

Goal: heatmap-style plot with the mean/median rank (importance) of the:

- deltaMu genes
- genes only in corr module
- genes in both corr module that also have deltaMu

In [None]:
########## Functions to extract gene imp and rank ##########
def load_feature_importances(importances_path):
    """
    Load feature importance data from a CSV file.

    Args:
        importances_path (str): Path to the feature importance file.

    Returns:
        pd.DataFrame: DataFrame of feature importances.
    """
    if not os.path.exists(importances_path):
        raise FileNotFoundError(f"File not found: {importances_path}")
    logger.debug(f"Loading feature importances from {importances_path}")
    imps = pd.read_csv(importances_path).set_index('Unnamed: 0')
    logger.debug(f"Loaded imps with shape {imps.shape}")
    return imps


def process_importances(imps, response_df):
    """
    Process feature importances by joining with response data and calculating differences between sample classes.
    This function computes the mean feature importances for each response class and then calculates the difference between them.
    The result is a Series of feature importance differences.

    Args:
        imps (pd.DataFrame): Feature importance DataFrame.
        response_df (pd.DataFrame): Response variable DataFrame.

    Returns:
        pd.Series: Processed feature importance differences.
    """
    logger.debug(f"head of imps.join(response_df).groupby('response').mean(): {imps.join(response_df).groupby('response').mean().head()}")
    logger.debug(f"shape of imps.join(response_df).groupby('response').mean(): {imps.join(response_df).groupby('response').mean().shape}")
    return imps.join(response_df).groupby('response').mean().diff(axis=0).iloc[1]


def build_runwise_gene_importance_rank_df(df, importance_path_column='gene_importances_path'):
    """
    Build a DataFrame summarizing gene importance and rank for each run and gene set.

    For each run in df, this function:
      - Loads the gene importances and computes the class-difference importance vector.
      - Computes the absolute rank of each gene (lower rank = more important).
      - Restricts to genes in each of three sets: "deltaMuGenes", "mod1_genes", "mod0_genes".
      - For each gene in each set, records the run ID, gene set, gene name, mean rank, and mean importance.

    Args:
        df (pd.DataFrame): DataFrame with one row per run, containing columns for run_id, gene sets, and path to gene importances.
        response_df (pd.DataFrame): DataFrame with sample responses, indexed by sample name, with column 'response'.
        importance_path_column (str): Column name in df with the path to the gene importances CSV.

    Returns:
        pd.DataFrame: DataFrame with columns [run_id, gene_perturbation_group, gene, rank, imp], where rank and imp are the mean rank and mean importance for each gene in each run and gene set.
    """
    logger.info(f"For each of the {df.shape[0]} model runs, we are getting gene importance and rank for each gene set")
    records = []
    for idx, row in df.iterrows():
        # Use index as run_id if no explicit run_id column
        run_id = row.get("run_id") or row.get("wandb_run_id") or idx
        logger.debug(f"Processing run {run_id} with row: {row.to_dict()}")

        logger.debug("Parse gene sets")
        deltaMuGenes = set(row["deltaMuGenes"])
        mod1_genes = set(row["mod1_genes"])
        mod0_genes = set(row["mod0_genes"])
        gene_perturbation_groups = {
            "deltaMuGenes": deltaMuGenes,
            "mod1_genes": mod1_genes,
            "mod0_genes": mod0_genes,
        }
        logger.debug(f"genes in both deltaMuGenes and mod1_genes: {deltaMuGenes.intersection(mod1_genes)}")
        # raise error if any of the gene sets are empty
        for group_name, genes in gene_perturbation_groups.items():
            if len(genes) == 0:
                raise ValueError(f"Gene set '{group_name}' is empty for run {run_id}, cannot proceed.")
            else:
                logger.debug(f"Gene set '{group_name}' has {len(genes)} genes: {genes}")

        logger.debug(f"Construct response_df from the number of samples, knowing that we had even split (first half are class 1)")
        n1 = row.get("num_class1_samples")
        n0 = row.get("num_class0_samples")
        response_df = np.concatenate([np.ones(n1), np.zeros(n0)])
        response_df = pd.DataFrame(response_df.astype(int), index=[f"Sample_{i}" for i in range(n1+n0)], columns=["response"])
        logger.debug(f"constructed response_df with shape {response_df.shape}")

        try:
            logger.debug("Loading importances")
            imps = load_feature_importances(row[importance_path_column])
            logger.debug("Processing importances")
            processed_imps = process_importances(imps, response_df)
            # processed_imps: pd.Series, index=gene names
            logger.debug("Ranks are calculated from the absolute(importance)")
            ranks = processed_imps.abs().rank(ascending=False)
        except Exception as e:
            logger.warning(f"Skipping run {run_id} due to error: {e}. \nCheck: does a file exist at {row[importance_path_column]}?")
            continue        

        logger.debug("For each gene set, collect mean importance and rank for each gene")
        for gene_perturbation_group_name, gene_perturbation_group in gene_perturbation_groups.items():
            for gene in gene_perturbation_group:
                if gene in processed_imps.index and gene in ranks.index:
                    records.append({
                        "run_id": run_id,
                        "gene_perturbation_group": gene_perturbation_group_name,
                        "gene": gene,
                        "rank": ranks[gene],
                        "imp": processed_imps[gene],                        
                    })
                else:
                    raise ValueError(f"Gene {gene} not found in importances for run {run_id}, skipping. Processed imps index: {processed_imps.index.tolist()}, ranks index: {ranks.index.tolist()}")

    final_df = pd.DataFrame(records)
    return final_df


########## Functions to plot top-k recovery ##########
def filter_runs(df, sample_binary=False, n_samples=10000, model_type="pnet"):
    """Filter runs for a given sample_binary, n_samples, and model_type."""
    return df[
        (df["sample_binary"] == sample_binary) &
        (df["n_samples"] == n_samples) &
        (df["model_type"] == model_type)
    ].copy()

def get_group_identifier(row):
    """Create a group identifier string for a run."""
    return f"OR-{row['odds_ratio']}_sigma-{row['sigma']}_nSamples-{row['n_samples']}_sampleBinary-{row['sample_binary']}"

def assign_group_identifier(df):
    """Assign group_identifier column to a DataFrame."""
    df = df.copy()
    df["group_identifier"] = df.apply(get_group_identifier, axis=1)
    return df

def get_gene_sets(row):
    """Return dict of gene sets for a run."""
    delta = set(row["deltaMuGenes"])
    mod0 = set(row["mod0_genes"])
    mod1 = set(row["mod1_genes"])
    corr = mod0 | mod1 # set union
    only_linear = delta - corr
    only_nonlinear = corr - delta
    both = delta & corr
    return {
        "only linear": only_linear,
        "only nonlinear": only_nonlinear,
        "both": both
    }

def add_gene_set_column(per_run_imps_df, df):
    """Add a 'gene_set_type' column to per_run_imps_df: only linear, only nonlinear, both."""
    # Build a mapping from run_id to gene sets
    runid_to_sets = {}
    for _, row in df.iterrows():
        runid_to_sets[row["run_id"]] = get_gene_sets(row)
    # Assign gene_set_type for each row in per_run_imps_df
    def which_set(row):
        sets = runid_to_sets.get(row["run_id"], {})
        for set_name, genes in sets.items():
            if row["gene"] in genes:
                return set_name
        return None
    per_run_imps_df = per_run_imps_df.copy()
    per_run_imps_df["gene_set_type"] = per_run_imps_df.apply(which_set, axis=1)
    return per_run_imps_df

def aggregate_heatmap(per_run_imps_df, df, value_col, gene_set_type):
    """
    Aggregate median value_col (rank or imp) for each group_identifier, odds_ratio, sigma, and gene_set_type.
    Returns a DataFrame with odds_ratio as rows, sigma as columns.
    """
    # Merge group_identifier and odds_ratio/sigma into per_run_imps_df
    meta = df[["run_id", "group_identifier", "odds_ratio", "sigma"]]
    merged = per_run_imps_df.merge(meta, on="run_id")
    # Filter for the gene_set_type
    merged = merged[merged["gene_set_type"] == gene_set_type]
    # Group by group_identifier, odds_ratio, sigma, then aggregate median across genes and runs
    agg = merged.groupby(["odds_ratio", "sigma"])[value_col].median().unstack()
    return agg


def aggregate_topk_recovery(per_run_imps_df, df, gene_set_type):
    """
    For each run, compute the fraction of signal genes (of this gene_set_type) recovered in the top K genes,
    where K = total number of signal genes for that run (across all three sets).
    Returns a DataFrame with odds_ratio as rows, sigma as columns, values are median top-K recovery across runs.
    """
    # Merge group_identifier and odds_ratio/sigma into per_run_imps_df
    meta = df[["run_id", "group_identifier", "odds_ratio", "sigma", "deltaMuGenes", "mod0_genes", "mod1_genes"]]
    merged = per_run_imps_df.merge(meta, on="run_id")
    records = []
    for run_id, group in merged.groupby("run_id"):
        # Get all signal genes for this run
        row = group.iloc[0]
        delta = set(row["deltaMuGenes"])
        mod0 = set(row["mod0_genes"])
        mod1 = set(row["mod1_genes"])
        only_linear = delta - (mod0 | mod1)
        only_nonlinear = (mod0 | mod1) - delta
        both = delta & (mod0 | mod1)
        all_signal_genes = only_linear | only_nonlinear | both
        K = len(all_signal_genes)
        if K == 0:
            continue
        # Get the genes of the current gene_set_type for this run
        if gene_set_type == "only linear":
            signal_genes = only_linear
        elif gene_set_type == "only nonlinear":
            signal_genes = only_nonlinear
        elif gene_set_type == "both":
            signal_genes = both
        else:
            continue
        if len(signal_genes) == 0:
            continue
        # Get all genes and their ranks for this run
        run_genes = group[["gene", "rank"]].drop_duplicates().set_index("gene")["rank"]
        # For each gene in the set, check if its rank <= K
        n_in_topk = sum(run_genes.get(gene, K+1) <= K for gene in signal_genes)
        # Compute recovery: fraction of signal genes in top K
        recovery = n_in_topk / len(signal_genes)
        records.append({
            "run_id": run_id,
            "gene_set_type": gene_set_type,
            "odds_ratio": row["odds_ratio"],
            "sigma": row["sigma"],
            "topk_recovery": recovery
        })
    # Aggregate by odds_ratio, sigma (median across runs)
    df = pd.DataFrame(records)
    heatmap_df = df.groupby(["odds_ratio", "sigma"])["topk_recovery"].median().unstack()
    stdev_df = df.groupby(["odds_ratio", "sigma"])["topk_recovery"].std().unstack()
    return heatmap_df, stdev_df, df

## Get run-level rank and importance information for all the relevant gene sets
Each row in df_feature_importance_paths corresponds to a run. I would like my final output to be a DF with columns [run_id, gene_set, gene, rank, imp]. The run_id is the identifer for a given run (wandb run id). The gene set is one of "deltaMuGenes", "mod1_genes", "mod0_genes". The rank is the mean rank of the gene across all samples in that run. The imp is the mean importance of the gene across all samples in that run.

In [None]:
logger.setLevel(logging.INFO)

In [None]:
eval_set = "test"
per_run_imps_df = build_runwise_gene_importance_rank_df(df, importance_path_column=f'{eval_set}_gene_importances_path')
per_run_imps_df

## Build heatmap of top-k recovery
switch metric to top-k recovery/top-k accuracy (not median rank). "How many of the true signal genes did my model recover among the top K?"
Also possible to use recall at K (same as accuracy but plotted as a curve for varied K) or Precision at K ("how 'pure' is my top K list?)

In [None]:
# --- MAIN LOGIC ---

# 1. Filter runs and per_run_imps_df
# df_filt = filter_runs(df, sample_binary=False, n_samples=10000) # continuous 10k

df_filts = {
    "Binary, N=1k": filter_runs(df, sample_binary=True, n_samples=1000),
    "Binary, N=10k": filter_runs(df, sample_binary=True, n_samples=10000),
    "Continuous, N=1k": filter_runs(df, sample_binary=False, n_samples=1000),
    "Continuous, N=10k": filter_runs(df, sample_binary=False, n_samples=10000),
    
}

for df_filt_name, df_filt in df_filts.items():
    if df_filt.shape[0] == 0:
        logger.info(f"{df_filt_name} had nothing left in the filtered runs")
        continue # skip this figure
    run_ids = set(df_filt["run_id"])
    per_run_filt = per_run_imps_df[per_run_imps_df["run_id"].isin(run_ids)].copy()

    # 2. Assign group_identifier to both DFs
    df_filt = assign_group_identifier(df_filt)

    # 3. Add gene_set_type column to per_run_filt
    per_run_filt = add_gene_set_column(per_run_filt, df_filt)
    # drop duplicates in per_run_filt based on run_id, gene, and gene_set_type, because only want to count each gene once per run (even if it occurs in multiple gene set categories)
    per_run_filt = per_run_filt.drop_duplicates(subset=["run_id", "gene", "gene_set_type"])

    # 4. Aggregate for each gene set and value type (recovery heatmap)
    fig, axes = plt.subplots(1, 3, figsize=(figsize_tuple[0]*3, figsize_tuple[1]), sharey=True)
    for col_idx, gene_set in enumerate(["only linear", "only nonlinear", "both"]):
        heatmap_df, _, _ = aggregate_topk_recovery(per_run_filt, df_filt, gene_set)
        ax = axes[col_idx]
        sns.heatmap(heatmap_df, annot=True, cmap=cmap_style, ax=ax, fmt=".2f", vmin=0, vmax=1, cbar=False)
        # ax.set_title(f"Top-K Recovery: {gene_set} ({eval_set} set)")
        ax.set_aspect(aspect_ratio_value)
        ax.set_title(f"{gene_set}")
        ax.set_xlabel("sigma")
        if col_idx == 0:
            ax.set_ylabel("odds ratio")
        else:
            ax.set_ylabel("")
    axes[0].set_yticklabels(get_smart_odds_ratio_labels(axes[0]))
    plt.tight_layout()
    plt.suptitle(f"Top-K recovery of perturbed gene sets ({df_filt_name}, {eval_set} set)")
    plt.savefig(os.path.join(FIGDIR, f"2D_simulations_heatmap_topk_recovery_{df_filt_name.replace(',', '').replace(' ', '_')}_{eval_set}.png"), format='png', dpi=600)
    plt.show()

    # 5. Std dev heatmap
    fig, axes = plt.subplots(1, 3, figsize=(figsize_tuple[0]*3, figsize_tuple[1]), sharey=True)
    for col_idx, gene_set in enumerate(["only linear", "only nonlinear", "both"]):
        _, stdev_df, _ = aggregate_topk_recovery(per_run_filt, df_filt, gene_set)
        ax = axes[col_idx]
        sns.heatmap(stdev_df, annot=True, cmap=cmap_style, ax=ax, fmt=".2f", vmin=0, vmax=1, cbar=False)
        ax.set_aspect(aspect_ratio_value)
        ax.set_title(f"{gene_set}")
        ax.set_xlabel("sigma")
        if col_idx == 0:
            ax.set_ylabel("odds ratio")
        else:
            ax.set_ylabel("")
    axes[0].set_yticklabels(get_smart_odds_ratio_labels(axes[0]))
    plt.tight_layout()
    plt.suptitle(f"Std dev of top-K Recovery ({df_filt_name}, {eval_set} set)")
    plt.savefig(os.path.join(FIGDIR, f"2D_simulations_heatmap_topk_recovery_stddev_{df_filt_name.replace(',', '').replace(' ', '_')}_{eval_set}.png"), format='png', dpi=600)
    plt.show()

In [None]:
_, _, tmp1 = aggregate_topk_recovery(per_run_filt, df_filt, "both")
_, _, tmp2 = aggregate_topk_recovery(per_run_filt, df_filt, "only nonlinear")
_, _, tmp3 = aggregate_topk_recovery(per_run_filt, df_filt, "only linear")
df_topk = pd.concat([tmp1, tmp2, tmp3], ignore_index=True)
df_topk = df_topk[(df_topk["odds_ratio"] == 1.1) ]
df_topk

In [None]:


# 5. Plot
# color_mapping = {
#     "only linear": '#0173B2',
#     "only nonlinear":  '#029E73',
#     "both": '#E69F00',
# }
color_mapping = {
    "only linear": '#56B4E9',
    "only nonlinear":  '#009E73',
    "both": '#E69F00'#'#CC79A7',
}

# plt.figure(figsize=(7,4))
plt.figure(figsize=(figsize_tuple[0]*2, figsize_tuple[1]*1.))

sns.boxplot(
    data=df_topk,
    x="sigma",
    y="topk_recovery",
    hue="gene_set_type",
    showfliers=True,
    palette=color_mapping
)

# sns.stripplot(
#     data=df_topk,
#     x="sigma",
#     y="topk_recovery",
#     hue="gene_set_type",
#     dodge=True,
#     color="black",
#     alpha=0.25,
#     legend=False
# )

plt.ylim(0, 1.05)
plt.title(f"Top-K recovery at OR=1.1 ({df_filt_name}, {eval_set} set)")
plt.ylabel("Top-K Recovery")
plt.xlabel("sigma")
plt.legend(title="Gene Set", bbox_to_anchor=(1.01, 1), loc='upper left', borderaxespad=0.)
# plt.legend(title="Gene Set", loc="best")
plt.tight_layout()
# plt.savefig(os.path.join(FIGDIR, f"2D_simulations_boxplot_topk_recovery_or1.1_{df_filt_name.replace(',', '').replace(' ', '_')}_{eval_set}.png"), format='png', dpi=600)
plt.show()


## Logic for generating feature importance and rank heatmaps

1. Filter the relevant runs
- Start by filtering df_feature_importance_paths for the subset you want (e.g., Continuous, N=10k).
- Make this filtering modular so you can easily swap to other subsets (e.g., N=1k).

2. Extract gene sets for each run
- For each row/run, parse the gene lists:
    - deltaMuGenes
    - mod0_genes and mod1_genes (combine for "corr module")
- Compute the three gene sets for each run:
    - deltaMuGenes
    - only in corr module (mod0+mod1 minus deltaMu)
    - intersection (deltaMu âˆ© (mod0+mod1))

3. Aggregate gene ranks/importances

- For each run, use the corresponding group_identifier to get the gene importance/rank DataFrame from df_ranks_by_key (or df_imps_by_key).
- For each gene set, extract the ranks/importances for the genes in that set.
- Aggregate (e.g., median) across genes in the set, for each run.

4. Aggregate across runs with the same group_identifier
- For each group_identifier (i.e., each cell in the heatmap), aggregate the median ranks/importances across all runs with that group_identifier.

5. Build heatmap DataFrames
- For each gene set, build a DataFrame with odds_ratio as rows and sigma as columns, values are the top K recovery (alternatives: aggregated median ranks/importances.)

6. Plot heatmaps
- Use the same plotting style as elsewhere in the notebook.