In [1]:
# Disable all warnings
import warnings
warnings.filterwarnings('ignore')

In [2]:
import os

if not os.path.exists('../../data/replogle22/replogle22_processed.h5ad'):
    os.system('aws s3 cp s3://shift-personal-dev/henry/icml_data/replogle22/replogle22_processed.h5ad ../../data/replogle22/replogle22_processed.h5ad')

if not os.path.exists('../../data/replogle22/replogle22_names_df_vsrest.pkl'):
    os.system('aws s3 cp s3://shift-personal-dev/henry/icml_data/replogle22/replogle22_names_df_vsrest.pkl ../../data/replogle22/replogle22_names_df_vsrest.pkl')

if not os.path.exists('../../data/replogle22/replogle22_scores_df_vsrest.pkl'):
    os.system('aws s3 cp s3://shift-personal-dev/henry/icml_data/replogle22/replogle22_scores_df_vsrest.pkl ../../data/replogle22/replogle22_scores_df_vsrest.pkl')

if not os.path.exists('../../data/gears_predictions.pkl'):
    os.system('aws s3 cp s3://shift-personal-dev/lucas/icml/gears_predictions.pkl ../../data/gears_predictions.pkl')

if not os.path.exists('../../data/scgpt_predictions.pkl'):
    os.system('aws s3 cp s3://shift-personal-dev/lucas/icml/scgpt_predictions.pkl ../../data/scgpt_predictions.pkl')


In [None]:
import numpy as np
import pandas as pd

# Read the numpy files
try:
    names_df_vsrest = np.load('../../data/replogle22/replogle22_names_df_vsrest.pkl', allow_pickle=True)
    print("Successfully loaded names_df_vsrest")
except Exception as e:
    print(f"Error loading names_df_vsrest: {e}")

try:
    scores_df_vsrest = np.load('../../data/replogle22/replogle22_scores_df_vsrest.pkl', allow_pickle=True)
    print("Successfully loaded scores_df_vsrest")
except Exception as e:
    print(f"Error loading scores_df_vsrest: {e}")


In [None]:
import scanpy as sc
import pandas as pd
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
from tqdm import tqdm
import os
from scipy.stats import ranksums # Added ranksums
import scienceplots

import sys
sys.path.append(os.path.dirname(os.getcwd())) # For finding the 'analyses' package
from common import *


DATASET_NAME = 'replogle22'

# Initialize analysis using the common function
(
    adata,
    pert_means, # This is the dictionary from get_pert_means(adata) 
    total_mean_original,
    ctrl_mean_original,
    DATASET_NAME,
    DATASET_CELL_COUNTS,
    DATASET_PERTS_TO_SWEEP,
    dataset_specific_subdir, # e.g. "norman19" or "replogle22"
    DATA_CACHE_DIR, # Base cache dir, e.g., "../../../data/"
    original_np_random_state,
    ANALYSIS_DIR,
    pert_normalized_abs_scores_vsrest,
    pert_counts,
    scores_df_vsrest,
    names_df_vsrest,
) = initialize_analysis(DATASET_NAME, 'modeling_with_gears')

In [None]:
import pickle

# Load the gears predictions
with open('../../data/gears_predictions.pkl', 'rb') as f:
    gears_predictions = pickle.load(f)

# Load the scGPT predictions
with open('../../data/scgpt_predictions.pkl', 'rb') as f:
    scgpt_predictions = pickle.load(f)


first_half_cells = []
second_half_cells = []
for pert in tqdm(pert_means.keys(), desc="Processing perturbations"):
    # Get all cells for this perturbation
    pert_cells = adata.obs[adata.obs['condition'] == pert].index.tolist()
    
    # Randomly shuffle the cells and split into two halves
    np.random.shuffle(pert_cells)
    split_idx = len(pert_cells) // 2
    first_half_cells.extend(pert_cells[:split_idx])
    second_half_cells.extend(pert_cells[split_idx:])

adata_first_half = adata[first_half_cells].copy()
adata_second_half = adata[second_half_cells].copy()

In [None]:
# Get means for first half and second half
pert_means_first_half = get_pert_means(adata_first_half)
total_mean_first_half = np.mean(list(pert_means_first_half.values()), axis=0)
pert_means_second_half = get_pert_means(adata_second_half)
total_mean_second_half = np.mean(list(pert_means_second_half.values()), axis=0)

In [7]:
# Create dictionaries to store the metrics for each perturbation
pearson_delta_dict_predictive = {}
pearson_delta_degs_dict_predictive = {}
mse_dict_predictive = {}
wmse_dict_predictive = {}
r2_delta_dict_predictive = {}
wr2_delta_dict_predictive = {}

In [8]:
### Ideal predictive baseline: Tech Duplicate oracle ###
# Calculate mean expression for each half
first_half_mean = adata_first_half[adata_first_half.obs['condition'] == pert].X.mean(axis=0).A1
second_half_mean = adata_second_half[adata_second_half.obs['condition'] == pert].X.mean(axis=0).A1

