In [1]:
# Disable all warnings
import warnings
warnings.filterwarnings('ignore')

In [2]:
import os

if not os.path.exists('../../data/norman19/norman19_processed.h5ad'):
    os.system('aws s3 cp s3://shift-personal-dev/henry/icml_data/norman19/norman19_processed.h5ad ../../data/norman19/norman19_processed.h5ad')

if not os.path.exists('../../data/norman19/norman19_names_df_vsrest.pkl'):
    os.system('aws s3 cp s3://shift-personal-dev/henry/icml_data/norman19/norman19_names_df_vsrest.pkl ../../data/norman19/norman19_names_df_vsrest.pkl')

if not os.path.exists('../../data/norman19/norman19_scores_df_vsrest.pkl'):
    os.system('aws s3 cp s3://shift-personal-dev/henry/icml_data/norman19/norman19_scores_df_vsrest.pkl ../../data/norman19/norman19_scores_df_vsrest.pkl')

In [3]:
import numpy as np
import pandas as pd

# Read the numpy files
try:
    names_df_vsrest = np.load('../../data/norman19/norman19_names_df_vsrest.pkl', allow_pickle=True)
    print("Successfully loaded names_df_vsrest")
except Exception as e:
    print(f"Error loading names_df_vsrest: {e}")

try:
    scores_df_vsrest = np.load('../../data/norman19/norman19_scores_df_vsrest.pkl', allow_pickle=True)
    print("Successfully loaded scores_df_vsrest")
except Exception as e:
    print(f"Error loading scores_df_vsrest: {e}")


Successfully loaded names_df_vsrest
Successfully loaded scores_df_vsrest


In [None]:
import scanpy as sc
import pandas as pd
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
from tqdm import tqdm
import os
from scipy.stats import ranksums # Added ranksums
import scienceplots
import pickle

import sys
sys.path.append(os.path.dirname(os.getcwd())) # For finding the 'analyses' package
from common import *


DATASET_NAME = 'norman19'

# Initialize analysis using the common function
(
    adata,
    pert_means, # This is the dictionary from get_pert_means(adata) 
    total_mean_original,
    ctrl_mean_original,
    DATASET_NAME,
    DATASET_CELL_COUNTS,
    DATASET_PERTS_TO_SWEEP,
    dataset_specific_subdir, # e.g. "norman19" or "replogle22"
    DATA_CACHE_DIR, # Base cache dir, e.g., "../../../data/"
    original_np_random_state,
    ANALYSIS_DIR,
    pert_normalized_abs_scores_vsrest,
    pert_counts,
    scores_df_vsrest,
    names_df_vsrest,
) = initialize_analysis(DATASET_NAME, 'modeling_with_gears')

loss_weights_dict = {}
for key in pert_normalized_abs_scores_vsrest.keys():
    new_key = key+'+ctrl' if '+' not in key else key
    loss_weights_dict[new_key] = 100*pert_normalized_abs_scores_vsrest.get(key).values

adata.var['gene_name'] = adata.var.index.tolist()

In [None]:
# Import required libraries
from gears import PertData
from gears_icml import GEARS
import pickle

# Check if all prediction files exist
prediction_files = [
    '../../data/gears_predictions_mse_unweighted_norman19.pkl',
    '../../data/gears_predictions_mse_weighted_norman19.pkl',
    '../../data/gears_predictions_default_loss_unweighted_norman19.pkl',
    '../../data/gears_predictions_default_loss_weighted_norman19.pkl'
]

if all(os.path.exists(f) for f in prediction_files):
    print("All prediction files exist. Loading predictions...")
    gears_predictions = {}
    for loss in ['default_loss', 'mse']:
        for weight in ['unweighted', 'weighted']:
            with open(f'../../data/gears_predictions_{loss}_{weight}_norman19.pkl', 'rb') as f:
                gears_predictions[f'{loss}_{weight}'] = pickle.load(f)
