# Generate Performance Boxplots and Feature Importance Calculations from WandB Runs

## Import Required Libraries
Import necessary libraries such as `wandb`, `pandas`, `numpy`, `matplotlib`, and `seaborn`.

In [None]:
# Import Required Libraries
import logging
import os
import sys

import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns

sys.path.insert(0, "../..")  # add project_config to path
import project_config

import wandb

# Setup Logging and Configuration
logging.basicConfig(
    format="%(asctime)s %(levelname)-8s [%(name)s] %(message)s",
    level=logging.INFO,
    datefmt="%Y-%m-%d %H:%M:%S",
    force=True,
)
logger = logging.getLogger(__name__)

## Setup Logging and Configuration
Set up logging and define configurations such as WandB project name, sweep ID, and directories for saving results.

Here, looking at single-gene perturbation on P1000 somatic mutation dataset backbone. Essentially, we spiked in signal on a single gene (AR) and then assigned these samples to class 1.

In [None]:
# Path to file with performance metrics and feature importance paths for each run
METRIC_CSV_PATH = project_config.RESULT_DIR / "single_gene_spike_in_simulation_prediction_metrics_per_run_full.csv"

# Define directories for saving results
FIGDIR = project_config.FIGURE_2_DIR
RESULTS_DIR = project_config.RESULT_DIR_FIGURE_2

# make directory if it doesn't exist
os.makedirs(os.path.join(FIGDIR), exist_ok=True)

# Figure parameters
figsize_tuple = (2.1, 2)

aspect_ratio_value = 1.7
cmap_style = "viridis"

## Load Performance Metrics (grouped by model_type, OR, and control_frequency)
Use the WandB API to fetch runs and extract performance metrics grouped by `model_type`, `odds_ratio`, `control_frequency`, and `n_features` for purposes of assessing the 'power curve' of the model. Store the results in a DataFrame.

In [None]:
df = pd.read_csv(METRIC_CSV_PATH)

# Average over seeds
group_by_cols = ["model_type", "n_features", "odds_ratio", "control_frequency"]

# Add a counts column for each group
df_counts = df.groupby(group_by_cols).size().reset_index(name="count")

# Average over seeds
df_avg = df.groupby(group_by_cols, as_index=False).mean()

# Merge counts into df_avg
df_avg = df_avg.merge(df_counts, on=group_by_cols)
df_avg

## helper functions

In [None]:
def smart_ticklabels(labels):
    return [f"{int(l)}" if float(l).is_integer() else f"{l:.1f}" for l in labels]

def get_smart_odds_ratio_labels(axessubplot_object):
    """Expects type <class 'matplotlib.axes._subplots.AxesSubplot'> as input"""
    smart_yticks = smart_ticklabels([float(l.get_text()) for l in axessubplot_object.get_yticklabels()])
    return smart_yticks

def heatmap_no_counts(df, title, ax, eval_set, **kwargs):
    precision = df.pivot(index="odds_ratio", columns="control_frequency", values=f"{eval_set}_avg_precision")
    sns.heatmap(
        precision, cmap=cmap_style, annot=True, fmt=".2f",
        cbar=False, ax=ax, **kwargs
    )
    ax.set_aspect(aspect_ratio_value, adjustable='box')
    ax.set_yticklabels(get_smart_odds_ratio_labels(ax))
    ax.set_title(title)
    ax.set_xlabel("control_frequency")
    ax.set_ylabel("odds ratio")

In [None]:
# matches feature importance plots in shape and style
# adding x-axis labels to all
eval_set = "train"
# filter to n_features 
df_avg_to_plot = df_avg # df_avg[df_avg["n_features"] == 10]

# rename values in the model_type column: pnet --> P-NET, rf --> Random Forest
df_avg_to_plot["model_type"] = df_avg_to_plot["model_type"].replace({
    "pnet": "P-NET",
    "rf": "Random Forest"
})

metric = f"{eval_set}_avg_precision" 

# Compute global color scale limits
vmin = df_avg_to_plot[metric].min()
vmax = df_avg_to_plot[metric].max()

g = sns.FacetGrid(
    df_avg_to_plot,
    row="n_features",
    col="model_type",
    margin_titles=True,
    height=figsize_tuple[0],
    aspect=aspect_ratio_value
)

