In [1]:
import sys
import yaml
import torch
import logging
from pathlib import Path

# Add the path to the directory containing the omnicell package
# Assuming the omnicell package is in the parent directory of your notebook
sys.path.append('..')  # Adjust this path as needed

import yaml
import torch
import logging
from pathlib import Path
from omnicell.config.config import Config
from omnicell.data.loader import DataLoader
from omnicell.constants import PERT_KEY, GENE_EMBEDDING_KEY, CONTROL_PERT
# from omnicell.models.model_factory import get_model

# Set up logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

# Configure paths
MODEL_CONFIG = "/orcd/data/omarabu/001/njwfish/omnicell/configs/models/sclambda_large_no_clip.yaml"
ETL_CONFIG = "/orcd/data/omarabu/001/njwfish/omnicell/configs/ETL/no_preprocessing.yaml"  # Change this to your desired ETL config
SPLIT_CONFIG = "/orcd/data/omarabu/001/njwfish/omnicell/configs/repogle_k562_essential_raw/random_splits/rs_accP_k562_ood_ss:ns_20_2_most_pert_0.1/split_0/split_config.yaml"
EVAL_CONFIG = None  # Set this if you want to run evaluations

# Load configuration
config = Config.from_yamls(MODEL_CONFIG, ETL_CONFIG, SPLIT_CONFIG, EVAL_CONFIG)
config.etl_config['gene_embedding'] = 'bioBERT'
config.etl_config['drop_unmatched_perts'] = True
# Set up device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# Initialize data loader and load training data
loader = DataLoader(config)
adata, pert_rep_map = loader.get_training_data()

# Get dimensions and perturbation IDs
input_dim = adata.obsm['embedding'].shape[1]
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
pert_ids = adata.obs[PERT_KEY].unique()
gene_emb_dim = adata.varm[GENE_EMBEDDING_KEY].shape[1] if GENE_EMBEDDING_KEY in adata.varm else None

print(f"Data loaded:")
print(f"- Number of cells: {adata.shape[0]}")
print(f"- Input dimension: {input_dim}")
print(f"- Number of perturbations: {len(pert_ids)}")


2025-01-16 13:35:18,578 - INFO - Loading training data at path: /orcd/data/omarabu/001/Omnicell_datasets/repogle_k562_essential_raw/K562_essential_raw_singlecell_01.h5ad


Using device: cuda


2025-01-16 13:35:23,844 - INFO - Loaded unpreprocessed data, # of data points: 310385, # of genes: 8563.
2025-01-16 13:35:23,845 - INFO - Preprocessing training data
  embedding = torch.load(f"{dataset_details.folder_path}/{self.gene_embedding_name}.pt")
2025-01-16 13:35:23,865 - INFO - Removing observations with perturbations not in the dataset as a column
  adata.obsm["embedding"] = adata.X.toarray().astype('float32')
2025-01-16 13:36:16,763 - INFO - Doing OOD split


Data loaded:
- Number of cells: 279630
- Input dimension: 8563
- Number of perturbations: 1850


In [132]:
from omnicell.constants import PERT_KEY, GENE_EMBEDDING_KEY, CONTROL_PERT, CELL_KEY

import scanpy as sc
import pandas as pd 
import numpy as np
import scipy

# After the existing imports, add:
from sklearn.linear_model import LinearRegression, Ridge, Lasso, ElasticNet
from sklearn.ensemble import RandomForestRegressor
from sklearn.svm import SVR
from sklearn.metrics import mean_squared_error, r2_score
from sklearn.decomposition import PCA


def expected_distribute_shift(ctrl_cells, shift_pred):
    cell_fractions = ctrl_cells.sum(axis=1) / ctrl_cells.sum()
    Z = ctrl_cells + shift_pred[None, :] * cell_fractions[:, None] * ctrl_cells.shape[0]
    return Z

def distribute_shift(ctrl_cells, mean_shift):
    """
    Distribute the global per-gene difference (sum_diff[g]) across cells in proportion
    to the cell's existing counts for that gene. 
    """ 
    ctrl_cells = ctrl_cells.copy()
    sum_shift = (mean_shift * ctrl_cells.shape[0]).astype(int)

    n_cells, n_genes = ctrl_cells.shape

    #For each gene, distribute sum_diff[g] using a single multinomial draw
    for g in range(n_genes):
        diff = int(sum_shift[g])
        if diff == 0:
            continue  

        # Current counts for this gene across cells
        gene_counts = ctrl_cells[:, g]

        current_total = gene_counts.sum().astype(np.float64)
        

        # Probabilities ~ gene_counts / current_total
        p = gene_counts / current_total


        if diff > 0:
            # We want to add `diff` counts
            draws = np.random.multinomial(diff, p)  # shape: (n_cells,)
            
            ctrl_cells[:, g] = gene_counts + draws
        else:
            if current_total <= 0:
                continue

            # We want to remove `abs(diff)` counts
            amt_to_remove = abs(diff)

            to_remove = min(amt_to_remove, current_total)
            draws = np.random.multinomial(to_remove, p)
            # Subtract, then clamp
            updated = gene_counts - draws
            updated[updated < 0] = 0
            ctrl_cells[:, g] = updated

    return ctrl_cells

