In [None]:
import logging
import os
import pickle

import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

import pnet.performance_and_feature_importance_stability as stability_utils
from pnet import report_and_eval

logging.basicConfig(
    format="%(asctime)s %(levelname)-8s [%(name)s] %(message)s",
    level=logging.INFO,
    datefmt="%Y-%m-%d %H:%M:%S",
    force=True,  # Only needed if youâ€™re reconfiguring logging in a running session (e.g., Jupyter)
)
logger = logging.getLogger(__name__)

%load_ext autoreload
%autoreload 2

In [None]:
# # wandb setup
# import wandb
# os.environ['WANDB_NOTEBOOK_NAME'] = "stability_and_performance_analysis.ipynb"
# wandb.login()
# run = wandb.init(
#     project="prostate_met_status",
#     group="analysis_prostate_somatic_and_germline"
# )

# TODO:
- [ ] examine stability of P-NET layers
- [ ] try adding germline (and additional data aka confounders) to RF, BDT models. Prioritize the BDT since it is more promising.

In [None]:
PNET_DIR = '../../pnet/results/gene_rank_stability_v2' # (after improved P-NET stability)
BDT_DIR = '../../pnet/results/somatic_bdt_eval_set_test'
RF_DIR = '../../pnet/results/somatic_rf_eval_set_test'
ARXIV_PNET_DIR = '../../cancer-net/reprod_report' # based on the other group's reproducibility report

FIGDIR = '../figures/'

# Relative performance
TODO:
- load RF, BDT data from W&B.
- load P-NET data from pickle files.
- make standardized DF for plotting (rows = runs, columns = model type, values = performance metric)

## P-NET on different germline:somatic dataset combinations

### Result: everything > all rare > all missense > all common > rare LOF > all LOF > just somatic
Here we look at box-and-whisker plots of validation AUC (across 5 runs each), examining the ordering from highest to lowest validation performance (sorted by median).

For germline:somatic we see
everything > all rare > all missense > all common > rare LOF > all LOF > just somatic

For germline only we see overall lower performance, and considerably different order:
all > all missense > all LOF > all common > rare LOF > all rare
Interestingly, we actually get pretty decent validation AUC with only germline data for the top two combinations: all LOF and missense, or just all rare LOF/missense variants. In fact, the mean AUC is around 0.72/0.71 for these groups compared to 0.59/0.56 in the worst two groups.

Q: why does the order of "usefullness" change so much? Is it because of which signal is redundant with somatic data?


In [None]:
import wandb
import pandas as pd
from collections import defaultdict

# Initialize the API
api = wandb.Api()

# Define parameters
project_name = "millergw/prostate_met_status"
sweep_id = "cmlmrw2s"

# Fetch the runs
runs = api.runs(project_name, filters={"sweep": sweep_id, "state": "finished"})

# Group AUC scores by model_type and datasets
auc_scores_by_group = defaultdict(list)

# Extract relevant information and group by model_type and datasets
for run in runs:
    model_type = run.config.get("model_type", "unknown")
    datasets = run.config.get("datasets", "unknown")
    auc_score = run.summary.get("validation_roc_auc_score")
    if auc_score is not None:
        auc_scores_by_group[(model_type, datasets)].append(auc_score)

# Convert to DataFrame
df = pd.DataFrame([
    {"model_type": model_type, "datasets": datasets, "auc": auc_score}
    for (model_type, datasets), auc_scores in auc_scores_by_group.items()
    for auc_score in auc_scores
])

logging.debug("Sorting model order by decreasing mean")
# Calculate the mean of each group
group_means = df.groupby(['model_type', 'datasets'])['auc'].mean().sort_values(ascending=False)

# Reorder the DataFrame based on the descending group means
df = df.set_index(['model_type', 'datasets']).loc[group_means.index].reset_index()
display(df)

# Calculate mean, standard deviation, and sample size of the AUC scores for each group
grouped_stats = df.groupby(['model_type', 'datasets'])['auc'].agg(
    mean='mean',
    std='std',
    num_samples='count'
).loc[group_means.index].reset_index()

display(grouped_stats)

In [None]:
def get_group_ordered_by_statistic(df, group_col_name, order_col_name='auc', stat="median", ascending=False):
    if stat=="mean":
        # Calculate mean for each dataset
        sorted_df = df.groupby(group_col_name)[order_col_name].mean().sort_values(ascending=ascending)
    elif stat=="median":
        # Calculate median for each dataset
        sorted_df = df.groupby(group_col_name)[order_col_name].median().sort_values(ascending=ascending)
    else:
        print("TODO: not implemented")
        return
    # Order the datasets by descending mean AUC
    group_order = sorted_df.index.tolist()
    return group_order