def draw_heatmap(data, **kwargs):
    pivoted = data.pivot(index="odds_ratio", columns="control_frequency", values=metric)
    sns.heatmap(pivoted, cmap="viridis", annot=True, fmt=".2f", cbar=False, vmax=vmax, vmin=vmin,
                **kwargs)

g.map_dataframe(draw_heatmap)
g.set_titles(row_template="# features = {row_name}", col_template="model = {col_name}")
g.set_axis_labels("control frequency", "odds ratio")
# Adjust spacing
g.figure.subplots_adjust(top=0.85, hspace=0.1)

g.figure.suptitle(
    f"AUPRC on single-gene perturbation datasets ({eval_set} set)",
    fontsize=12
)

# Ensure all x-axis tick labels are shown
for col_idx, ax in enumerate(g.axes.flat):
    ax.set_xticklabels(ax.get_xticklabels())
    ax.tick_params(axis='x', labelbottom=True)
    if col_idx == 0:
        ax.tick_params(axis='y', labelleft=True)
        ax.set_yticklabels(get_smart_odds_ratio_labels(ax))

plt.tight_layout()
plt.savefig(os.path.join(FIGDIR, f"heatmap_auprc_single_gene_perturbation_{eval_set}.svg"), format='svg', dpi=1200)
plt.savefig(os.path.join(FIGDIR, f"heatmap_auprc_single_gene_perturbation_{eval_set}.png"), format='png', dpi=600)

plt.show()


# How do P-NET feature rankings compare with ground truth?
In this 1D (single gene) perturbation setting, we spiked in a signal to a single gene. We chose "AR". Now, we will investigate the P-NET assigned rank of this gene

Plot: OR x control_frequency grid where the value is the rank (importance) assigned by the model to gene "AR".

In [None]:
# want df with 1) y path 2) feature importances path 3) gene importances path 4) unique identifiers (odds ratio, control frequency, n_features, model_type)
df_feature_importance_paths = df[["save_dir", "target_f", "model_type", "n_features", "odds_ratio", "control_frequency"]].copy()

# filter to only model_type == "pnet"
df_feature_importance_paths = df_feature_importance_paths[df_feature_importance_paths["model_type"] == "pnet"].reset_index(drop=True)


# in "save_dir" column, replace "../../results/" with "/mnt/disks/gmiller_data1/pnet/results"
df_feature_importance_paths["save_dir"] = df_feature_importance_paths["save_dir"].str.replace("../../results/", "/mnt/disks/gmiller_data1/pnet/results/", regex=False)

eval_set = "test"
df_feature_importance_paths["feature_importances_path"] = df_feature_importance_paths.apply(lambda row: os.path.join(row["save_dir"], f"{eval_set}_gene_feature_importances.csv"), axis=1)
df_feature_importance_paths["gene_importances_path"] = df_feature_importance_paths.apply(lambda row: os.path.join(row["save_dir"], f"{eval_set}_gene_importances.csv"), axis=1)

# Add a unique group identifer column by joining together the unique identifiers: ["model_type", "n_features", "odds_ratio", "control_frequency"]
df_feature_importance_paths["group_identifier"] = df_feature_importance_paths.apply(
    lambda row: f"OR-{row['odds_ratio']}_ctrlFreq-{row['control_frequency']}_nFeatures-{row['n_features']}",
    axis=1
)


In [None]:
def load_feature_importances(importances_path):
    """
    Load feature importance data from a CSV file.

    Args:
        importances_path (str): Path to the feature importance file.

    Returns:
        pd.DataFrame: DataFrame of feature importances.
    """
    if not os.path.exists(importances_path):
        raise FileNotFoundError(f"File not found: {importances_path}")
    logger.debug(f"Loading feature importances from {importances_path}")
    return pd.read_csv(importances_path).set_index('Unnamed: 0')


def process_importances(imps, response_df):
    """
    Process feature importances by joining with response data and calculating differences between sample classes.
    This function computes the mean feature importances for each response class and then calculates the difference between them.
    The result is a Series of feature importance differences.

    Args:
        imps (pd.DataFrame): Feature importance DataFrame.
        response_df (pd.DataFrame): Response variable DataFrame.

    Returns:
        pd.Series: Processed feature importance differences.
    """
    logger.debug(f"head of imps.join(response_df).groupby('response').mean(): {imps.join(response_df).groupby('response').mean().head()}")
    logger.debug(f"shape of imps.join(response_df).groupby('response').mean(): {imps.join(response_df).groupby('response').mean().shape}")
    return imps.join(response_df).groupby('response').mean().diff(axis=0).iloc[1]

