In [1]:
# code for table 2 (post perturbation with full context w PCA compressed average control expression or AIDO control embeddings for cell line context)
import torch
import lightning as pl
import numpy as np
import pandas as pd
from sklearn.decomposition import PCA
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split
# from contextualized.easy import ContextualizedCorrelationNetworks
import os

from contextualized.regression.lightning_modules import ContextualizedCorrelation
from contextualized.data import CorrelationDataModule
from lightning import seed_everything, Trainer

## Make Data Splits

In [None]:
# expression for pca compressed avg control expression, embeddings for aido embeddings for cell line context

# CONTEXT_MODE = 'expression' 
CONTEXT_MODE = 'embeddings' 

# File Paths
PATH_L1000 = 'data/merged_output4_head.csv'
PATH_CTLS = 'data/ctrls.csv'
EMB_FILE = 'data/aido_cell_100m_lincs_embeddings.npy'

N_DATA_PCS = 50    
TEST_SIZE = 0.33
RANDOM_STATE = 42

N_CTRL_PCS = 20
N_EMBEDDING_PCS = 20 

#specify type of perturbation to fit on
pert_to_fit_on = ['trt_cp']


# Validate file paths
if not os.path.exists(PATH_L1000):
    raise FileNotFoundError(f"L1000 data not found at {PATH_L1000}")
if not os.path.exists(PATH_CTLS):
    raise FileNotFoundError(f"Controls data not found at {PATH_CTLS}")
if CONTEXT_MODE == 'embeddings' and not os.path.exists(EMB_FILE):
    raise FileNotFoundError(f"Embeddings file not found at {EMB_FILE}")

print(f"Using cell line context mode: {CONTEXT_MODE}\n")

# Load L1000 data
df = pd.read_csv(PATH_L1000, engine='pyarrow')

df = df[df['pert_type'].isin(pert_to_fit_on)]

# Quality filters
bad = (
    (df['distil_cc_q75'] < 0.2) | (df['distil_cc_q75'] == -666) | (df['distil_cc_q75'].isna()) |
    (df['pct_self_rank_q25'] > 5) | (df['pct_self_rank_q25'] == -666) | (df['pct_self_rank_q25'].isna())
)
df = df[~bad]

# Ignore-flag columns for missing meta-data
df['ignore_flag_pert_time'] = (df['pert_time'] == -666).astype(int)
df['ignore_flag_pert_dose'] = (df['pert_dose'] == -666).astype(int)

# Replace –666 with column mean
for col in ['pert_time', 'pert_dose']:
    mean_val = df.loc[df[col] != -666, col].mean()
    df[col] = df[col].replace(-666, mean_val)

# Get X (gene expression data)
numeric_cols = df.select_dtypes(include=[np.number]).columns
drop_cols = ['pert_dose', 'pert_dose_unit', 'pert_time',
             'distil_cc_q75', 'pct_self_rank_q25']
feature_cols = [c for c in numeric_cols if c not in drop_cols]
X_raw = df[feature_cols].values

scaler_X = StandardScaler()
X_scaled = scaler_X.fit_transform(X_raw)  # shape (N, p)

# Get context components
pert_dummies = pd.get_dummies(df['pert_id'], drop_first=True)

pert_time = df['pert_time'].to_numpy().reshape(-1, 1)
pert_dose = df['pert_dose'].to_numpy().reshape(-1, 1)
ignore_time = df['ignore_flag_pert_time'].to_numpy().reshape(-1, 1)
ignore_dose = df['ignore_flag_pert_dose'].to_numpy().reshape(-1, 1)


cell2vec = {}
unique_cells_in_l1000 = np.sort(df['cell_id'].unique())

if CONTEXT_MODE == 'expression':
    print("Preparing cell line context using PCA of control expression...")
    ctrls_df = pd.read_csv(PATH_CTLS, index_col=0)  # index = cell_id
    
    # Filter controls to only include cells present in the L1000 dataset
    ctrls_df = ctrls_df.loc[ctrls_df.index.intersection(unique_cells_in_l1000)]
    
    if ctrls_df.empty:
        raise ValueError("No common cell IDs found between lincs1000.csv and ctrls.csv for PCA control expression.")

    scaler_ctrls = StandardScaler()
    ctrls_scaled = scaler_ctrls.fit_transform(ctrls_df.values)

    n_cells = ctrls_scaled.shape[0]
    n_components_for_context = min(N_CTRL_PCS, n_cells)

    pca_ctrls = PCA(n_components=n_components_for_context, random_state=RANDOM_STATE)
    ctrls_pcs = pca_ctrls.fit_transform(ctrls_scaled)  # shape (#cells, N_CTRL_PCS)

    cell2vec = dict(zip(ctrls_df.index, ctrls_pcs))