def plot_auc_boxplot_for_grouped_df(df, group_order=None):
    # TODO: wip. Start here.
    if group_order is None:
        # Extract unique dataset orders
        group_order = df['datasets'].unique()

    # Set color pallette and plot size
    pnet_palette = sns.color_palette('colorblind', n_colors=len(group_order))
    plt.figure(figsize=(10, 6))

    # Box plot for 'pnet' model type
    pnet_df = df[df['model_type'] == 'pnet']
    pnet_order = get_group_ordered_by_statistic(pnet_df, group_col_name="datasets")
    sns.boxplot(data=pnet_df, x='datasets', y='auc', order=pnet_order, dodge=True, palette=pnet_palette)
    sns.stripplot(data=pnet_df, x='datasets', y='auc', order=pnet_order, color='black', jitter=0.2, alpha=0.3)
    plt.title('Validation AUC (P-NET)')
    plt.xlabel('Datasets')
    plt.ylabel('AUC')
    plt.xticks([])  # Hide xticks
    plt.grid(True)

    # Set y-axis limits manually?
    plt.set_ylim(plt.get_ylim()[0], plt.get_ylim()[1])

    # Custom legend
    handles = [plt.Rectangle((0,0),1,1, color=color) for color in pnet_palette]
    plt.legend(handles, group_order, title='Datasets', loc='center left', bbox_to_anchor=(1, 0.5))

    return plt

def plot_boxplots(df, group_order=None):
    if group_order is None:
        # Extract unique dataset orders
        group_order = df['datasets'].unique()

    # Create color palette
    pnet_palette = sns.color_palette('colorblind', n_colors=len(group_order))

    # Create separate box-and-whisker plots for each model type
    plt.figure(figsize=(10, 6))

    # Box plot for 'pnet' model type
    ax1 = plt.subplot(2, 1, 1)
    pnet_df = df[df['model_type'] == 'pnet']
    pnet_order = get_group_ordered_by_statistic(pnet_df, group_col_name="datasets")
    sns.boxplot(data=pnet_df, x='datasets', y='auc', order=pnet_order, dodge=True, palette=pnet_palette)
    sns.stripplot(data=pnet_df, x='datasets', y='auc', order=pnet_order, color='black', jitter=0.2, alpha=0.3)
    plt.title('Validation AUC (P-NET)')
    plt.xlabel('Datasets')
    plt.ylabel('AUC')
    plt.xticks([])  # Hide xticks
    plt.grid(True)

    # Box plot for 'rf' model type
    ax2 = plt.subplot(2, 1, 2)
    rf_df = df[df['model_type'] == 'rf']
    sns.boxplot(data=rf_df, x='datasets', y='auc', order=pnet_order, dodge=True, palette=pnet_palette)
    sns.stripplot(data=rf_df, x='datasets', y='auc', order=pnet_order, color='black', jitter=0.2, alpha=0.3)
    plt.title('Validation AUC (rf)')
    plt.xlabel('Datasets')
    plt.ylabel('AUC')
    plt.xticks(rotation=90)
    plt.grid(True)

    # Set y-axis limits to be the same for both plots
    ymin = min(ax1.get_ylim()[0], ax2.get_ylim()[0])
    ymax = max(ax1.get_ylim()[1], ax2.get_ylim()[1])
    ax1.set_ylim(ymin, ymax)
    ax2.set_ylim(ymin, ymax)

    # Custom legend
    handles = [plt.Rectangle((0,0),1,1, color=color) for color in pnet_palette]
    plt.legend(handles, group_order, title='Datasets', loc='center left', bbox_to_anchor=(1, 0.5))

    return plt

logging.info("Only germline data")
filtered_data = df[(~df['datasets'].str.contains('somatic'))] # Filtered dataset
plt = plot_boxplots(filtered_data)
plt.show()

logging.info("Somatic and germline data, plus somatic only")
filtered_data = df[(df['datasets'].str.contains('somatic') & df['datasets'].str.contains('germline')) | (df['datasets'] == "somatic_amp somatic_del somatic_mut")] # Filtered dataset
plt = plot_boxplots(filtered_data)
plt.show()

logging.info("All data combos")
filtered_data = df
plt = plot_boxplots(filtered_data)
plt.show()


In [None]:
# Create a box-and-whisker plot
plt.figure(figsize=(10, 6))
sns.boxplot(data=df, x='datasets', y='auc', hue='model_type', dodge=True)

# Add grid and legend
plt.grid(True)
plt.legend(title='Model Type')

# Set labels and title
plt.xlabel('Datasets')
plt.ylabel('AUC')
plt.title('AUC grouped by datasets, for P-NET and RF')
plt.xticks(rotation=90)

# Show plot
plt.tight_layout()
plt.show()

## P-NET feature importance (can also examine gene-level importance)
1. SNR (signal to noise ratio) as Marc G uses
2. Manual look at top-ranked genes (just using scores assigned by P-NET)

### SNR for P-NET

In [None]:
# Function to get top N features by SNR for each dataset
def get_top_n_features(df, n=10):
    top_features = {}
    for dataset in df.columns:
        top_features[dataset] = df[[dataset]].nlargest(n, dataset)
    return top_features


# Function to calculate the proportion of total mean importance score contributed by top N features for each dataset
def get_top_n_features_proportion(df, n=10):
    raise AssertionError("the function get_top_n_features_proportion() is not implemented correctly yet.")
    top_n_features_proportion = {}
    total_mean_importance = df.mean(axis=0)
    print(total_mean_importance)
    for dataset in df.columns:
        top_n_features_proportion[dataset] = df.nlargest(n, dataset)[dataset].sum() / total_mean_importance[dataset]
    return top_n_features_proportion