In [None]:
# Skip 'control' perturbation and focus only on actual perturbations
all_perts_for_predictive = [pert for pert in adata.obs['condition'].unique() if pert != 'control' and pert in list(gears_predictions.keys()) and pert in list(scgpt_predictions.keys())] # Added filter for missing GEARS and scGPT perturbations

MIN_DEGS_FOR_METRIC = 200

for pert in tqdm(all_perts_for_predictive, desc="Processing perturbations"):

    ### Ideal predictive baseline: Tech Duplicate oracle ###
    # Calculate mean expression for each half
    first_half_mean = adata_first_half[adata_first_half.obs['condition'] == pert].X.mean(axis=0).A1
    second_half_mean = adata_second_half[adata_second_half.obs['condition'] == pert].X.mean(axis=0).A1

    # Get DEG info
    current_pert_weights = pert_normalized_abs_scores_vsrest.get(pert)
    pert_degs_vsrest = list(set(adata.uns['deg_dict_vsrest'][pert]['up']) | set(adata.uns['deg_dict_vsrest'][pert]['down']))
    pert_degs_vsrest_idx = adata.var_names.isin(pert_degs_vsrest)

    
    # Calculate basic metrics between the two halves
    mse_dict_predictive[pert] = mse(first_half_mean, second_half_mean)    
    if pert_degs_vsrest_idx.sum() > MIN_DEGS_FOR_METRIC:
        wmse_dict_predictive[pert] = wmse(first_half_mean, second_half_mean, current_pert_weights)
    else:
        wmse_dict_predictive[pert] = np.nan
    
    # Calculate delta metrics
    delta_first_half = first_half_mean - total_mean_first_half
    delta_second_half = second_half_mean - total_mean_first_half
    
    # Get Pearson delta for all and just DEGs
    pearson_delta_dict_predictive[pert] = pearson(delta_second_half, delta_first_half)
    if pert_degs_vsrest_idx.sum() > MIN_DEGS_FOR_METRIC:
        pearson_delta_degs_dict_predictive[pert] = pearson(delta_first_half[pert_degs_vsrest_idx], delta_second_half[pert_degs_vsrest_idx])
    else:
        pearson_delta_degs_dict_predictive[pert] = np.nan

    # Get R2 with and without weights
    r2_delta_dict_predictive[pert] = r2_score_on_deltas(delta_second_half, delta_first_half)
    if pert_degs_vsrest_idx.sum() > MIN_DEGS_FOR_METRIC:
        wr2_delta_dict_predictive[pert] = r2_score_on_deltas(delta_second_half, delta_first_half, current_pert_weights)
    else:
        wr2_delta_dict_predictive[pert] = np.nan

    ### Null baseline: Prediction of data mean ###
    # Add "Data Mean" condition metrics - use total_mean_original
    datamean_key = f"{pert}_datamean"
    
    # Calculate basic metrics between data mean and second half
    mse_dict_predictive[datamean_key] = mse(total_mean_original, second_half_mean)
    if pert_degs_vsrest_idx.sum() > MIN_DEGS_FOR_METRIC:
        wmse_dict_predictive[datamean_key] = wmse(total_mean_original, second_half_mean, current_pert_weights)
    else:
        wmse_dict_predictive[datamean_key] = np.nan

    
    # Delta metrics for data mean condition
    # delta_data_mean would be zeros (total_mean_original - total_mean_original)
    # So all delta metrics are 0
    pearson_delta_dict_predictive[datamean_key] = 0.0  # Explicitly set to 0
    pearson_delta_degs_dict_predictive[datamean_key] = 0.0  # Explicitly set to 0
    
    # Get the R2 and weighted R2 for delta_data_mean
    delta_data_mean = total_mean_first_half - total_mean_first_half
    delta_second_half = second_half_mean - total_mean_first_half
    r2_delta_dict_predictive[datamean_key] = r2_score_on_deltas(delta_second_half, delta_data_mean)
    if pert_degs_vsrest_idx.sum() > MIN_DEGS_FOR_METRIC:
        wr2_delta_dict_predictive[datamean_key] = r2_score_on_deltas(delta_second_half, delta_data_mean, current_pert_weights)
    else:
        wr2_delta_dict_predictive[datamean_key] = np.nan


    ### Control baseline: Prediction of control mean ###
    # Add "Control" condition metrics - use control mean instead of first half
    control_key = f"{pert}_control"
    
    # Calculate basic metrics between control mean and second half
    mse_dict_predictive[control_key] = mse(ctrl_mean_original, second_half_mean)
    if pert_degs_vsrest_idx.sum() > MIN_DEGS_FOR_METRIC:
        wmse_dict_predictive[control_key] = wmse(ctrl_mean_original, second_half_mean, current_pert_weights)
    else:
        wmse_dict_predictive[control_key] = np.nan
    
    # Calculate delta metrics
    delta_control = ctrl_mean_original - total_mean_first_half
    delta_second_half = second_half_mean - total_mean_first_half
    
    # Get Pearson delta for all and just DEGs
    pearson_delta_dict_predictive[control_key] = pearson(delta_control, delta_second_half)
    pert_degs_vsrest = list(set(adata.uns['deg_dict_vsrest'][pert]['up']) | set(adata.uns['deg_dict_vsrest'][pert]['down']))
    pert_degs_vsrest_idx = adata.var_names.isin(pert_degs_vsrest)
    if pert_degs_vsrest_idx.sum() > MIN_DEGS_FOR_METRIC:
        pearson_delta_degs_dict_predictive[control_key] = pearson(delta_control[pert_degs_vsrest_idx], delta_second_half[pert_degs_vsrest_idx])
    else:
        pearson_delta_degs_dict_predictive[control_key] = np.nan
    
    # Get R2 with and without weights
    r2_delta_dict_predictive[control_key] = r2_score_on_deltas(delta_second_half, delta_control)
    if pert_degs_vsrest_idx.sum() > MIN_DEGS_FOR_METRIC:
        wr2_delta_dict_predictive[control_key] = r2_score_on_deltas(delta_second_half, delta_control, current_pert_weights)
    else:
        wr2_delta_dict_predictive[control_key] = np.nan

    ### GEARS: Prediction of unseen single genes  ###
    gears_key = f"{pert}_gears"

    gears_mean = gears_predictions.get(pert)
    
    # Calculate basic metrics between control mean and second half
    mse_dict_predictive[gears_key] = mse(gears_mean, second_half_mean)
    if pert_degs_vsrest_idx.sum() > MIN_DEGS_FOR_METRIC:
        wmse_dict_predictive[gears_key] = wmse(gears_mean, second_half_mean, current_pert_weights)
    else:
        wmse_dict_predictive[gears_key] = np.nan
    
    # Calculate delta metrics
    delta_gears = gears_mean - total_mean_first_half
    delta_second_half = second_half_mean - total_mean_first_half
    
    # Get Pearson delta for all and just DEGs
    pearson_delta_dict_predictive[gears_key] = pearson(delta_gears, delta_second_half)
    pert_degs_vsrest = list(set(adata.uns['deg_dict_vsrest'][pert]['up']) | set(adata.uns['deg_dict_vsrest'][pert]['down']))
    pert_degs_vsrest_idx = adata.var_names.isin(pert_degs_vsrest)
    if pert_degs_vsrest_idx.sum() > MIN_DEGS_FOR_METRIC:
        pearson_delta_degs_dict_predictive[gears_key] = pearson(delta_gears[pert_degs_vsrest_idx], delta_second_half[pert_degs_vsrest_idx])
    else:
        pearson_delta_degs_dict_predictive[gears_key] = np.nan
    
    # Get R2 with and without weights
    r2_delta_dict_predictive[gears_key] = r2_score_on_deltas(delta_second_half, delta_gears)
    if pert_degs_vsrest_idx.sum() > MIN_DEGS_FOR_METRIC:
        wr2_delta_dict_predictive[gears_key] = r2_score_on_deltas(delta_second_half, delta_gears, current_pert_weights)
    else:
        wr2_delta_dict_predictive[gears_key] = np.nan

    ### scGPT: Prediction of unseen single genes  ###
    # Add "Control" condition metrics - use control mean instead of first half
    scgpt_key = f"{pert}_scgpt"

    scgpt_mean = scgpt_predictions.get(pert)
    
    # Calculate basic metrics between control mean and second half
    mse_dict_predictive[scgpt_key] = mse(scgpt_mean, second_half_mean)
    if pert_degs_vsrest_idx.sum() > MIN_DEGS_FOR_METRIC:
        wmse_dict_predictive[scgpt_key] = wmse(scgpt_mean, second_half_mean, current_pert_weights)
    else:
        wmse_dict_predictive[scgpt_key] = np.nan
    
    # Calculate delta metrics
    delta_control = second_half_mean - total_mean_original
    delta_scgpt = scgpt_mean - total_mean_original
    
    # Get Pearson delta for all and just DEGs
    pearson_delta_dict_predictive[scgpt_key] = pearson(delta_scgpt, delta_second_half)
    pert_degs_vsrest = list(set(adata.uns['deg_dict_vsrest'][pert]['up']) | set(adata.uns['deg_dict_vsrest'][pert]['down']))
    pert_degs_vsrest_idx = adata.var_names.isin(pert_degs_vsrest)
    if pert_degs_vsrest_idx.sum() > MIN_DEGS_FOR_METRIC:
        pearson_delta_degs_dict_predictive[scgpt_key] = pearson(delta_scgpt[pert_degs_vsrest_idx], delta_second_half[pert_degs_vsrest_idx])
    else:
        pearson_delta_degs_dict_predictive[scgpt_key] = np.nan
    
    # Get R2 with and without weights
    r2_delta_dict_predictive[scgpt_key] = r2_score_on_deltas(delta_second_half, delta_scgpt)
    if pert_degs_vsrest_idx.sum() > MIN_DEGS_FOR_METRIC:
        wr2_delta_dict_predictive[scgpt_key] = r2_score_on_deltas(delta_second_half, delta_scgpt, current_pert_weights)
    else:
        wr2_delta_dict_predictive[scgpt_key] = np.nan