elif CONTEXT_MODE == 'embeddings':
    all_embeddings_raw = np.load(EMB_FILE)

    # Use ctrls.csv to map embedding rows to cell IDs
    ctrls_meta_df = pd.read_csv(PATH_CTLS, index_col=0)
    embedding_cell_ids_full = ctrls_meta_df.index.to_numpy()

    if len(embedding_cell_ids_full) != all_embeddings_raw.shape[0]:
        raise ValueError(
            f"Mismatch: embeddings file '{EMB_FILE}' has {all_embeddings_raw.shape[0]} entries, "
            f"but ctrls.csv has {len(embedding_cell_ids_full)} cell IDs. "
            "Please ensure they correspond row-wise."
        )

    # Z-score normalize embeddings
    scaler_embeddings = StandardScaler()
    embeddings_scaled = scaler_embeddings.fit_transform(all_embeddings_raw)

    # Apply PCA to embeddings
    n_embeddings_dim = embeddings_scaled.shape[1]
    n_components_for_context = min(N_EMBEDDING_PCS, n_embeddings_dim)

    pca_embeddings = PCA(n_components=n_components_for_context, random_state=RANDOM_STATE)
    embeddings_pcs = pca_embeddings.fit_transform(embeddings_scaled)

    # Create a mapping from cell_id to its processed embedding vector for all loaded embeddings
    full_cell_embedding_map = dict(zip(embedding_cell_ids_full, embeddings_pcs))

    # Filter to only include cells present in the L1000 dataset
    for cell_id in unique_cells_in_l1000:
        if cell_id in full_cell_embedding_map:
            cell2vec[cell_id] = full_cell_embedding_map[cell_id]

    if not cell2vec:
        raise ValueError(
            "No common cell IDs found between lincs1000.csv and embeddings/ctrls.csv. "
            "Cannot proceed. Please check your data files."
        )

    print(f"AIDO embeddings context: Original dim {n_embeddings_dim}, "
          f"now {n_components_for_context}D after scaling and PCA for {len(cell2vec)} unique cells.")

else:
    raise ValueError(f"Invalid CONTEXT_MODE: {CONTEXT_MODE}. Choose 'expression' or 'embeddings'.")

# Update the list of unique cells to process, based on what's available in cell2vec
unique_cells = np.sort(list(cell2vec.keys()))

if not unique_cells.shape[0] > 0:
    raise RuntimeError("No cell IDs found to process after context loading and filtering. Check data consistency.")

continuous_context_list = []
other_context_list = []
cell_ids = df['cell_id'].to_numpy()

for cell_id in unique_cells:
    mask = cell_ids == cell_id
    if mask.sum() == 0:
        continue
        
    if mask.sum() < 2:  # At least 2 samples needed for a train/test split
        print(f"Skipping cell {cell_id} due to insufficient samples ({mask.sum()}). Needs at least 2 for split.")
        continue

    # Continuous context features: cell context + pert_time + pert_dose
    continuous_context = np.hstack([
        np.tile(cell2vec[cell_id], (mask.sum(), 1)),  # Cell-specific context (PCA of expr or embeddings)
        pert_time[mask],                              
        pert_dose[mask],                              
    ])
    
    # Other (categorical/binary) context features
    other_context = np.hstack([
        pert_dummies.loc[mask].values,                # Perturbation ID (one-hot encoded)
        ignore_time[mask],                            
        ignore_dose[mask],                            
    ])
    
    continuous_context_list.append(continuous_context)
    other_context_list.append(other_context)

# Concatenate all continuous context features across cells
all_continuous_context = np.vstack(continuous_context_list)
all_other_context = np.vstack(other_context_list)

# Scale the continuous context features together
print("Scaling continuous context features...")
scaler_continuous_context = StandardScaler()
all_continuous_context_scaled = scaler_continuous_context.fit_transform(all_continuous_context)

# Combine scaled continuous context with other context features
all_context_scaled = np.hstack([all_continuous_context_scaled, all_other_context])