In [None]:
# Examining the impact of different input data
prostate_response = pd.read_csv('../../pnet_germline/data/pnet_database/prostate/processed/response_paper.csv')
prostate_response.rename(columns={'id': "Tumor_Sample_Barcode"}, inplace=True)
prostate_response.set_index('Tumor_Sample_Barcode', inplace=True)

import wandb
import pandas as pd

# Initialize the API
api = wandb.Api()

# Define parameters
who = "validation"
project_name = "prostate_met_status"
group_name = "pnet_somatic_and_germline_exp_003"
sweep_id = "cmlmrw2s" 

# Fetch the runs
runs = api.runs(project_name, filters={"sweep": sweep_id, "state": "finished"})

# Create an empty list to store DataFrames
dfs = []

for run in runs:
    model_type = run.config.get("model_type", "unknown")
    datasets = run.config.get("datasets", "unknown")
    subdir_name = f"{model_type}_eval_set_{who}"

    # Paths to feature and gene importances files
    feature_importances_path = f'../results/{group_name}/{subdir_name}/wandbID_{run.id}/{who}_gene_feature_importances.csv'
    gene_importances_path = f'../results/{group_name}/{subdir_name}/wandbID_{run.id}/{who}_gene_importances.csv'

    # Create a DataFrame for the current run
    df = pd.DataFrame({
        "run_id": [run.id],
        "model_type": [model_type],
        "datasets": [datasets],
        "feature_importances_path": [feature_importances_path],
        # "gene_importances_path": [gene_importances_path]
    })

    # Append the DataFrame to the list
    dfs.append(df)

# Concatenate all DataFrames in the list into a single DataFrame
pnet_df = pd.concat(dfs, ignore_index=True)
pnet_df.set_index("run_id", inplace=True)

# Now you have a DataFrame containing the run information
# You can use this DataFrame to access the paths to feature and gene importances files
display(pnet_df)

In [None]:
# Examining the h1 regularization experiment results
prostate_response = pd.read_csv('../../pnet_germline/data/pnet_database/prostate/processed/response_paper.csv')
prostate_response.rename(columns={'id': "Tumor_Sample_Barcode"}, inplace=True)
prostate_response.set_index('Tumor_Sample_Barcode', inplace=True)

import wandb
import pandas as pd

# Initialize the API
api = wandb.Api()

# Define parameters
who = "validation"
project_name = "prostate_met_status"
group_name = "pnet_h1_regularization_001"
sweep_id = "g8cl6qur" 

# Fetch the runs
runs = api.runs(project_name, filters={"sweep": sweep_id, "state": "finished"})

# Create an empty list to store DataFrames
dfs = []

for run in runs:
    model_type = run.config.get("model_type", "unknown")
    h1_alpha = run.config.get("h1_alpha", "unknown")
    datasets = run.config.get("datasets", "unknown")
    h1_regularization_method = run.config.get("h1_regularization_method", "unknown")
    subdir_name = f"{model_type}_eval_set_{who}"

    # Paths to feature and gene importances files
    feature_importances_path = f'../results/{group_name}/{subdir_name}/wandbID_{run.id}/{who}_gene_feature_importances.csv'
    gene_importances_path = f'../results/{group_name}/{subdir_name}/wandbID_{run.id}/{who}_gene_importances.csv'

    # Create a DataFrame for the current run
    df = pd.DataFrame({
        "run_id": [run.id],
        "model_type": [model_type],
        "datasets": [datasets],
        "h1_alpha": [h1_alpha],
        "h1_regularization_method": [h1_regularization_method],
        "feature_importances_path": [feature_importances_path],
        # "gene_importances_path": [gene_importances_path]
    })

    # Append the DataFrame to the list
    dfs.append(df)

# Concatenate all DataFrames in the list into a single DataFrame
h1_df = pd.concat(dfs, ignore_index=True)
h1_df.set_index("run_id", inplace=True)

# Now you have a DataFrame containing the run information
# You can use this DataFrame to access the paths to feature and gene importances files
display(h1_df)

In the below code:
1. We use dictionaries imps_by_dataset and ranks_by_dataset to store feature importances and ranks grouped by dataset.
1. For each row in pnet_df, we append the importances and ranks to the appropriate dataset list.
1. After collecting the data, we convert each dataset's list into a DataFrame.
1. We compute the SNR for each dataset and store the results in a dictionary.
1. Finally, we convert the SNR dictionary to a DataFrame for easier handling and analysis

In [None]:
wandb_sweep_df = h1_df #pnet_df
wandb_sweep_df

In [None]:
import pandas as pd

# Initialize dictionaries to store data by key
imps_by_key = {}
ranks_by_key = {}


# Ensure the DataFrame has a 'dataset' column or a column to group by
if 'h1_alpha' not in wandb_sweep_df.columns:
    raise ValueError("The DataFrame must contain a 'h1_alpha' column to group by")