In [24]:
# Create plots for the predictive baseline metrics
PLOT_DIR = f'{ANALYSIS_DIR}/plots/'
os.makedirs(PLOT_DIR, exist_ok=True)

# Process data for plotting
# Split keys into three groups based on suffix - regular, _control, and _datamean
regular_keys = [key for key in mse_dict_predictive.keys() if '_control' not in key and '_datamean' not in key and '_gears' not in key]
control_keys = [key for key in mse_dict_predictive.keys() if '_control' in key]
datamean_keys = [key for key in mse_dict_predictive.keys() if '_datamean' in key]
gears_keys = [key for key in mse_dict_predictive.keys() if '_gears' in key]
scgpt_keys = [key for key in mse_dict_predictive.keys() if '_scgpt' in key]

# Create restructured dataframes for side-by-side condition comparison
# For main metrics
data_for_plotting = []

# Process MSE
for key in regular_keys:
    base_pert = key
    control_key = f"{base_pert}_control"
    datamean_key = f"{base_pert}_datamean"
    gears_key = f"{base_pert}_gears"
    scgpt_key = f"{base_pert}_scgpt"
    
    if control_key in control_keys and datamean_key in datamean_keys and gears_key in gears_keys:
        data_for_plotting.append({
            'Perturbation': base_pert,
            'Metric': 'MSE',
            'Condition': 'Tech Duplicate',
            'Value': mse_dict_predictive[key]
        })
        data_for_plotting.append({
            'Perturbation': base_pert,
            'Metric': 'MSE',
            'Condition': '$\mu^c$ (ctrl mean)',
            'Value': mse_dict_predictive[control_key]
        })
        data_for_plotting.append({
            'Perturbation': base_pert,
            'Metric': 'MSE',
            'Condition': '$\mu^{all}$ (perts mean)',
            'Value': mse_dict_predictive[datamean_key]
        })
        data_for_plotting.append({
            'Perturbation': base_pert,
            'Metric': 'MSE',
            'Condition': 'GEARS',
            'Value': mse_dict_predictive[gears_key]
        })
        data_for_plotting.append({
            'Perturbation': base_pert,
            'Metric': 'MSE',
            'Condition': 'scGPT',
            'Value': mse_dict_predictive[scgpt_key]
        })