def load_response_variable(response_path):
    response_df = pd.read_csv(response_path).set_index('Tumor_Sample_Barcode')
    # rename the column to 'response' for consistency
    response_df.rename(columns={"is_met": "response"}, inplace=True)
    return response_df

def process_feature_importances(df_feature_importance_paths, response_df_path_column='target_f', importance_path_column='feature_importances_path', group_identifier_column='group_identifier'):
    """
    Process feature importance data for multiple runs and group by dataset.

    Args:
        df_feature_importance_paths (pd.DataFrame): DataFrame with feature importance paths and dataset info.
        response_df (pd.DataFrame): Response variable DataFrame.

    Returns:
        tuple: Dictionaries of DataFrames for feature importances and ranks grouped by the input datasets used in the model.
    """
    imps_by_key = {}
    ranks_by_key = {}

    for _, row in df_feature_importance_paths.iterrows():
        try:
            # Load the response variable DataFrame from the specified path
            response_df = load_response_variable(row[response_df_path_column])
            # Load feature importances from the specified path
            imps = load_feature_importances(row[importance_path_column])
            processed_imps = process_importances(imps, response_df)
            ranks = processed_imps.abs().rank(ascending=False)

            key = row[group_identifier_column]
            imps_by_key.setdefault(key, []).append(processed_imps)
            ranks_by_key.setdefault(key, []).append(ranks)
        except FileNotFoundError as e:
            print(e)
            continue

    df_imps_by_key = {key: pd.DataFrame(imps_list) for key, imps_list in imps_by_key.items()}
    df_ranks_by_key = {key: pd.DataFrame(ranks_list) for key, ranks_list in ranks_by_key.items()}
    return df_imps_by_key, df_ranks_by_key



def extract_top_features_from_df(df_per_dataset, top_n=10, keep_smallest_n=True, index_label=None):
    """
    Extract the top N features by rank for each dataset.

    Args:
        df_per_dataset (dict): Dictionary containing feature-level pd DataFrames for each dataset.
        top_n (int): Number of top features to extract.
        keep_smallest_n (bool): Whether to sort in ascending order (lower rank is better).

    Returns:
        pd.DataFrame: DataFrame containing the top N features for each dataset.
    """
    top_features_df = pd.DataFrame()

    # Iterate over the dictionary to calculate top features
    for dataset, df in df_per_dataset.items():
        # Calculate the mean rank for each feature and select the top N
        top_features = df.mean(axis=0).sort_values(ascending=keep_smallest_n)[:top_n]
        # Add the top features as a column to the DataFrame
        top_features_df[dataset] = top_features.index

    # Set the index of the DataFrame to be 1 through top_n
    top_features_df.index = range(1, top_n + 1)

    if index_label is not None:
        top_features_df.index.name = index_label

    return top_features_df


In [None]:
logger.info("Get the feature importances and ranks for each dataset combination")
# response_df = load_response_variable()
# df_feature_importance_paths = fetch_feature_importance_paths(runs, GROUP_NAME)

# # gene x modality
# df_imps_by_key, df_ranks_by_key = process_feature_importances(df_feature_importance_paths)

# gene
df_imps_by_key, df_ranks_by_key = process_feature_importances(df_feature_importance_paths, importance_path_column='gene_importances_path')

In [None]:
top_10_features_by_rank = extract_top_features_from_df(df_ranks_by_key, top_n=10, keep_smallest_n=True, index_label="rank")
top_10_features_by_rank

In [None]:
# Extract AR importances and ranks from the processed dicts
AR_records = []

for key, df_imp in df_imps_by_key.items():
    mean_imp = df_imp.mean()
    mean_rank = df_ranks_by_key[key].mean()
    
    if "AR" in mean_imp:
        AR_records.append({
            "group_identifier": key,
            "AR_importance": mean_imp["AR"],
            "AR_absolute_importance": abs(mean_imp["AR"]),
            "AR_rank": mean_rank["AR"]
        })

df_AR = pd.DataFrame(AR_records)
# Merge AR stats into the plotting dataframe
df_plot = df_feature_importance_paths.merge(df_AR, on="group_identifier", how="left")
df_plot_avg = df_plot.groupby(group_by_cols, as_index=False).mean()


In [None]:
# current: using this for the rank/imps