def fit_supervised_model(X, Y, model_type='linear', **kwargs):
    """
    Fit a supervised model based on the specified model type.
    
    Args:
        X: Input features (gene embeddings)
        Y: Target values (perturbation effects)
        model_type: Type of model to fit ('linear', 'ridge', 'lasso', 'elastic_net', 'rf', 'svr')
        **kwargs: Additional arguments to pass to the model constructor
    
    Returns:
        fitted model, training MSE, R2 score
    """
    models = {
        'linear': LinearRegression,
        'ridge': Ridge,
        'lasso': Lasso,
        'elastic_net': ElasticNet,
        'rf': RandomForestRegressor,
        'svr': SVR
    }
    
    if model_type not in models:
        raise ValueError(f"Model type {model_type} not supported. Choose from {list(models.keys())}")
    
    model = models[model_type](**kwargs)
    model.fit(X, Y)
    
    # Make predictions and calculate metrics
    Y_pred = model.predict(X)
    mse = mean_squared_error(Y, Y_pred)
    r2 = r2_score(Y, Y_pred)
    
    return model, mse, r2

def compute_cell_type_means(adata, cell_type):
    """Compute perturbation effect embeddings for a specific cell type"""
    # Filter data for this cell type
    cell_type_data = adata[adata.obs[CELL_KEY] == cell_type]
    
    # Compute control mean for this cell type
    ctrl_mean = np.mean(
        cell_type_data[cell_type_data.obs[PERT_KEY] == CONTROL_PERT].X, axis=0
    )
    
    # Convert to dense array if sparse
    X = cell_type_data.X.toarray() if scipy.sparse.issparse(cell_type_data.X) else cell_type_data.X
    
    # Create dataframe
    df = pd.DataFrame(X, index=cell_type_data.obs.index)
    df['perturbation'] = cell_type_data.obs[PERT_KEY].values
    
    # Compute means per perturbation
    pert_means = df.groupby('perturbation').mean()
    
    # Compute deltas from control
    pert_deltas = pd.DataFrame(pert_means.values - ctrl_mean, index=pert_means.index)
    pert_deltas_dict = {
        pert: np.array(means) 
        for pert, means in pert_deltas.iterrows() 
        if pert != CONTROL_PERT
    }
    
    return ctrl_mean, pert_deltas_dict

class MeanPredictor():

    def __init__(self, adata: sc.AnnData, model_config: dict):
        self.model = None
        self.model_type = model_config['model_type']
        self.pca_gene_embeddings = model_config['pca_gene_embeddings']
        self.pca_gene_embeddings_components = model_config['pca_gene_embeddings_components']
        self.total_adata = adata
        self.gene_emb = None

    def train(self, adata: sc.AnnData):
        if self.pca_gene_embeddings:
            pca = PCA(n_components=self.pca_gene_embeddings_components)
            gene_emb_temp = pca.fit_transform(adata.varm[GENE_EMBEDDING_KEY])
        else:
            gene_emb_temp = adata.varm[GENE_EMBEDDING_KEY]

        gene_emb = {}
        for i, g in enumerate(adata.var_names):
            gene_emb[g] = gene_emb_temp[i]
        
        self.gene_emb = gene_emb
        # Get unique cell types
        cell_types = adata.obs[CELL_KEY].unique()

        # Compute embeddings for each cell type
        Xs = []
        Ys = []
        for cell_type in cell_types:
            ctrl_mean, pert_deltas_dict = compute_cell_type_means(adata, cell_type)
            
            # Get perturbation IDs for this cell type
            idxs = pert_deltas_dict.keys()
            
            # Create feature matrix X and target matrix Y
            Y = np.array([pert_deltas_dict[pert] for pert in idxs])
            X = np.array([gene_emb[g] for g in idxs])

            # Store the embeddings
            Xs.append(X)
            Ys.append(Y)

        # Now you can train a model for each cell type
        X = np.concatenate(Xs)
        Y = np.concatenate(Ys)
        self.model, mse, r2 = fit_supervised_model(X, Y, model_type=self.model_type)
        
    def make_predict(self, adata: sc.AnnData, pert_id: str, cell_type: str) -> np.ndarray:
        ctrl_cells = adata[(adata.obs[PERT_KEY] == CONTROL_PERT) & (adata.obs[CELL_KEY] == cell_type)].X.toarray()
        X_new = np.array(self.gene_emb[pert_id].reshape(1, -1))
        shift_pred = np.array(self.model.predict(X_new)).flatten()
        return distribute_shift(ctrl_cells, shift_pred)
    