# Loop through the rows of the DataFrame containing run information
for index, row in wandb_sweep_df.iterrows():
    importances_path = row['feature_importances_path']
    imps = pd.read_csv(importances_path).set_index('Unnamed: 0')
    
    # Read feature importances from the specified file path
    imps = imps.join(prostate_response).groupby('response').mean().diff(axis=0).iloc[1]
    
    ranks = imps.rank(ascending=False)
    
    key = row['h1_regularization_method']+'_'+row['h1_alpha']
    
    # Initialize lists for the key if not already present
    if key not in imps_by_key:
        imps_by_key[key] = []
        ranks_by_key[key] = []
    
    # Store the results in the corresponding key lists
    imps_by_key[key].append(imps)
    ranks_by_key[key].append(ranks)

# Create DataFrames for feature importances and ranks by key
df_imps_by_key = {key: pd.DataFrame(imps_list) for key, imps_list in imps_by_key.items()}
df_ranks_by_key = {key: pd.DataFrame(ranks_list) for key, ranks_list in ranks_by_key.items()}

# Calculate signal-to-noise ratio (SNR) for each key
snr_per_key = {}
for key, df_imps in df_imps_by_key.items():
    snr = df_imps.mean(axis=0) / (df_imps.std(axis=0) + 1e-9)
    snr_per_key[key] = snr

# Convert the SNR dictionary to a DataFrame for easier handling
snr_df = pd.DataFrame(snr_per_key).transpose().reset_index().rename(columns={'index': 'key'})

display(snr_df)

In [None]:
# Initialize an empty DataFrame to store the top 10 values by P-NET rank for each dataset
top_10_df = pd.DataFrame()

# Iterate over df_ranks_by_dataset dictionary
for dataset, df in df_imps_by_dataset.items():
    # Calculate the mean ranking for each run and select the top 10
    top_10 = df.mean(axis=0).sort_values(ascending=False)[:10]
    
    # Add the top 10 values as a column to the DataFrame
    top_10_df[dataset] = top_10.index

# Set the index of the DataFrame to be 1 through 10
top_10_df.index = range(1, 11)

# Display the DataFrame
print("Top features by P-NET importance score:")
display(top_10_df)


In [None]:
print("INTS3_somatic_del: High SNR score even though not important due to low variability run-to-run")
print(df_imps_by_dataset['somatic_amp somatic_del somatic_mut']['INTS3_somatic_del'])
print('')
print("AR_somatic_amp: Less high SNR score even though very important due to higher variability")
print(df_imps_by_dataset['somatic_amp somatic_del somatic_mut']['AR_somatic_amp'])

In [None]:
# Initialize an empty DataFrame to store the top 10 values by P-NET rank for each dataset
top_10_df = pd.DataFrame()

# Iterate over df_ranks_by_dataset dictionary
for dataset, rankdf in df_ranks_by_dataset.items():
    # Calculate the mean ranking for each run and select the top 10
    top_10 = rankdf.mean(axis=0).sort_values(ascending=True)[:10]
    
    # Add the top 10 values as a column to the DataFrame
    top_10_df[dataset] = top_10.index

# Set the index of the DataFrame to be 1 through 10
top_10_df.index = range(1, 11)

# Display the DataFrame
print("Top features by P-NET ranking:")
display(top_10_df)


In [None]:
# # Get top 10 features by P-NET ranking for each dataset
# top_n_features_pnet_rank = get_top_n_features(df_ranks_by_dataset, n=10)

# # Display the results
# for dataset, top_features in top_n_features_pnet_rank.items():
#     print(f"Top {len(top_features)} P-NET ranked features for {dataset}:")
#     display(top_features)
#     print()


In [None]:
top_n_features['somatic_amp somatic_del somatic_mut']

In [None]:
# Get top 10 features by SNR for each dataset
top_n_features = get_top_n_features(pd.DataFrame(snr_per_dataset), n=10)

# Display the results
for dataset, top_features in top_n_features.items():
    print(f"Top {len(top_features)} features for {dataset}:")
    display(top_features)
    print()

In [None]:
# Here we calculate SNR for each group ("datasets") separately to see how the top-ranked genes change with different inputs

# Initialize empty dictionaries to store results for each dataset group
snr_dict = {}

# Restrict to just P-NET models for now
pnet_df = pnet_df[pnet_df['model_type'] == "pnet"]

# Loop through each dataset group
for dataset_group in pnet_df['datasets'].unique():
    # Filter the DataFrame to include only rows corresponding to the current dataset group
    dataset_group_df = pnet_df[pnet_df['datasets'] == dataset_group]
    
    # Initialize empty DataFrames to store results for the current dataset group
    df_imps = pd.DataFrame()
    df_ranks = pd.DataFrame()
    
    # Loop through the rows of the DataFrame containing run information for the current dataset group
    for index, row in dataset_group_df.iterrows():
        importances_path = row['feature_importances_path']
        imps = pd.read_csv(importances_path).set_index('Unnamed: 0')

        # Read feature importances from the specified file path
        imps = imps.join(prostate_response).groupby('response').mean().diff(axis=0).iloc[1]

        ranks = imps.rank(ascending=False)

        # Store results in separate DataFrames for the current dataset group
        df_imps[f'{index}'] = imps
        df_ranks[f'{index}'] = ranks

    # Calculate signal-to-noise ratio (SNR) for the current dataset group
    snr = df_imps.mean(axis=1) / (df_imps.std(axis=1) + 1e-9)
    
    # Store the SNR results for the current dataset group in the dictionary
    snr_dict[dataset_group] = snr.sort_values(ascending=False)