metric = "AR_rank"  # or "AR_rank", "AR_absolute_importance", "AR_importance"
eval_set = "test"
# df_plot_avg_tmp = df_plot_avg[df_plot_avg["n_features"] == 100]
df_plot_avg_tmp =df_plot_avg

# rename values in the model_type column: pnet --> P-NET, rf --> Random Forest
df_plot_avg_tmp["model_type"] = df_plot_avg_tmp["model_type"].replace({
    "pnet": "P-NET",
    "rf": "Random Forest"
})

if metric == "AR_rank":
    cmap_tmp = "cividis_r"
elif metric == "AR_importance":
    cmap_tmp = "cividis"

g = sns.FacetGrid(
    df_plot_avg_tmp,
    row="n_features",
    col="model_type",
    margin_titles=True,
    height=figsize_tuple[0],
    aspect=aspect_ratio_value
)

def draw_heatmap(data, **kwargs):
    pivoted = data.pivot(index="odds_ratio", columns="control_frequency", values=metric)
    sns.heatmap(pivoted, cmap=cmap_tmp, annot=True, fmt=".1f", cbar=False, **kwargs)

g.map_dataframe(draw_heatmap)
g.set_titles(row_template="# features = {row_name}", col_template="model = {col_name}")
g.set_axis_labels("control frequency", "odds ratio")
# Adjust spacing
g.figure.subplots_adjust(top=0.85, hspace=0.1)

g.figure.suptitle(
    f"Perturbed gene {'importance' if metric == 'AR_importance' else 'rank'} ({eval_set} set)",
    fontsize=12
)

# Ensure all x-axis tick labels are shown
for ax in g.axes.flat:
    ax.set_xticklabels(ax.get_xticklabels())
    ax.tick_params(axis='x', labelbottom=True)
    ax.set_yticklabels(get_smart_odds_ratio_labels(ax))

plt.tight_layout()
plt.savefig(os.path.join(FIGDIR, f"heatmap_perturbed_gene_{metric}_{eval_set}.svg"), format='svg', dpi=1200)
plt.savefig(os.path.join(FIGDIR, f"heatmap_perturbed_gene_{metric}_{eval_set}.png"), format='png', dpi=600)
plt.show()


In [None]:
# current: using this for the rank/imps

metric = "AR_importance"  # or "AR_rank", "AR_absolute_importance", "AR_importance"
eval_set = "test"
# df_plot_avg_tmp = df_plot_avg[df_plot_avg["n_features"] == 100]
df_plot_avg_tmp =df_plot_avg

# rename values in the model_type column: pnet --> P-NET, rf --> Random Forest
df_plot_avg_tmp["model_type"] = df_plot_avg_tmp["model_type"].replace({
    "pnet": "P-NET",
    "rf": "Random Forest"
})

if metric == "AR_rank":
    cmap_tmp = "cividis_r"
elif metric == "AR_importance":
    cmap_tmp = "cividis"

g = sns.FacetGrid(
    df_plot_avg_tmp,
    row="n_features",
    col="model_type",
    margin_titles=True,
    height=figsize_tuple[0],
    aspect=aspect_ratio_value
)

def draw_heatmap(data, **kwargs):
    pivoted = data.pivot(index="odds_ratio", columns="control_frequency", values=metric)
    sns.heatmap(pivoted, cmap=cmap_tmp, annot=True, fmt=".1f", cbar=False, **kwargs)

g.map_dataframe(draw_heatmap)
g.set_titles(row_template="# features = {row_name}", col_template="model = {col_name}")
g.set_axis_labels("control frequency", "odds ratio")
# Adjust spacing
g.figure.subplots_adjust(top=0.85, hspace=0.1)

g.figure.suptitle(
    f"Perturbed gene {'importance' if metric == 'AR_importance' else 'rank'} ({eval_set} set)",
    fontsize=12
)

# Ensure all x-axis tick labels are shown
for ax in g.axes.flat:
    ax.set_xticklabels(ax.get_xticklabels())
    ax.tick_params(axis='x', labelbottom=True)
    ax.set_yticklabels(get_smart_odds_ratio_labels(ax))

plt.tight_layout()
plt.savefig(os.path.join(FIGDIR, f"heatmap_perturbed_gene_{metric}_{eval_set}.svg"), format='svg', dpi=1200)
plt.savefig(os.path.join(FIGDIR, f"heatmap_perturbed_gene_{metric}_{eval_set}.png"), format='png', dpi=600)
plt.show()