# Process WMSE
for key in regular_keys:
    base_pert = key
    control_key = f"{base_pert}_control"
    datamean_key = f"{base_pert}_datamean"
    gears_key = f"{base_pert}_gears"
    scgpt_key = f"{base_pert}_scgpt"
    
    if control_key in control_keys and datamean_key in datamean_keys and gears_key in gears_keys:
        data_for_plotting.append({
            'Perturbation': base_pert,
            'Metric': 'WMSE',
            'Condition': 'Tech Duplicate',
            'Value': wmse_dict_predictive[key]
        })
        data_for_plotting.append({
            'Perturbation': base_pert,
            'Metric': 'WMSE',
            'Condition': '$\mu^c$ (ctrl mean)',
            'Value': wmse_dict_predictive[control_key]
        })
        data_for_plotting.append({
            'Perturbation': base_pert,
            'Metric': 'WMSE',
            'Condition': '$\mu^{all}$ (perts mean)',
            'Value': wmse_dict_predictive[datamean_key]
        })
        data_for_plotting.append({
            'Perturbation': base_pert,
            'Metric': 'WMSE',
            'Condition': 'GEARS',
            'Value': wmse_dict_predictive[gears_key]
        })
        data_for_plotting.append({
            'Perturbation': base_pert,
            'Metric': 'WMSE',
            'Condition': 'scGPT',
            'Value': wmse_dict_predictive[scgpt_key]
        })

# Process Pearson Delta
for key in regular_keys:
    base_pert = key
    control_key = f"{base_pert}_control"
    datamean_key = f"{base_pert}_datamean"
    gears_key = f"{base_pert}_gears"
    scgpt_key = f"{base_pert}_scgpt"
    
    if control_key in control_keys and datamean_key in datamean_keys and gears_key in gears_keys:
        data_for_plotting.append({
            'Perturbation': base_pert,
            'Metric': 'Pearson Delta',
            'Condition': 'Tech Duplicate',
            'Value': pearson_delta_dict_predictive[key]
        })
        data_for_plotting.append({
            'Perturbation': base_pert,
            'Metric': 'Pearson Delta',
            'Condition': '$\mu^c$ (ctrl mean)',
            'Value': pearson_delta_dict_predictive[control_key]
        })
        data_for_plotting.append({
            'Perturbation': base_pert,
            'Metric': 'Pearson Delta',
            'Condition': '$\mu^{all}$ (perts mean)',
            'Value': pearson_delta_dict_predictive[datamean_key]
        })
        data_for_plotting.append({
            'Perturbation': base_pert,
            'Metric': 'Pearson Delta',
            'Condition': 'GEARS',
            'Value': pearson_delta_dict_predictive[gears_key]
        })
        data_for_plotting.append({
            'Perturbation': base_pert,
            'Metric': 'Pearson Delta',
            'Condition': 'scGPT',
            'Value': pearson_delta_dict_predictive[scgpt_key]
        })