# Now, snr_dict contains SNR results for each dataset group
snr_dict

In [None]:
# Initialize an empty DataFrame to store the top 10 values by P-NET rank for each dataset
top_10_snr_df = pd.DataFrame()

# Iterate over df_ranks_by_dataset dictionary
for dataset, df in snr_dict.items():
    # Calculate the mean ranking for each run and select the top 10
    top_10 = df.sort_values(ascending=False)[:10]
    
    # Add the top 10 values as a column to the DataFrame
    top_10_snr_df[dataset] = top_10.index

# Set the index of the DataFrame to be 1 through 10
top_10_snr_df.index = range(1, 11)

# Display the DataFrame
print("Top features by SNR score:")
display(top_10_snr_df)


In [None]:
for k, s in snr_dict.items():
    print(f"{k}")
    display(s[:10])

In [None]:
# calculating the SNR on all pnet model runs combined
# Initialize empty DataFrames to store results
df_imps = pd.DataFrame()
df_ranks = pd.DataFrame()

# Restrict to just P-NET models for now
pnet_df = pnet_df[pnet_df['model_type']=="pnet"]
# Loop through the rows of the DataFrame containing run information
for index, row in pnet_df.iterrows():
    importances_path = row['feature_importances_path']
    imps = pd.read_csv(importances_path).set_index('Unnamed: 0')

    # Read feature importances from the specified file path
    imps = imps.join(prostate_response).groupby('response').mean().diff(axis=0).iloc[1]
    
    ranks = imps.rank(ascending=False)

    # Store results in separate DataFrames
    df_imps[f'{index}'] = imps
    df_ranks[f'{index}'] = ranks

# Calculate signal-to-noise ratio (SNR) -- calculate the average score for a gene across all runs, then divide by its stdev across all runs
snr = df_imps.mean(axis=1) / (df_imps.std(axis=1) + 1e-9)


In [None]:
# This looks correct (fits my model for positive control success).
display(snr.sort_values(ascending=False)[:25])

## P-NET: impact of extra regularization on the first hidden layer (aka genes)
Sweep `g8cl6qur` was just looking at the somatic datasets (restricted to the ~943 paired samples only) and run on the vanilla P-NET model.\

See the run information for this W&B sweep at https://wandb.ai/millergw/prostate_met_status/sweeps/g8cl6qur/table?nw=nwusermillergw

In [None]:
import wandb
import pandas as pd
from collections import defaultdict

# Initialize the API
api = wandb.Api()

# Define parameters
project_name = "millergw/prostate_met_status"
sweep_id = "g8cl6qur"

# Fetch the runs
runs = api.runs(project_name, filters={"sweep": sweep_id, "state": "finished"})
print("We have found {} finished runs in project {} and sweep_id {}".format(len(runs), project_name, sweep_id))

# Group AUC scores by the h1 regularization parameters
auc_scores_by_group = defaultdict(list)

# Extract relevant information and group by model_type and datasets
for run in runs:
    h1_alpha = run.config.get("h1_alpha", "unknown")
    h1_regularization_method = run.config.get("h1_regularization_method", "unknown")
    l1_ratio = run.config.get("l1_ratio", "unknown")
    auc_score = run.summary.get("validation_roc_auc_score")
    if auc_score is not None:
        auc_scores_by_group[(h1_regularization_method, h1_alpha, l1_ratio)].append(auc_score)

# Convert to DataFrame
df = pd.DataFrame([
    {"h1_regularization_method":h1_regularization_method, "h1_alpha":h1_alpha, "l1_ratio":l1_ratio,
     "auc": auc_score}
    for (h1_regularization_method, h1_alpha, l1_ratio), auc_scores in auc_scores_by_group.items()
    for auc_score in auc_scores
])

logging.debug("Sorting model order by decreasing mean")
# Calculate the mean of each group
group_means = df.groupby(['h1_regularization_method', 'h1_alpha'])['auc'].mean().sort_values(ascending=False)

# Reorder the DataFrame based on the descending group means
df = df.set_index(['h1_regularization_method', 'h1_alpha']).loc[group_means.index].reset_index()
print(df.shape)

# Calculate mean, standard deviation, and sample size of the AUC scores for each group
grouped_stats = df.groupby(['h1_regularization_method', 'h1_alpha'])['auc'].agg(
    mean='mean',
    std='std',
    num_samples='count'
).loc[group_means.index].reset_index()

display(grouped_stats)

## Arxiv GAN reproduction

In [None]:
api = wandb.Api()