In [133]:
model_config = {
    'model_type': 'linear',
    'pca_gene_embeddings': True,
    'pca_gene_embeddings_components': 10
}

model = MeanPredictor(adata, model_config)
model.train(adata)


  pert_means = df.groupby('perturbation').mean()


In [134]:
shift_pred = pert_deltas_dict[pert_id]
ctrl_cells = adata[(adata.obs[PERT_KEY] == CONTROL_PERT) & (adata.obs[CELL_KEY] == cell_type)].X.toarray()

In [135]:
Z = model.make_predict(adata, 'SLC39A9', 'k562')

matrix([[18.943586,  0.      ,  0.      , ..., 30.510262,  0.      ,
          0.      ],
        [21.323915,  0.      ,  0.      , ..., 34.868988,  0.      ,
          0.      ],
        [18.176373,  0.      ,  0.      , ..., 29.274597,  0.      ,
          0.      ],
        ...,
        [21.275522,  0.      ,  0.      , ..., 34.26604 ,  0.      ,
          0.      ],
        [21.526276,  0.      ,  0.      , ..., 35.150177,  0.      ,
          0.      ],
        [21.763897,  0.      ,  0.      , ..., 35.46508 ,  0.      ,
          0.      ]], dtype=float32)

In [22]:
adata.obs[PERT_KEY].value_counts()

pert
ctrl       10691
RPL3        1996
NCBP2        992
KIF11        974
SLC39A9      752
           ...  
NUP155         7
POLR3A         6
SEC62          5
RBM22          5
POT1           5
Name: count, Length: 1850, dtype: int64

In [53]:
# use sklearn to fit a supervised model based on `model_type`


# use sklearn to fit a supervised model based on `model_type`

array([[ 0.00410382, -0.01693404, -0.02247995, ...,  0.02091249,
         0.02567016,  0.0138661 ],
       [ 0.01048473,  0.02137071, -0.0019055 , ..., -0.01459742,
        -0.02456123,  0.0438426 ],
       [ 0.05892141, -0.03006956, -0.29437995, ...,  0.03454808,
        -0.01858041, -0.04018317],
       ...,
       [ 0.05377207, -0.0211549 ,  0.00326538, ...,  0.01787109,
        -0.01472393,  0.00791901],
       [-0.01146409, -0.01612045, -0.00243032, ...,  0.0074872 ,
        -0.0009381 ,  0.01571979],
       [ 0.02805913, -0.00253367, -0.16755474, ...,  0.01492483,
        -0.06151099,  0.04340555]], dtype=float32)

In [54]:
(np.abs(Y - X @ (np.linalg.inv(X.T @ X) @ X.T @ Y))).mean()

np.float32(0.03486446)

In [2]:
from train import get_model
model = get_model(config.get_model_name(), config.model_config, loader, pert_rep_map, input_dim, device, pert_ids, gene_emb_dim)

  from .autonotebook import tqdm as notebook_tqdm
2025-01-12 19:57:51,653 - INFO - SCLambda model selected


In [3]:
model.train(adata)

2025-01-12 19:59:05,108 - INFO - Computing 32000-dimensional perturbation embeddings for 279630 cells...
2025-01-12 19:59:35,468 - INFO - Splitting data...
  2%|▏         | 4/200 [02:07<1:44:02, 31.85s/it]2025-01-12 20:03:23,463 - INFO - Epoch  5 complete! -  Loss: 3856.907958984375
2025-01-12 20:06:54,662 - INFO - Validation correlation delta 0.23106926226955496
  4%|▍         | 9/200 [08:16<2:31:07, 47.48s/it] 2025-01-12 20:09:32,692 - INFO - Epoch  10 complete! -  Loss: nan
2025-01-12 20:13:08,032 - INFO - Validation correlation delta nan
  7%|▋         | 14/200 [14:31<2:35:28, 50.15s/it] 2025-01-12 20:15:48,229 - INFO - Epoch  15 complete! -  Loss: nan
2025-01-12 20:19:25,387 - INFO - Validation correlation delta nan
 10%|▉         | 19/200 [20:49<2:32:53, 50.68s/it] 2025-01-12 20:22:05,256 - INFO - Epoch  20 complete! -  Loss: nan
2025-01-12 20:25:39,663 - INFO - Validation correlation delta nan
 12%|█▏        | 24/200 [27:03<2:28:12, 50.53s/it] 2025-01-12 20:28:19,324 - INFO - Ep