# Process Pearson Delta DEGs
for key in regular_keys:
    base_pert = key
    control_key = f"{base_pert}_control"
    datamean_key = f"{base_pert}_datamean"
    gears_key = f"{base_pert}_gears"
    scgpt_key = f"{base_pert}_scgpt"
    
    if control_key in control_keys and datamean_key in datamean_keys and gears_key in gears_keys:
        data_for_plotting.append({
            'Perturbation': base_pert,
            'Metric': 'Pearson Delta DEGs',
            'Condition': 'Tech Duplicate',
            'Value': pearson_delta_degs_dict_predictive[key]
        })
        data_for_plotting.append({
            'Perturbation': base_pert,
            'Metric': 'Pearson Delta DEGs',
            'Condition': '$\mu^c$ (ctrl mean)',
            'Value': pearson_delta_degs_dict_predictive[control_key]
        })
        data_for_plotting.append({
            'Perturbation': base_pert,
            'Metric': 'Pearson Delta DEGs',
            'Condition': '$\mu^{all}$ (perts mean)',
            'Value': pearson_delta_degs_dict_predictive[datamean_key]
        })
        data_for_plotting.append({
            'Perturbation': base_pert,
            'Metric': 'Pearson Delta DEGs',
            'Condition': 'GEARS',
            'Value': pearson_delta_degs_dict_predictive[gears_key]
        })
        data_for_plotting.append({
            'Perturbation': base_pert,
            'Metric': 'Pearson Delta DEGs',
            'Condition': 'scGPT',
            'Value': pearson_delta_degs_dict_predictive[scgpt_key]
        })
           
# Process R-Squared Delta
for key in regular_keys:
    base_pert = key
    control_key = f"{base_pert}_control"
    datamean_key = f"{base_pert}_datamean"
    gears_key = f"{base_pert}_gears"
    scgpt_key = f"{base_pert}_scgpt"
    
    if control_key in control_keys and datamean_key in datamean_keys and gears_key in gears_keys:
        data_for_plotting.append({
            'Perturbation': base_pert,
            'Metric': 'R-Squared Delta',
            'Condition': 'Tech Duplicate',
            'Value': r2_delta_dict_predictive[key]
        })
        data_for_plotting.append({
            'Perturbation': base_pert,
            'Metric': 'R-Squared Delta',
            'Condition': '$\mu^c$ (ctrl mean)',
            'Value': r2_delta_dict_predictive[control_key]
        })
        data_for_plotting.append({
            'Perturbation': base_pert,
            'Metric': 'R-Squared Delta',
            'Condition': '$\mu^{all}$ (perts mean)',
            'Value': r2_delta_dict_predictive[datamean_key]
        })
        data_for_plotting.append({
            'Perturbation': base_pert,
            'Metric': 'R-Squared Delta',
            'Condition': 'GEARS',
            'Value': r2_delta_dict_predictive[gears_key]
        })
        data_for_plotting.append({
            'Perturbation': base_pert,
            'Metric': 'R-Squared Delta',
            'Condition': 'scGPT',
            'Value': r2_delta_dict_predictive[scgpt_key]
        })

# Process weighted R-Squared Delta
for key in regular_keys:
    base_pert = key
    control_key = f"{base_pert}_control"
    datamean_key = f"{base_pert}_datamean"
    gears_key = f"{base_pert}_gears"
    scgpt_key = f"{base_pert}_scgpt"
    
    if control_key in control_keys and datamean_key in datamean_keys and gears_key in gears_keys:
        data_for_plotting.append({
            'Perturbation': base_pert,
            'Metric': 'Weighted R-Squared Delta',
            'Condition': 'Tech Duplicate',
            'Value': wr2_delta_dict_predictive[key]
        })
        data_for_plotting.append({
            'Perturbation': base_pert,
            'Metric': 'Weighted R-Squared Delta',
            'Condition': '$\mu^c$ (ctrl mean)',
            'Value': wr2_delta_dict_predictive[control_key]
        })
        data_for_plotting.append({
            'Perturbation': base_pert,
            'Metric': 'Weighted R-Squared Delta',
            'Condition': '$\mu^{all}$ (perts mean)',
            'Value': wr2_delta_dict_predictive[datamean_key]
        })
        data_for_plotting.append({
            'Perturbation': base_pert,
            'Metric': 'Weighted R-Squared Delta',
            'Condition': 'GEARS',
            'Value': wr2_delta_dict_predictive[gears_key]
        })
        data_for_plotting.append({
            'Perturbation': base_pert,
            'Metric': 'Weighted R-Squared Delta',
            'Condition': 'scGPT',
            'Value': wr2_delta_dict_predictive[scgpt_key]
        })

# Create main DataFrame for plotting
df_for_plotting = pd.DataFrame(data_for_plotting)

In [25]:

# Function to create comparison violin plots for the three conditions
def plot_predictive_conditions_boxplot(df, metric_name, y_label, plot_title, plot_dir, dataset_name, plot_suffix=''):
    # Filter for just this metric
    df_metric = df[df['Metric'] == metric_name].copy()
    
    plt.figure(figsize=(8, 7))
    ax = plt.gca()
    
    # Define colors for conditions
    condition_colors = {
        'Tech Duplicate': 'steelblue',
        '$\mu^c$ (ctrl mean)': 'forestgreen',
        '$\mu^{all}$ (perts mean)': 'indianred',
        'GEARS': 'purple',
        'scGPT': 'orange' 
    }
    
    # Define the condition order as requested
    condition_order = ['$\mu^c$ (ctrl mean)', '$\mu^{all}$ (perts mean)', 'GEARS', 'scGPT', 'Tech Duplicate']  # Added GEARS to order
    
    # Create violin plots with conditions side by side
    violinplot = sns.violinplot(
        x='Condition', 
        y='Value', 
        data=df_metric,
        palette=condition_colors,
        ax=ax,
        order=condition_order,
        inner='quartile',  # Show quartiles inside the violins
        cut=0              # Don't extend beyond observed data
    )
    
    # Add individual points
    sns.stripplot(
        x='Condition', 
        y='Value', 
        data=df_metric,
        color='black', 
        size=3, 
        alpha=0.3,
        ax=ax,
        dodge=True,
        order=condition_order
    )
    
    # Add mean values for each condition in black, using Greek μ (mu) symbol
    for i, condition in enumerate(condition_order):
        condition_data = df_metric[df_metric['Condition'] == condition]['Value']
        if not condition_data.empty:
            median_val = condition_data.median()
            mean_val = condition_data.mean()
            # if not np.isnan(mean_val):
            #     yloc = mean_val * 1.02 if mean_val > -1 else -.94
            #     ax.text(
            #         i + 0.15, yloc, 
            #         f'μ: {mean_val:.3f}', 
            #         color='black',
            #         fontweight='bold',
            #         ha='left', 
            #         va='bottom'
            #     )
            if not np.isnan(median_val):
                yloc = median_val * 1.02 if median_val > -1 else -.94
                ax.text(
                    i + 0.15, yloc, 
                    f'Med: {median_val:.3f}', 
                    color='black',
                    fontweight='bold',
                    ha='left', 
                    va='bottom',
                    fontsize=8  # Added smaller font size
                )
    
    # Add a horizontal line at y=0 for R-squared and Pearson delta plots
    if metric_name in ['R-Squared', 'Pearson Delta', "Pearson Delta DEGs", 'R-Squared Delta', 'Weighted R-Squared Delta']:
        ax.axhline(y=0, color='firebrick', linestyle='--', linewidth=0.8, zorder=20, alpha=0.7)
        ax.set_ylim(-1.05, 1.05) # Set Y-axis from -1 to 1, with a little padding

        # Count and annotate points below -1 for each condition
        for i, condition in enumerate(condition_order):
            condition_data = df_metric[df_metric['Condition'] == condition]
            num_outliers = (condition_data['Value'] < -1).sum()
            if num_outliers > 0:
                ax.text(
                    i + 0.15, -1, 
                    f'N < -1: {num_outliers}', 
                    color='black',
                    fontweight='bold',
                    ha='left', 
                    va='bottom',
                    fontsize=8  # Added smaller font size to match median labels
                )
    
    # Format plot
    plt.title(f'{plot_title} ({dataset_name})', fontsize=14)
    plt.ylabel(y_label, fontsize=12)
    plt.grid(axis='y', alpha=0.3)
    
    # Add count of perturbations
    unique_perts = df_metric['Perturbation'].nunique()
    
    # Incorporate the optional suffix in the filename to differentiate between regular and DEG plots
    filename = f"condition_comparison_{metric_name.lower().replace(' ', '_')}"
    # Check if DEGs name
    
    if plot_suffix:
        filename += f"_{plot_suffix}"
    plot_path = f"{plot_dir}/{filename}.pdf"
    
    plt.tight_layout()
    plt.savefig(plot_path, dpi=300)
    print(f"Plot saved to {plot_path}")
    plt.show()
    plt.close()

In [None]:
plot_predictive_conditions_boxplot(df_for_plotting, 'MSE', 'MSE (vs Second Half)', 'MSE (vs Second Half)', PLOT_DIR, DATASET_NAME)
plot_predictive_conditions_boxplot(df_for_plotting, 'WMSE', 'Weighted MSE (vs Second Half)', 'Weighted MSE (vs Second Half)', PLOT_DIR, DATASET_NAME)

In [None]:
plot_predictive_conditions_boxplot(df_for_plotting, 'R-Squared Delta', r'$R^2$ Delta ($\mu_{total}$ as delta control)', r'$R^2$ Delta (vs Second Half)', PLOT_DIR, DATASET_NAME)
plot_predictive_conditions_boxplot(df_for_plotting, 'Weighted R-Squared Delta', r'Weighted $R^2$ Delta ($\mu_{total}$ as delta control)', r'Weighted $R^2$ Delta (vs Second Half)', PLOT_DIR, DATASET_NAME)