print(f"Context scaling complete. Continuous features scaled: {all_continuous_context.shape[1]}, "
      f"Other features: {all_other_context.shape[1]}, Total: {all_context_scaled.shape[1]}")

# Now split the data
X_tr_lst, X_te_lst = [], []
C_tr_lst, C_te_lst = [], []
cell_tr_lst, cell_te_lst = [], []

start_idx = 0
for i, cell_id in enumerate(unique_cells):
    mask = cell_ids == cell_id
    if mask.sum() == 0:
        continue
        
    if mask.sum() < 2:
        continue

    end_idx = start_idx + mask.sum()
    
    X_cell = X_scaled[mask]
    C_cell = all_context_scaled[start_idx:end_idx]
    ids_cell = cell_ids[mask]
    
    X_tr, X_te, C_tr, C_te, ids_tr, ids_te = train_test_split(
        X_cell, C_cell, ids_cell,
        test_size=TEST_SIZE, random_state=RANDOM_STATE
    )

    X_tr_lst.append(X_tr)
    X_te_lst.append(X_te)
    C_tr_lst.append(C_tr)
    C_te_lst.append(C_te)
    cell_tr_lst.append(ids_tr)
    cell_te_lst.append(ids_te)
    
    start_idx = end_idx

if not X_tr_lst or not X_te_lst:
    raise RuntimeError(
        "No data collected for training/testing after splits. "
        "This might be due to all cells being filtered or having insufficient samples."
    )

# Concatenate splits across cells
X_train = np.vstack(X_tr_lst)
X_test = np.vstack(X_te_lst)
C_train = np.vstack(C_tr_lst)
C_test = np.vstack(C_te_lst)
cell_ids_train = np.concatenate(cell_tr_lst)
cell_ids_test = np.concatenate(cell_te_lst)

print(f'\nContext matrix:   train {C_train.shape}   test {C_test.shape}')

# Do any extra processing based on train split
# --- PCA on Gene Expression Data (X) ---
print("Applying PCA to gene expression data...")
pca_data = PCA(n_components=N_DATA_PCS, random_state=RANDOM_STATE)
X_train_pca = pca_data.fit_transform(X_train)
X_test_pca = pca_data.transform(X_test)

# Z-score in latent space
mu, sigma = X_train_pca.mean(0), X_train_pca.std(0)
X_train_norm = (X_train_pca - mu) / sigma
X_test_norm = (X_test_pca - mu) / sigma
print(f"Gene expression data: Reduced to {N_DATA_PCS} PCs and Z-score normalized.")

# Set useful variables
train_group_ids = cell_ids_train
test_group_ids = cell_ids_test
C_train = C_train
C_test = C_test
X_train = X_train_norm
X_test = X_test_norm

Using cell line context mode: embeddings

AIDO embeddings context: Original dim 640, now 20D after scaling and PCA for 1 unique cells.
Scaling continuous context features...
Context scaling complete. Continuous features scaled: 22, Other features: 137, Total: 159

Context matrix:   train (270, 159)   test (134, 159)
Applying PCA to gene expression data...
Gene expression data: Reduced to 50 PCs and Z-score normalized.


## Fit Population Baseline

In [3]:
from contextualized.baselines.networks import CorrelationNetwork
pop_model = CorrelationNetwork()
pop_model.fit(X_train)
print(f"Train MSE: {pop_model.measure_mses(X_train).mean()}")
print(f"Test MSE: {pop_model.measure_mses(X_test).mean()}")

Train MSE: 0.980000000000001
Test MSE: 0.39288096437129766


## Fit Grouped Baseline

In [4]:
from contextualized.baselines.networks import GroupedNetworks
grouped_model = GroupedNetworks(CorrelationNetwork)
grouped_model.fit(X_train, train_group_ids)
print(f"Grouped Train MSE: {grouped_model.measure_mses(X_train, train_group_ids).mean()}")
print(f"Grouped Test MSE: {grouped_model.measure_mses(X_test, test_group_ids).mean()}")

Grouped Train MSE: 0.980000000000001
Grouped Test MSE: 0.39288096437129766


## Fit Contextualized Model

In [None]:
import wandb
wandb.login(key='add-your-key-here')  # Add your WandB API key here