runs = api.runs("millergw/init_variance", filters={"sweep": "9sob1jgy", "state": "Finished"})

# TODO: fix/recover. Need to pull all of these metrics from wandb
d = {}
# Iterate over each dictionary in the list and extract metric values
for run in runs:
    aucs.append(d['auc'])
    accs.append(d['accuracy'])
    auprs.append(d['aupr'])
    f1_scores.append(d['f1'])
    precisions.append(d['precision'])
    recalls.append(d['recall'])

# Create a new dictionary with metric lists
arxiv_gcn_metrics_dict = {
    'auc': d['aucs'],
    'acc': d['accs'],
    'aupr': d['auprs'],
    'f1': d['f1_scores'],
    'precision': d['precisions'],
    'recall': d['recalls']
}

# Now metrics_dict contains the desired structure with lists of metric values
logging.info(f"Computing performance metric stats using data from {len(runs)} runs")
for k, v in arxiv_gcn_metrics_dict.items():
    print("Arxiv GCN %s: %.2f +/- %.2f" % (k, np.mean(v), np.std(v)))

## Arxiv P-NET reproduction

In [None]:
logging.info("Grabbing the performance metrics for arxiv P-NET from saved pickle file")
# with open(os.path.join(ARXIV_PNET_DIR, 'pnet_results.h6.torch.num_workers_16.pkl'), 'rb') as file:
with open(os.path.join(ARXIV_PNET_DIR, 'pnet_results.h6.torch.pkl'), 'rb') as file:
    arxiv_pnet_results = pickle.load(file)

arxiv_pnet_results[0].keys()
# arxiv_pnet_aucs = arxiv_pnet_results['auc']
# arxiv_pnet_aucs

In [None]:
# Initialize empty lists for each metric
aucs = []
accs = []
auprs = []
f1_scores = []
precisions = []
recalls = []

# Iterate over each dictionary in the list and extract metric values
for d in arxiv_pnet_results:
    aucs.append(d['auc'])
    accs.append(d['accuracy'])
    auprs.append(d['aupr'])
    f1_scores.append(d['f1'])
    precisions.append(d['precision'])
    recalls.append(d['recall'])

# Create a new dictionary with metric lists
arxiv_pnet_metrics_dict = {
    'auc': aucs,
    'acc': accs,
    'aupr': auprs,
    'f1': f1_scores,
    'precision': precisions,
    'recall': recalls
}

# Now metrics_dict contains the desired structure with lists of metric values
logging.info(f"Computing performance metric stats using data from {len(arxiv_pnet_results)} runs")
for k, v in arxiv_pnet_metrics_dict.items():
    print("Arxiv P-NET %s: %.2f +/- %.2f" % (k, np.mean(v), np.std(v)))

## BDT, RF, and Marc's torch P-NET

In [None]:
logging.info("Grabbing the performance metrics for BDT and RF from W&B")
# Specify your project and run group
entity = "millergw"
project_name = "prostate_met_status"
run_group = "bdt_stability_experiment_004"
metric = "test_roc_auc_score"
bdt_eval_auc = stability_utils.get_summary_metric_from_wandb(entity, project_name, metric, run_group=run_group)
print("BDT AUC:", np.mean(bdt_eval_auc))

# Specify your project and run group
run_group = "rf_stability_experiment_003"
metric = "test_roc_auc_score"
rf_eval_auc = stability_utils.get_summary_metric_from_wandb(entity, project_name, metric, run_group=run_group)
print("RF AUC:", np.mean(rf_eval_auc))

logging.info("Grabbing the performance metrics for P-NET from saved pickle file")
# Read gene_imps from a Pickle file (format: len 20 list --> pandas DFs, samples x genes?)
with open(os.path.join(PNET_DIR, 'aucs.pkl'), 'rb') as file:
    pnet_aucs = pickle.load(file)
print("P-NET AUC:", np.mean(pnet_aucs))

In [None]:
logging.info("Constructing a DF of AUCs")
auc_df = pd.DataFrame({
    'Group': ['P-NET'] * len(pnet_aucs) +
             ['RF'] * len(rf_eval_auc) +
             ['BDT'] * len(bdt_eval_auc) +
             ['arxiv P-NET (h=6)'] * len(arxiv_pnet_metrics_dict['auc']) +
             ['arxiv GCN'] * len(arxiv_gcn_metrics_dict['auc']),
    'Value': pnet_aucs + rf_eval_auc + bdt_eval_auc + arxiv_pnet_metrics_dict['auc'] + arxiv_gcn_metrics_dict['auc']
})

logging.debug("Sorting model order by decreasing mean")
# Calculate the mean of each group
group_means = auc_df.groupby('Group')['Value'].mean().sort_values(ascending=False)

# Reorder the DataFrame based on the descending group means
auc_df = auc_df.set_index('Group').loc[group_means.index].reset_index()
grouped_auc_df = auc_df.groupby('Group')
display(auc_df)


In [None]:
print("mean\n",auc_df.mean())
print("median\n",auc_df.median())
print("stdev\n",auc_df.std())