else:
    # ===================GEARS DATA PREPARATION ===================
    print("Starting GEARS data preparation...")

    gears_predictions = {}

    # Create a copy of the original data and subsample cells
    print("Creating data copy...")
    adata_gears = adata.copy()
    np.random.seed(42)  # Set random seed for reproducibility

    # Process condition labels
    print("Processing condition labels...")
    adata_gears.obs['condition'] = adata_gears.obs['condition'].astype(str)
    adata_gears.obs['condition'] = adata_gears.obs['condition'].apply(lambda x: x + '+ctrl' if '+' not in x else x)
    adata_gears.obs['condition'] = adata_gears.obs['condition'].str.replace('control+ctrl', 'ctrl')

    # Get unique perturbations
    print("Getting unique perturbations...")
    all_perturbations = adata_gears.obs['condition'].unique()
    all_perturbations = all_perturbations[all_perturbations != 'ctrl']
    print(f"Found {len(all_perturbations)} unique perturbations")

    # Split perturbations into first and second half
    print("Splitting perturbations into halves...")
    # Split perturbations into single gene and two gene perturbations
    single_gene_perturbations = np.array([p for p in all_perturbations if '+ctrl' in p])
    two_gene_perturbations = np.array([p for p in all_perturbations if 'ctrl' not in p])

    print(f"Found {len(single_gene_perturbations)} single gene perturbations")
    print(f"Found {len(two_gene_perturbations)} two gene perturbations")

    # Use single gene perturbations for the split
    first_half_perturbations = two_gene_perturbations[:len(two_gene_perturbations)//2]
    second_half_perturbations = two_gene_perturbations[len(two_gene_perturbations)//2:]
    print(f"First half: {len(first_half_perturbations)} perturbations")
    print(f"Second half: {len(second_half_perturbations)} perturbations")

    # Process first half splits
    print("\nProcessing first half splits...")
    first_half_train_val = first_half_perturbations
    np.random.shuffle(first_half_train_val)
    split_idx = int(len(first_half_train_val) * 0.8) # 80% train, 10% val
    first_half_train = first_half_train_val[:split_idx]
    first_half_val = first_half_train_val[split_idx:]

    first_half_split_dict = {
        'train': np.concatenate([first_half_train, single_gene_perturbations]),
        'val': first_half_val,
        'test': np.concatenate([first_half_train_val, single_gene_perturbations])
    }
    print(f"First half splits - Train: {len(first_half_train)}, Val: {len(first_half_val)}")

    # Process second half splits
    print("\nProcessing second half splits...")
    second_half_train_val = second_half_perturbations
    np.random.shuffle(second_half_train_val) 
    split_idx = int(len(second_half_train_val) * 0.8)  # 80% train, 10% val
    second_half_train = second_half_train_val[:split_idx]
    second_half_val = second_half_train_val[split_idx:]

    second_half_split_dict = {
        'train': np.concatenate([second_half_train, single_gene_perturbations]),
        'val': second_half_val,
        'test': np.concatenate([second_half_train_val, single_gene_perturbations])
    }
    print(f"Second half splits - Train: {len(second_half_train)}, Val: {len(second_half_val)}")

    # Prepare data for GEARS
    print("\nPreparing data for GEARS...")
    first_half_perturbations = np.concatenate([['ctrl'], first_half_perturbations, single_gene_perturbations])
    second_half_perturbations = np.concatenate([['ctrl'], second_half_perturbations, single_gene_perturbations])

    # Create subsetted datasets
    print("Creating subsetted datasets...")
    adata_first_half_gears = adata_gears[adata_gears.obs['condition'].isin(first_half_perturbations)].copy()
    adata_second_half_gears = adata_gears[adata_gears.obs['condition'].isin(second_half_perturbations)].copy()
    print(f"First half dataset size: {adata_first_half_gears.n_obs} cells")
    print(f"Second half dataset size: {adata_second_half_gears.n_obs} cells")

    # Process first half data
    print("\nProcessing first half data with GEARS...")
    pert_data_first_half = PertData('../../data')
    if not os.path.exists('../../data/norman19_gears_first_half_gears'):
        pert_data_first_half.new_data_process(dataset_name='norman19_gears_first_half_gears', adata=adata_first_half_gears)
    pert_data_first_half.load(data_path='../../data/norman19_gears_first_half_gears')
    # Filter out genes not in gene2go from each split
    for split in ['train', 'val', 'test']:
        original_count = len(first_half_split_dict[split])
        first_half_split_dict[split] = [gene for gene in first_half_split_dict[split] 
                                    if gene.split('+')[0] in ['ctrl'] + list(pert_data_first_half.gene2go.keys()) and gene.split('+')[1] in ['ctrl'] + list(pert_data_first_half.gene2go.keys())]
        filtered_count = len(first_half_split_dict[split])
        print(f"{split} split: {filtered_count}/{original_count} genes kept ({original_count - filtered_count} removed)")
    with open('../../data/norman19_gears_first_half_split_dict.pkl', 'wb') as f:
        pickle.dump(first_half_split_dict, f)
    pert_data_first_half.prepare_split(split='custom', seed=42, split_dict_path='../../data/norman19_gears_first_half_split_dict.pkl')
    print("First half data processing complete")

    # Process second half data
    print("\nProcessing second half data with GEARS...")
    pert_data_second_half = PertData('../../data')
    if not os.path.exists('../../data/norman19_gears_second_half_gears'):
        pert_data_second_half.new_data_process(dataset_name='norman19_gears_second_half_gears', adata=adata_second_half_gears)
    pert_data_second_half.load(data_path='../../data/norman19_gears_second_half_gears')
    # Filter out genes not in gene2go from each split
    for split in ['train', 'val', 'test']:
        original_count = len(second_half_split_dict[split])
        second_half_split_dict[split] = [gene for gene in second_half_split_dict[split] 
                                    if gene.split('+')[0] in ['ctrl'] + list(pert_data_second_half.gene2go.keys()) and gene.split('+')[1] in ['ctrl'] + list(pert_data_second_half.gene2go.keys())]
        filtered_count = len(second_half_split_dict[split])
        print(f"{split} split: {filtered_count}/{original_count} genes kept ({original_count - filtered_count} removed)")
    with open('../../data/norman19_gears_second_half_split_dict.pkl', 'wb') as f:
        pickle.dump(second_half_split_dict, f)
    pert_data_second_half.prepare_split(split='custom', seed=42, split_dict_path='../../data/norman19_gears_second_half_split_dict.pkl')
    print("Second half data processing complete")

    print("\nGEARS data preparation completed successfully!")

    # Get dataloaders
    pert_data_first_half.get_dataloader(batch_size = 32, test_batch_size = 512)
    pert_data_second_half.get_dataloader(batch_size = 32, test_batch_size = 512)

    # Prepare perturbations for prediction
    first_half_train_val_without_ctrl = [gene.split('+') for gene in first_half_train_val if gene.split('+')[0] in pert_data_first_half.gene2go.keys() and gene.split('+')[1] in pert_data_first_half.gene2go.keys()]
    second_half_train_val_without_ctrl = [gene.split('+') for gene in second_half_train_val if gene.split('+')[0] in pert_data_first_half.gene2go.keys() and gene.split('+')[1] in pert_data_first_half.gene2go.keys()]
    
    # Train and evaluate models with different loss functions and weights
    losses = ['mse', 'default_loss']
    weights = ['unweighted', 'weighted']

    for loss in losses:

        for weight in weights:

            if not os.path.exists(f'../../data/gears_predictions_{loss}_{weight}_norman19.pkl'):

                print(f"Training GEARS model with {loss} loss and {weight} weight")

                # ===================GEARS MODEL TRAINING ===================

                gears_model_first_half = GEARS(pert_data_first_half, device = 'cuda', 
                                        weight_bias_track = False, 
                                        proj_name = 'first_half', 
                                        exp_name = 'first_half',
                                        loss_weights_dict = loss_weights_dict if weight == 'weighted' else None,
                                        use_mse_loss = True if loss == 'mse' else False)
                gears_model_first_half.model_initialize()

                gears_model_second_half = GEARS(pert_data_second_half, device = 'cuda', 
                                        weight_bias_track = False, 
                                        proj_name = 'second_half', 
                                        exp_name = 'second_half',
                                        loss_weights_dict = loss_weights_dict if weight == 'weighted' else None,
                                        use_mse_loss = True if loss == 'mse' else False)
                gears_model_second_half.model_initialize()

                gears_model_first_half.train(epochs = 10)
                gears_model_second_half.train(epochs = 10)

                os.makedirs("../../data/gears_models", exist_ok=True)

                gears_model_first_half.save_model(f'../../data/gears_models/norman19_first_half_{loss}_{weight}')
                gears_model_second_half.save_model(f'../../data/gears_models/norman19_second_half_{loss}_{weight}')

                # ===================GEARS PREDICTIONS ===================

                # Compile GEARS predictions
                predictions_first_half = gears_model_second_half.predict(first_half_train_val_without_ctrl)
                predictions_second_half = gears_model_first_half.predict(second_half_train_val_without_ctrl)

                # Combine predictions from both halves
                gears_predictions[f'{loss}_{weight}'] = {}
                for pert in tqdm(predictions_first_half.keys()):
                    gears_predictions[f'{loss}_{weight}'][pert.replace('_', '+')] = predictions_first_half[pert]
                for pert in tqdm(predictions_second_half.keys()):
                    gears_predictions[f'{loss}_{weight}'][pert.replace('_', '+')] = predictions_second_half[pert]
                gears_predictions[f'{loss}_{weight}']['control'] = pert_means['control']

                # Save GEARS predictions to pickle file
                with open(f'../../data/gears_predictions_{loss}_{weight}_norman19.pkl', 'wb') as f:
                    pickle.dump(gears_predictions[f'{loss}_{weight}'], f)

            else:

                # Load GEARS predictions from pickle file
                with open(f'../../data/gears_predictions_{loss}_{weight}_norman19.pkl', 'rb') as f:
                    gears_predictions[f'{loss}_{weight}'] = pickle.load(f)

Starting GEARS data preparation...
Creating data copy...
Processing condition labels...
Getting unique perturbations...
Found 175 unique perturbations
Splitting perturbations into halves...
Found 91 single gene perturbations
Found 84 two gene perturbations
First half: 42 perturbations
Second half: 42 perturbations

Processing first half splits...
First half splits - Train: 33, Val: 9

Processing second half splits...
Second half splits - Train: 33, Val: 9

Preparing data for GEARS...
Creating subsetted datasets...


Found local copy...


First half dataset size: 42240 cells
Second half dataset size: 42240 cells

Processing first half data with GEARS...


Found local copy...
These perturbations are not in the GO graph and their perturbation can thus not be predicted
['LYL1+IER5L' 'IER5L+ctrl' 'KIAA1804+ctrl']
Local copy of pyg dataset is detected. Loading...
Done!


train split: 121/124 genes kept (3 removed)
val split: 9/9 genes kept (0 removed)


Found local copy...


test split: 130/133 genes kept (3 removed)
First half data processing complete

Processing second half data with GEARS...


Found local copy...
These perturbations are not in the GO graph and their perturbation can thus not be predicted
['IER5L+ctrl' 'KIAA1804+ctrl']
Local copy of pyg dataset is detected. Loading...
Done!


train split: 122/124 genes kept (2 removed)
val split: 9/9 genes kept (0 removed)


Creating dataloaders....
Done!
Creating dataloaders....
Done!


test split: 131/133 genes kept (2 removed)
Second half data processing complete

GEARS data preparation completed successfully!
Training GEARS model with mse loss and unweighted weight


Found local copy...