In [None]:
plot_predictive_conditions_boxplot(df_for_plotting, 'Pearson Delta', r'Pearson Delta ($\mu_{total}$ as delta control)', r'Pearson Delta (vs Second Half)', PLOT_DIR, DATASET_NAME)
plot_predictive_conditions_boxplot(df_for_plotting, 'Pearson Delta DEGs', r'Pearson Delta DEGs ($\mu_{total}$ as delta control)', r'Pearson Delta DEGs (vs Second Half)', PLOT_DIR, DATASET_NAME)

In [None]:
# Find the perturbation which has the highest pearson delta DEGs in GEARS predictions
# Find the perturbation which has the highest pearson delta DEGs in GEARS predictions
gears_pearson_delta_degs = {pert.replace('_gears', ''): value 
                            for pert, value in pearson_delta_degs_dict_predictive.items() 
                            if '_gears' in pert and not pd.isna(value)}

if gears_pearson_delta_degs:
    max_pert_gears = max(gears_pearson_delta_degs, key=gears_pearson_delta_degs.get)
    max_value_gears = gears_pearson_delta_degs[max_pert_gears]
    print(f"Perturbation with highest Pearson delta DEGs (GEARS): {max_pert_gears}, Value: {max_value_gears}")
else:
    print("No GEARS predictions found or all values are NaN.")


In [None]:
# Show the corelation plot between the GEARS prediction and ground truth for the selected perturbation
if gears_pearson_delta_degs:
    # Get the GEARS prediction and ground truth for the best perturbation
    selected_pert = max_pert_gears
    
    # Get GEARS prediction
    gears_pred = gears_predictions[selected_pert]
    
    # Get ground truth (second half mean)
    ground_truth = adata_second_half[adata_second_half.obs['condition'] == selected_pert].X.mean(axis=0).A1
    
    # Get DEGs for this perturbation
    pert_degs = list(set(adata.uns['deg_dict_vsrest'][selected_pert]['up']) | 
                     set(adata.uns['deg_dict_vsrest'][selected_pert]['down']))
    pert_degs_idx = adata.var_names.isin(pert_degs)
    
    # Get weights for this perturbation
    current_pert_weights = pert_normalized_abs_scores_vsrest.get(selected_pert)
    
    # Calculate deltas
    delta_gears = gears_pred - total_mean_first_half
    delta_ground_truth = ground_truth - total_mean_first_half
    
    # Create figure with two subplots
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 5))
    
    # Plot 1: All genes
    ax1.scatter(delta_ground_truth, delta_gears, alpha=0.3, s=1, color='gray', label='All genes')
    
    # Add diagonal line
    lims = [min(ax1.get_xlim()[0], ax1.get_ylim()[0]),
            max(ax1.get_xlim()[1], ax1.get_ylim()[1])]
    ax1.plot(lims, lims, 'k--', alpha=0.5, zorder=0)
    
    # Calculate and display correlation and R2
    corr_all = pearson(delta_ground_truth, delta_gears)
    r2_all = r2_score_on_deltas(delta_ground_truth, delta_gears, current_pert_weights)
    ax1.set_xlabel('Ground Truth (Δ Expression)', fontsize=12)
    ax1.set_ylabel('GEARS Prediction (Δ Expression)', fontsize=12)
    ax1.set_title(f'{selected_pert} - All Genes\nPearson r = {corr_all:.3f}, R² = {r2_all:.3f}', fontsize=14)
    ax1.grid(True, alpha=0.3)
    
    # Plot 2: DEGs only
    ax2.scatter(delta_ground_truth[pert_degs_idx], delta_gears[pert_degs_idx], 
                alpha=0.6, s=10, color='darkred', label='DEGs')
    
    # Add diagonal line
    lims = [min(ax2.get_xlim()[0], ax2.get_ylim()[0]),
            max(ax2.get_xlim()[1], ax2.get_ylim()[1])]
    ax2.plot(lims, lims, 'k--', alpha=0.5, zorder=0)
    
    # Calculate and display correlation and R2 for DEGs
    corr_degs = pearson(delta_ground_truth[pert_degs_idx], delta_gears[pert_degs_idx])
    r2_degs = r2_score_on_deltas(delta_ground_truth[pert_degs_idx], delta_gears[pert_degs_idx], 
                       current_pert_weights[pert_degs_idx])
    ax2.set_xlabel('Ground Truth (Δ Expression)', fontsize=12)
    ax2.set_ylabel('GEARS Prediction (Δ Expression)', fontsize=12)
    ax2.set_title(f'{selected_pert} - DEGs Only ({pert_degs_idx.sum()} genes)\nPearson r = {corr_degs:.3f}, R² = {r2_degs:.3f}', fontsize=14)
    ax2.grid(True, alpha=0.3)
    
    plt.tight_layout()
    
    # Save the plot
    plt.savefig(f'{PLOT_DIR}/gears_correlation_best_pert_{selected_pert}.pdf', dpi=300, bbox_inches='tight')
    plt.savefig(f'{PLOT_DIR}/gears_correlation_best_pert_{selected_pert}.png', dpi=300, bbox_inches='tight')
    
    plt.show()
    
    # Print additional statistics
    print(f"\nStatistics for {selected_pert}:")
    print(f"Number of cells in ground truth: {(adata_second_half.obs['condition'] == selected_pert).sum()}")
    print(f"Number of DEGs: {pert_degs_idx.sum()}")
    print(f"MSE (all genes): {mse_dict_predictive[f'{selected_pert}_gears']:.6f}")
    print(f"WMSE (weighted by DEGs): {wmse_dict_predictive[f'{selected_pert}_gears']:.6f}")
    print(f"R² (all genes): {r2_all:.6f}")
    print(f"R² (DEGs only): {r2_degs:.6f}")