In [None]:
# print("mean\n",auc_df.mean())
# print("median\n",auc_df.median())
# print("stdev\n",auc_df.std())
# display(auc_df.groupby('Group').median())
display(grouped_auc_df.median())
display(grouped_auc_df.std())

In [None]:
# SAVENAME = 'pnet_rf_bdt_performance_benchmark_AUC_w_points'
SAVENAME = 'pnet_gcn_arxivPnet_rf_bdt_performance_benchmark_AUC_w_points'
# SAVENAME_nopoints = 'pnet_rf_bdt_performance_benchmark_AUC'
SAVENAME_nopoints = 'pnet_gcn_arxivPnet_rf_bdt_performance_benchmark_AUC'

smallest_median_val = auc_df.groupby('Group').median().min().min()

# Set up the boxplot with points using Seaborn
plt.figure(figsize=(6, 4))
sns.boxplot(x='Group', y='Value',
    data=auc_df, color='gray', showfliers=False, boxprops={'facecolor': 'lightgrey', 'color': 'lightgrey'},
                      whiskerprops={'color': 'gainsboro'},
                      capprops={'color': 'gainsboro'},
                      medianprops={'color': 'dimgrey'})
sns.stripplot(x='Group', y='Value', data=auc_df, color='black', jitter=0.2, alpha=0.3)

# Customize plot appearance
ax = plt.gca()
ax.spines[['top', 'right']].set_visible(False)
plt.axhline(y=smallest_median_val, color='coral', linestyle='--', label=f'y_min = {smallest_median_val}', alpha=0.5)

# Calculate the number of samples in each group
sample_counts = auc_df['Group'].value_counts()
# Create custom x-axis labels with sample counts
x_labels = [f"{group} (n={sample_counts[group]})" for group in sample_counts.index]
plt.xticks(range(len(sample_counts.index)), x_labels, rotation=45)

ax.set_ylabel('AUC', size=14)
ax.set_xlabel('Model', size=14)
ax.set_ylim((0.5, 1))
plt.title("Relative performance on test set")
report_and_eval.savefig(plt, os.path.join(FIGDIR, SAVENAME), png=True, pdf=True)
plt.show()



# Set up the boxplot WITHOUT points using Seaborn
plt.figure(figsize=(6, 4))
sns.boxplot(x='Group', y='Value',
    data=auc_df, color='gray', showfliers=True, boxprops={'facecolor': 'lightgrey', 'color': 'lightgrey'},
                      whiskerprops={'color': 'gainsboro'},
                      capprops={'color': 'gainsboro'},
                      medianprops={'color': 'dimgrey'})

# Customize plot appearance
ax = plt.gca()
ax.spines[['top', 'right']].set_visible(False)
plt.axhline(y=smallest_median_val, color='coral', linestyle='--', label=f'y_min = {smallest_median_val}', alpha=0.5)

# Calculate the number of samples in each group
sample_counts = auc_df['Group'].value_counts()
# Create custom x-axis labels with sample counts
x_labels = [f"{group} (n={sample_counts[group]})" for group in sample_counts.index]
plt.xticks(range(len(sample_counts.index)), x_labels, rotation=45)


ax.set_ylabel('AUC', size=14)
ax.set_xlabel('Model', size=14)
ax.set_ylim((0.5, 1))
plt.title("Relative performance on test set")
report_and_eval.savefig(plt, os.path.join(FIGDIR, SAVENAME_nopoints), png=True, pdf=True)
plt.show()



# Relative stability of feature importance
- have gene-level for P-NET, RF, and BDT
- also have layer/pathway level information for P-NET

In [None]:
import pickle

# Read gene_imps from a Pickle file (format: len 20 list --> pandas DFs, samples x genes?)
with open(os.path.join(PNET_DIR, 'gene_imps.pkl'), 'rb') as file:
    gene_imps = pickle.load(file)

# Read layerwise_imps from a Pickle file (format: len 20 list --> len 5 list --> pandas DF, samples x features)
with open(os.path.join(PNET_DIR, 'layerwise_imps.pkl'), 'rb') as file:
    layerwise_imps = pickle.load(file)

In [None]:
[i.shape for i in layerwise_imps[0]]

In [None]:
pnet_gene_imps = stability_utils.get_pnet_gene_imps(PNET_DIR)
pnet_patient_stabs = stability_utils.calc_perpatient_stability_metric(pnet_gene_imps)

logging.info("Plotting histogram of patient-level stability metric")
# median stdev of top 50 genes (top 50 relative to each patient). 
# Pretty darn unstable. 
# For example, a gene at rank 50 with a stdev of 50 means that 68% of the time, its rank was between 0 and 100.
plt.hist(pnet_patient_stabs)
plt.show()

In [None]:
MODEL_TYPE = "rf"
EVAL_SET = 'test' # val, validation
SAVEDIR = f'../../pnet/results/somatic_{MODEL_TYPE}_eval_set_{EVAL_SET}'

rf_gene_imps = stability_utils.get_sklearn_feature_imps(SAVEDIR)