contextualized_model = ContextualizedCorrelation(
    context_dim=C_train.shape[1],
    x_dim=X_train.shape[1],
    encoder_type='mlp',
    num_archetypes=50,
)
# Random val split
C_val = train_test_split(C_train, test_size=0.2, random_state=RANDOM_STATE)[0]
X_val = train_test_split(X_train, test_size=0.2, random_state=RANDOM_STATE)[0]
datamodule = CorrelationDataModule(
    C_train=C_train,
    X_train=X_train,
    C_val=C_val,
    X_val=X_val,
    C_test=C_test,
    X_test=X_test,
    C_predict=np.concatenate((C_train, C_test), axis=0),
    X_predict=np.concatenate((X_train, X_test), axis=0),
    batch_size=32,
)
checkpoint_callback = pl.pytorch.callbacks.ModelCheckpoint(
    monitor='val_loss',
    mode='min',
    save_top_k=1,
    filename='best_model',
)
logger = pl.pytorch.loggers.WandbLogger(
    project='contextpert',
    name='cell_line_context',
    log_model=True,
    save_dir='logs/',
)
trainer = Trainer(
    max_epochs=10,
    accelerator='auto',
    devices='auto',
    callbacks=[checkpoint_callback],
    logger=logger,
)
trainer.fit(contextualized_model, datamodule=datamodule)