In [None]:
selected_pert = np.random.choice(adata_second_half.obs['condition'].unique())
print(f"Selected perturbation: {selected_pert}")
second_half_mean = adata_second_half[adata_second_half.obs['condition'] == selected_pert].X.mean(axis=0).A1
first_half_mean = adata_first_half[adata_first_half.obs['condition'] == selected_pert].X.mean(axis=0).A1


# Get DEGs for this perturbation
pert_degs = list(set(adata.uns['deg_dict_vsrest'][selected_pert]['up']) | 
                    set(adata.uns['deg_dict_vsrest'][selected_pert]['down']))
pert_degs_idx = adata.var_names.isin(pert_degs)

# Get weights for this perturbation
current_pert_weights = pert_normalized_abs_scores_vsrest.get(selected_pert)

# Calculate deltas
delta_first_half = first_half_mean - total_mean_first_half
delta_second_half = second_half_mean - total_mean_first_half

# Create figure with two subplots
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 5))

# Plot 1: All genes
ax1.scatter(delta_second_half, delta_first_half, alpha=0.3, s=1, color='gray', label='All genes')

# Add diagonal line
lims = [min(ax1.get_xlim()[0], ax1.get_ylim()[0]),
        max(ax1.get_xlim()[1], ax1.get_ylim()[1])]
ax1.plot(lims, lims, 'k--', alpha=0.5, zorder=0)

# Calculate and display correlation and R2
corr_all = pearson(delta_second_half, delta_first_half)
r2_all = r2_score_on_deltas(delta_second_half, delta_first_half)
r2_all_weighted = r2_score_on_deltas(delta_second_half, delta_first_half, current_pert_weights)
ax1.set_xlabel('Second Half (Δ Expression)', fontsize=12)
ax1.set_ylabel('First Half (Δ Expression)', fontsize=12)
ax1.set_title(f'{selected_pert} - All Genes\nPearson r = {corr_all:.3f}, R² = {r2_all:.3f}, R² weighted = {r2_all_weighted:.3f}', fontsize=14)
ax1.grid(True, alpha=0.3)

# Plot 2: DEGs only
ax2.scatter(delta_second_half[pert_degs_idx], delta_first_half[pert_degs_idx], 
            alpha=0.6, s=10, color='darkred', label='DEGs')

# Add diagonal line
lims = [min(ax2.get_xlim()[0], ax2.get_ylim()[0]),
        max(ax2.get_xlim()[1], ax2.get_ylim()[1])]
ax2.plot(lims, lims, 'k--', alpha=0.5, zorder=0)

# Calculate and display correlation and R2 for DEGs
corr_degs = pearson(delta_second_half[pert_degs_idx], delta_first_half[pert_degs_idx])
r2_degs = r2_score_on_deltas(delta_second_half[pert_degs_idx], delta_first_half[pert_degs_idx])
r2_degs_weighted = r2_score_on_deltas(delta_second_half[pert_degs_idx], delta_first_half[pert_degs_idx], 
                    current_pert_weights[pert_degs_idx])
ax2.set_xlabel('Second Half (Δ Expression)', fontsize=12)
ax2.set_ylabel('First Half (Δ Expression)', fontsize=12)
ax2.set_title(f'{selected_pert} - DEGs Only ({pert_degs_idx.sum()} genes)\nPearson r = {corr_degs:.3f}, R² = {r2_degs:.3f}, R² weighted = {r2_degs_weighted:.3f}', fontsize=14)
ax2.grid(True, alpha=0.3)

plt.tight_layout()

plt.show()

# Print additional statistics
print(f"\nStatistics for {selected_pert}:")
print(f"Number of cells in ground truth: {(adata_second_half.obs['condition'] == selected_pert).sum()}")
print(f"Number of DEGs: {pert_degs_idx.sum()}")
print(f"MSE (all genes): {mse_dict_predictive[f'{selected_pert}_gears']:.6f}")
print(f"WMSE (weighted by DEGs): {wmse_dict_predictive[f'{selected_pert}_gears']:.6f}")
print(f"R² (all genes): {r2_all:.6f}")
print(f"R² (DEGs only): {r2_degs:.6f}")