In [None]:
# PROBLEM: RF and BDT don't have per-patient feature importances. How did Marc deal with this? He said that he normally grouped by the response variable.
# Maybe I'll just look at the overall model-level stability here. That'll be easiest to compare anyway. :/
print(type(pnet_gene_imps[0]))
print(type(rf_gene_imps[0]))

In [None]:
rf_model_stability = stability_utils.calc_model_stability(rf_gene_imps, n_top_genes=50)
rf_model_stability

In [None]:
pd.DataFrame(rf_gene_imps).apply(lambda row: row.abs().rank(ascending=False).sort_values(), axis=1)

In [None]:
# TODO: create sorted DF across all runs
for i in range(len(rf_gene_imps)):
    print(f"Run {i}")
    print(rf_gene_imps[i].rank(ascending=False).sort_values()[:10])
    print("\n")

In [None]:
MODEL_TYPE = "bdt"
EVAL_SET = 'test' # val
SAVEDIR = f'../../pnet/results/somatic_{MODEL_TYPE}_eval_set_{EVAL_SET}'

bdt_gene_imps = stability_utils.get_sklearn_feature_imps(SAVEDIR)
bdt_model_stability = stability_utils.calc_model_stability(bdt_gene_imps, n_top_genes=50)
bdt_model_stability

In [None]:
## Here we create a runs x feature rank DF (values = feature name). This is useful for comparing across runs. The input was a list of series, where each series is the gene imp list from a given run.
# Create a DataFrame to store the ranks
rank_df = pd.DataFrame()
tmp_ranks = []
series_list = bdt_gene_imps
# series_list = rf_gene_imps
# Iterate through each series, calculate ranks, and add to the DataFrame
for i in range(len(series_list)):
    series = series_list[i]
    # Calculate ranks and convert them to integers
    ranks = series.abs().rank(ascending=False, method='dense').astype(int).sort_values()
    tmp_ranks.append(ranks.index.tolist())
    # Add ranks to the DataFrame
    # rank_df[series.name] = ranks

# Display the resulting DataFrame
N = 10
rank_df = pd.DataFrame(tmp_ranks)
display(rank_df.loc[:,:N])

print(rank_df.loc[:,:N].stack().value_counts())
top_unique_features = rank_df.loc[:,:N].stack().unique()
print([i.split("_")[0] for i in top_unique_features])
print(len(top_unique_features))


# Display a gene x rank DF (value = # times that gene had that rank across the N=20 runs)
top_gene_by_rank_consistency_df = rank_df.loc[:,:N].apply(lambda col: col.value_counts()).fillna('').reindex(top_unique_features)
display(top_gene_by_rank_consistency_df)

In [None]:
for run in range(len(rf_gene_imps)):
    top_rank_run1 = rf_gene_imps[run].rank(ascending=False).sort_values().index[:100]
    plt.bar([i.split('_')[0] for i in top_rank_run1], rf_gene_imps[run].loc[top_rank_run1])
    plt.show()

In [None]:
for run in range(len(bdt_gene_imps)):
    top_rank_run1 = bdt_gene_imps[run].rank(ascending=False).sort_values().index[:100]
    plt.bar([i.split('_')[0] for i in top_rank_run1], bdt_gene_imps[run].loc[top_rank_run1])
    plt.show()

In [None]:
logging.info("Exploring the ranks particular features of interest")
tmp = []
for run in range(len(bdt_gene_imps)):
    tmp.append(bdt_gene_imps[run].rank(ascending=False).sort_values()[
        ['MDM4_somatic_amp', 'MDM4_somatic_del', 'MDM4_somatic_mut']
        # ['FGFR1_somatic_amp', 'FGFR1_somatic_del', 'FGFR1_somatic_mut']
        ].tolist())
tmp

## Exploring magnitude of feature importances
- How change from model to model?
- What does the distribution look like across top genes? All genes? Flat, or large drop-off?

In [None]:
run=1
top_rank_run1 = bdt_gene_imps[run].rank(ascending=False).sort_values().index[:10]
plt.bar([i.split('_')[0] for i in top_rank_run1], bdt_gene_imps[run].loc[top_rank_run1])
plt.show()

In [None]:
# TODO: create sorted DF across all runs
for i in range(len(bdt_gene_imps)):
    print(f"Run {i}")
    print(bdt_gene_imps[i].rank(ascending=False).sort_values()[:10])
    print("\n")

In [None]:
tmp_ranking = pd.DataFrame(rf_gene_imps).apply(lambda row: row.abs().rank(ascending=False), axis=1)
tmp_ranking.loc[:,['MDM4_somatic_mut', 'MDM4_somatic_amp', 'MDM4_somatic_del']]

In [None]:
logging.info("Checking that both methods return the same result.")
pnet_rankings = stability_utils.make_perpatient_rankings_dfs(pnet_gene_imps)
print(pnet_rankings[0].loc[:,'MDM4'])
pnet_patient_0_rank = stability_utils.make_pnet_gene_ranking_df(pnet_gene_imps, 0)
pnet_patient_0_rank.loc[:,'MDM4']

# calc_stability_metric_on_runs_by_generank_df(rankings_df, n_top_genes)