[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /Users/calebellington/.netrc
GPU available: True (mps), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs



  | Name      | Type             | Params | Mode 
-------------------------------------------------------
0 | metamodel | SubtypeMetamodel | 255 K  | train
-------------------------------------------------------
255 K     Trainable params
0         Non-trainable params
255 K     Total params
1.021     Total estimated model params size (MB)
7         Modules in train mode
0         Modules in eval mode


Sanity Checking DataLoader 0:   0%|          | 0/2 [00:00<?, ?it/s]

/opt/homebrew/Caskroom/miniforge/base/envs/contextpert/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:425: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=7` in the `DataLoader` to improve performance.


                                                                           

/opt/homebrew/Caskroom/miniforge/base/envs/contextpert/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:425: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=7` in the `DataLoader` to improve performance.
/opt/homebrew/Caskroom/miniforge/base/envs/contextpert/lib/python3.11/site-packages/lightning/pytorch/loops/fit_loop.py:310: The number of training batches (9) is smaller than the logging interval Trainer(log_every_n_steps=50). Set a lower value for log_every_n_steps if you want to see logs for the training epoch.


Epoch 9: 100%|██████████| 9/9 [00:00<00:00, 23.63it/s, v_num=2awo] 

`Trainer.fit` stopped: `max_epochs=10` reached.


Epoch 9: 100%|██████████| 9/9 [00:00<00:00, 22.99it/s, v_num=2awo]


In [9]:
print(f"Testing model on training data...")
trainer.test(contextualized_model, datamodule.train_dataloader())
print(f"Testing model on test data...")
trainer.test(contextualized_model, datamodule.test_dataloader())

Testing model on training data...


/opt/homebrew/Caskroom/miniforge/base/envs/contextpert/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:476: Your `test_dataloader`'s sampler has shuffling enabled, it is strongly recommended that you turn shuffling off for val/test dataloaders.
/opt/homebrew/Caskroom/miniforge/base/envs/contextpert/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:425: The 'test_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=7` in the `DataLoader` to improve performance.


Testing DataLoader 0: 100%|██████████| 9/9 [00:00<00:00, 147.98it/s]
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
       Test metric             DataLoader 0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
        test_loss           0.9736742973327637
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
Testing model on test data...
Testing DataLoader 0: 100%|██████████| 5/5 [00:00<00:00, 95.69it/s] 
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
       Test metric             DataLoader 0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
        test_loss           0.39202195405960083
────────────────────────────────────────────

[{'test_loss': 0.39202195405960083}]

In [10]:
print(checkpoint_callback.best_model_path)

logs/contextpert/mqse2awo/checkpoints/best_model.ckpt


## Predict Networks

In [None]:
# Necessary to save predictions from multiple devices in parallel
from contextualized.callbacks import PredictionWriter
from pathlib import Path

output_dir = Path(checkpoint_callback.best_model_path).parent / 'predictions'
writer_callback = PredictionWriter(
    output_dir=output_dir,
    write_interval='batch',
)
trainer = Trainer(
    accelerator='auto',
    devices='auto',
    callbacks=[checkpoint_callback, writer_callback],
)
_ = trainer.predict(contextualized_model, datamodule=datamodule)

GPU available: True (mps), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
/opt/homebrew/Caskroom/miniforge/base/envs/contextpert/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/logger_connector/logger_connector.py:76: Starting from v1.9.0, `tensorboardX` has been removed as a dependency of the `lightning.pytorch` package, due to potential conflicts with other packages in the ML ecosystem. For this reason, `logger=True` will use `CSVLogger` as the default logger, unless the `tensorboard` or `tensorboardX` packages are found. Please `pip install lightning[extra]` or one of them to enable TensorBoard support by default
/opt/homebrew/Caskroom/miniforge/base/envs/contextpert/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:425: The 'predict_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=7` in the `DataLo

Predicting DataLoader 0:  54%|█████▍    | 7/13 [00:00<00:00, 11.78it/s]

/Users/calebellington/Workbench/contextpert/Contextualized/contextualized/regression/lightning_modules.py:785: MPS: nonzero op is supported natively starting from macOS 14.0. Falling back on CPU. This may have performance implications. (Triggered internally at /Users/runner/work/pytorch/pytorch/pytorch/aten/src/ATen/native/mps/operations/Indexing.mm:404.)


Predicting DataLoader 0: 100%|██████████| 13/13 [00:00<00:00, 16.54it/s]


In [24]:
# Compile distributed predictions and put into order
import torch
import glob

# Convert context to hashable type for lookup
C_train_hashable = [tuple(row) for row in C_train]
C_test_hashable = [tuple(row) for row in C_test]

# Gather preds and move to CPU
all_correlations = {}
all_betas = {}
all_mus = {}
pred_files = glob.glob(str(output_dir / 'predictions_*.pt'))
for file in pred_files:
    preds = torch.load(file)
    for context, correlation, beta, mu in zip(preds['contexts'], preds['correlations'], preds['betas'], preds['mus']):
        context_tuple = tuple(context.tolist())
        all_correlations[context_tuple] = correlation.cpu().numpy()
        all_betas[context_tuple] = beta.cpu().numpy()
        all_mus[context_tuple] = mu.cpu().numpy()

# Remake preds in order of C_train and C_test
correlations_train = np.array([all_correlations[c] for c in C_train_hashable])
correlations_test = np.array([all_correlations[c] for c in C_test_hashable])
betas_train = np.array([all_betas[c] for c in C_train_hashable])
betas_test = np.array([all_betas[c] for c in C_test_hashable])
mus_train = np.array([all_mus[c] for c in C_train_hashable])
mus_test = np.array([all_mus[c] for c in C_test_hashable])

In [32]:
# Get individual MSEs by sample
# Sanity check: These should closely match the trainer.test() outputs from earlier
def measure_mses(betas, mus, X):
    mses = np.zeros(len(X))
    for i in range(len(X)):
        sample_mse = 0
        for j in range(X.shape[-1]):
            for k in range(X.shape[-1]):
                residual = X[i, j] - betas[i, j, k] * X[i, k] - mus[i, j, k]
                sample_mse += residual**2 / (X.shape[-1] ** 2)
        mses += sample_mse / len(X)
    return mses

mse_train = measure_mses(betas_train, mus_train, X_train)
mse_test = measure_mses(betas_test, mus_test, X_test)
print(f"Train MSEs: {mse_train.mean()}")
print(f"Test MSEs: {mse_test.mean()}")


Train MSEs: 0.9736743374936839
Test MSEs: 0.3920219624105609


In [33]:
# Iterate over the unique cells that were included in the splits for per-cell MSE
print("Per-cell MSE:")
for cell_id in unique_cells:
    tr_mask = cell_ids_train == cell_id
    te_mask = cell_ids_test == cell_id

    if tr_mask.sum() == 0 and te_mask.sum() == 0:
        continue

    tr_mse = mse_train[tr_mask].mean() if tr_mask.any() else np.nan
    te_mse = mse_test[te_mask].mean() if te_mask.any() else np.nan
    print(f'Cell {cell_id:<15}:  train MSE = {tr_mse:7.4f}   '
          f'test MSE = {te_mse:7.4f}   (n={tr_mask.sum():3d}/{te_mask.sum():3d})')

Per-cell MSE:
Cell A375           :  train MSE =  0.9737   test MSE =  0.3920   (n=270/134)
