In [2]:
import sys
import yaml
import torch
import logging
from pathlib import Path

# Add the path to the directory containing the omnicell package
# Assuming the omnicell package is in the parent directory of your notebook
sys.path.append('..')  # Adjust this path as needed

import yaml
import torch
import logging
from pathlib import Path
from omnicell.config.config import Config, ETLConfig, ModelConfig, DatasplitConfig, EvalConfig, EmbeddingConfig
from omnicell.data.loader import DataLoader
from omnicell.constants import PERT_KEY, GENE_EMBEDDING_KEY, CONTROL_PERT, OMNICELL_ROOT
from omnicell.models.selector import load_model

# Set up logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

# Configure paths
MODEL_CONFIG = ModelConfig.from_yaml(f"{OMNICELL_ROOT}/configs/models/sparsity_gt.yaml")
ETL_CONFIG = ETLConfig(name = "no_preprocessing", log1p = False, drop_unmatched_perts = True)
EMBEDDING_CONFIG = EmbeddingConfig(pert_embedding='GenePT')

SPLIT_CONFIG = DatasplitConfig.from_yaml(f"{OMNICELL_ROOT}/configs/splits/essential_gene_knockouts_raw/random_splits/acrossC_ood_ss:40/split_jurkat/split_config.yaml")
EVAL_CONFIG = EvalConfig.from_yaml(f"{OMNICELL_ROOT}/configs/splits/essential_gene_knockouts_raw/random_splits/acrossC_ood_ss:40/split_jurkat/eval_config.yaml")  # Set this if you want to run evaluations

# Load configurations
config = Config(model_config=MODEL_CONFIG,
                 etl_config=ETL_CONFIG, 
                 datasplit_config=SPLIT_CONFIG, 
                 eval_config=EVAL_CONFIG)




#Alternatively you can initialize the config objects manually as follows:
# etl_config = ETLConfig(name = XXX, log1p = False, drop_unmatched_perts = False, ...)
# model_config = ...
# embedding_config = ...
# datasplit_config = ...
# eval_config = ...
# config = Config(etl_config, model_config, datasplit_config, eval_config)

# Set up device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# Initialize data loader and load training data
loader = DataLoader(config)
adata, gene_emb = loader.get_training_data()

# Get dimensions and perturbation IDs
input_dim = adata.X.shape[1]
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
pert_ids = adata.obs[PERT_KEY].unique()
gene_emb_dim = adata.varm[GENE_EMBEDDING_KEY].shape[1] if GENE_EMBEDDING_KEY in adata.varm else None

print(f"Data loaded:")
print(f"- Number of cells: {adata.shape[0]}")
print(f"- Input dimension: {input_dim}")
print(f"- Number of perturbations: {len(pert_ids)}")



2025-02-20 14:50:16,113 - INFO - Loading data catalogue from /orcd/data/omarabu/001/opitcho/omnicell/configs/catalogue/FrangiehIzar2021_protein.yaml
2025-02-20 14:50:16,116 - INFO - Loading data catalogue from /orcd/data/omarabu/001/opitcho/omnicell/configs/catalogue/PapalexiSatija2021_eccite_arrayed_RNA.yaml
2025-02-20 14:50:16,119 - INFO - Loading data catalogue from /orcd/data/omarabu/001/opitcho/omnicell/configs/catalogue/AdamsonWeissman2016_GSM2406677_10X005.yaml
2025-02-20 14:50:16,121 - INFO - Loading data catalogue from /orcd/data/omarabu/001/opitcho/omnicell/configs/catalogue/PapalexiSatija2021_eccite_protein.yaml
2025-02-20 14:50:16,124 - INFO - Loading data catalogue from /orcd/data/omarabu/001/opitcho/omnicell/configs/catalogue/SchraivogelSteinmetz2020_TAP_SCREEN__chromosome_11_screen.yaml
2025-02-20 14:50:16,126 - INFO - Loading data catalogue from /orcd/data/omarabu/001/opitcho/omnicell/configs/catalogue/DatlingerBock2021.yaml
2025-02-20 14:50:16,129 - INFO - Loading data

Using device: cpu


2025-02-20 14:52:36,148 - INFO - Loaded unpreprocessed data, # of data points: 966728, # of genes: 11907.
2025-02-20 14:52:36,149 - INFO - Preprocessing training data
2025-02-20 14:52:36,152 - INFO - Using identity features for perturbations
2025-02-20 14:52:36,811 - INFO - Removing observations with perturbations not in the dataset as a column
2025-02-20 14:52:37,720 - INFO - Removed 143 perturbations that were not in the dataset columns and 0 perturbations that did not have an embedding for a total of 143 perturbations removed out of an initial 2396 perturbations
2025-02-20 14:58:00,874 - INFO - Doing OOD split


Data loaded:
- Number of cells: 659782
- Input dimension: 11907
- Number of perturbations: 2253


In [3]:
from omnicell.constants import PERT_KEY, GENE_EMBEDDING_KEY, CONTROL_PERT, CELL_KEY

import scanpy as sc
import pandas as pd 
import numpy as np
import scipy

# After the existing imports, add:
from sklearn.linear_model import LinearRegression, Ridge, Lasso, ElasticNet
from sklearn.ensemble import RandomForestRegressor
from sklearn.svm import SVR
from sklearn.metrics import mean_squared_error, r2_score
from sklearn.decomposition import PCA
from pathlib import Path
from omnicell.models.mean_models.L4Regressor import L4Regressor
from sklearn.model_selection import GridSearchCV

from omnicell.models.utils.distribute_shift import sample_pert, get_proportional_weighted_dist


def fit_supervised_model(X, Y, model_type='linear', param_grid=None, **kwargs):
    """
    Fit a supervised model, optionally performing hyperparameter tuning with GridSearchCV.
    
    Args:
        X: Input features.
        Y: Target values.
        model_type: Type of model to fit.
        param_grid: Hyperparameters grid for GridSearchCV.
        **kwargs: Fixed parameters for the model.
        
    Returns:
        Fitted model, training MSE, R2 score.
    """
    models = {
        'linear': LinearRegression,
        'ridge': Ridge,
        'lasso': Lasso,
        'elastic_net': ElasticNet,
        'rf': RandomForestRegressor,
        'svr': SVR,
        'l4': L4Regressor
    }
    
    if model_type not in models:
        raise ValueError(f"Model type {model_type} not supported.")
    
    model_class = models[model_type]
    base_model = model_class(n_jobs=8, **kwargs)
    
    if param_grid:
        print(f"Performing hyperparameter tuning with {model_type}")
        print(f"Parameter grid: {param_grid}")
        grid_search = GridSearchCV(base_model, param_grid, cv=2, n_jobs = 8, scoring='neg_mean_squared_error')
        grid_search.fit(X, Y)
        best_model = grid_search.best_estimator_
        print(f"Best parameters: {grid_search.best_params_}")
    else:
        best_model = base_model.fit(X, Y)
    
    Y_pred = best_model.predict(X)
    mse = mean_squared_error(Y, Y_pred)
    r2 = r2_score(Y, Y_pred)
    return best_model, mse, r2

def compute_cell_type_means(adata, cell_type):
    """Compute perturbation effect embeddings for a specific cell type"""
    
    # Filter data for this cell type
    cell_type_data = adata[adata.obs[CELL_KEY] == cell_type]
    
    # Compute control mean for this cell type
    ctrl_mean = np.mean(
        cell_type_data[cell_type_data.obs[PERT_KEY] == CONTROL_PERT].X, axis=0
    )
    
    # Convert to dense array if sparse
    X = cell_type_data.X
    
    # Create dataframe
    df = pd.DataFrame(X, index=cell_type_data.obs.index)
    df['perturbation'] = cell_type_data.obs[PERT_KEY].values
    
    # Compute means per perturbation
    pert_means = df.groupby('perturbation').mean()
    
    # Compute deltas from control
    pert_deltas = pd.DataFrame(pert_means.values - ctrl_mean, index=pert_means.index)
    pert_deltas_dict = {
        pert: np.array(means) 
        for pert, means in pert_deltas.iterrows() 
        if pert != CONTROL_PERT
    }
    
    return ctrl_mean, pert_deltas_dict

class MeanPredictor():

    def __init__(self, model_config: dict, pert_embedding: dict):
        self.model = None
        self.model_type = model_config['model_type']
        self.pca_pert_embeddings = model_config['pca_pert_embeddings']
        self.pca_pert_embeddings_components = model_config['pca_pert_embeddings_components']
        self.pca_cell_embeddings_components = model_config['pca_cell_emb_components']
        self.pert_embedding = pert_embedding
        self.cell_embeddings = {}
        self.param_grid = model_config.get('param_grid', None)


    def train(self, adata: sc.AnnData, **kwargs):
        if self.pca_pert_embeddings:
            pca = PCA(n_components=self.pca_pert_embeddings_components)
            pert_emb_temp = pca.fit_transform(np.array(list(self.pert_embedding.values())))
            self.pert_embedding = {pert : pert_emb_temp[i] for i, pert in enumerate(self.pert_embedding.keys())}

        
        #Generating cell embeddings with PCA

        self.pca_model_cells = PCA(n_components=self.pca_cell_embeddings_components)

        X = adata.X
        print(f"Fitting PCA model for cells with shape {X.shape}")
        self.pca_model_cells.fit(X)

        print("Computing cell embeddings")
        for cell_type in adata.obs[CELL_KEY].unique():
            cell_data = adata[(adata.obs[CELL_KEY] == cell_type) & (adata.obs[PERT_KEY] == CONTROL_PERT)].X
            self.cell_embeddings[cell_type] = np.mean(self.pca_model_cells.transform(cell_data), axis=0)


        # Get unique cell types
        cell_types = adata.obs[CELL_KEY].unique()

        # Compute embeddings for each cell type
        Xs = []
        Ys = []

        #We append the data for each cell type, 
        for cell_type in cell_types:
            print(f"Creating training data for {cell_type}")
            ctrl_mean, pert_deltas_dict = compute_cell_type_means(adata, cell_type)
            
            
            # Get perturbation IDs for this cell type
            idxs = pert_deltas_dict.keys()
            
            # Create feature matrix X and target matrix Y
            Y = np.array([pert_deltas_dict[pert] for pert in idxs])
            print(Y.shape)



            X = np.array([np.concatenate([self.pert_embedding[g], self.cell_embeddings[cell_type]])  for g in idxs])

            print(f"X shape: {X.shape}")
            print(f"Y shape: {Y.shape}")
            # Store the embeddings
            Xs.append(X)
            Ys.append(Y)

        # Now you can train a model for each cell type
        X = np.concatenate(Xs)
        Y = np.concatenate(Ys)
        self.model, mse, r2 = fit_supervised_model(X, Y, model_type=self.model_type, param_grid=self.param_grid)
        
    def make_predict(self, adata: sc.AnnData, pert_id: str, cell_type: str) -> np.ndarray:
        ctrl_cells = adata[(adata.obs[PERT_KEY] == CONTROL_PERT) & (adata.obs[CELL_KEY] == cell_type)].X

        cell_embedding = np.mean(self.pca_model_cells.transform(ctrl_cells), axis=0)

        print(f"cell_embedding shape: {cell_embedding.shape}")
        X_new = np.concatenate([self.pert_embedding[pert_id], cell_embedding])
        print(f"X_new shape: {X_new.shape}")
        X_new = X_new.reshape(1, -1)
        print(f"X_new reshaped shape: {X_new.shape}")
        
        mean_shift_pred = np.array(self.model.predict(X_new)).flatten().astype(np.float32)
        weighted_dist = get_proportional_weighted_dist(ctrl_cells)
        samples = sample_pert(ctrl_cells, weighted_dist, mean_shift_pred)
        return samples


In [4]:
model_config = {
    'model_type': 'rf',
    'pca_pert_embeddings': True,
    'pca_pert_embeddings_components': 10,
    'pca_cell_emb_components': 50,
    'param_grid':
    {
        'n_estimators' : [100, 200, 400]
    }
    


}

model = MeanPredictor(model_config, gene_emb)


In [5]:
adata_small = adata[:1000]

In [6]:


model.train(adata_small)

Fitting PCA model for cells with shape (1000, 11907)
Computing cell embeddings
Creating training data for hepg2


  pert_means = df.groupby('perturbation').mean()


(705, 11907)
X shape: (705, 60)
Y shape: (705, 11907)
Performing hyperparameter tuning with rf
Parameter grid: {'n_estimators': [100, 200, 400]}
Best parameters: {'n_estimators': 100}


In [7]:
adata_pert = adata[(adata.obs[PERT_KEY] == 'ALDOA') & (adata.obs[CELL_KEY] == 'hepg2')].X


model.make_predict(adata_small, "ALDOA", 'hepg2')

cell_embedding shape: (50,)
X_new shape: (60,)
X_new reshaped shape: (1, 60)


array([[ 0.,  0.,  2., ...,  0.,  1.,  0.],
       [ 0.,  3., 22., ...,  0.,  0.,  0.],
       [ 0.,  0.,  0., ...,  0.,  0.,  0.],
       ...,
       [ 0.,  2., 29., ...,  1.,  3.,  0.],
       [ 0.,  0.,  3., ...,  0.,  3.,  0.],
       [ 0.,  0.,  4., ...,  0.,  2.,  0.]], dtype=float32)

In [8]:
import numpy as np

logger.info("Running evaluation")

# evaluate each pair of cells and perts
eval_dict = {}
for cell_id, pert_id, ctrl_data, gt_data in loader.get_eval_data():
    logger.debug(f"Making predictions for cell: {cell_id}, pert: {pert_id}")

    preds = model.make_predict(ctrl_data, pert_id, cell_id)
    eval_dict[(cell_id, pert_id)] = (ctrl_data.X.toarray(), gt_data.X.toarray(), preds)
    break
    
if not config.etl_config.log1p:
    for (cell, pert) in eval_dict:  
        ctrl_data, gt_data, pred_pert = eval_dict[(cell, pert)]
        # normalize to sum to 1
        ctrl_data = ctrl_data / ctrl_data.sum(axis=1).reshape(-1, 1) * 10_000
        gt_data = gt_data / gt_data.sum(axis=1).reshape(-1, 1) * 10_000
        pred_pert = pred_pert / pred_pert.sum(axis=1).reshape(-1, 1) * 10_000
        eval_dict[(cell, pert)] =  (np.log1p(ctrl_data), np.log1p(gt_data), np.log1p(pred_pert))


2025-02-20 15:02:11,942 - INFO - Running evaluation
2025-02-20 15:02:11,944 - INFO - Loading data catalogue from /orcd/data/omarabu/001/opitcho/omnicell/configs/catalogue/FrangiehIzar2021_protein.yaml
2025-02-20 15:02:11,948 - INFO - Loading data catalogue from /orcd/data/omarabu/001/opitcho/omnicell/configs/catalogue/PapalexiSatija2021_eccite_arrayed_RNA.yaml
2025-02-20 15:02:11,952 - INFO - Loading data catalogue from /orcd/data/omarabu/001/opitcho/omnicell/configs/catalogue/AdamsonWeissman2016_GSM2406677_10X005.yaml
2025-02-20 15:02:11,955 - INFO - Loading data catalogue from /orcd/data/omarabu/001/opitcho/omnicell/configs/catalogue/PapalexiSatija2021_eccite_protein.yaml
2025-02-20 15:02:11,958 - INFO - Loading data catalogue from /orcd/data/omarabu/001/opitcho/omnicell/configs/catalogue/SchraivogelSteinmetz2020_TAP_SCREEN__chromosome_11_screen.yaml
2025-02-20 15:02:11,961 - INFO - Loading data catalogue from /orcd/data/omarabu/001/opitcho/omnicell/configs/catalogue/DatlingerBock202

cell_embedding shape: (50,)
X_new shape: (60,)
X_new reshaped shape: (1, 60)


In [9]:
import scanpy as sc
from omnicell.evaluation.utils import get_DEGs, get_eval, get_DEG_Coverage_Recall, get_DEGs_overlaps
pval_threshold = 0.05
log_fold_change_threshold = 0.0

results_dict = {}
for (cell, pert) in eval_dict:  
    ctrl_data, gt_data, pred_pert = eval_dict[(cell, pert)]

    pred_pert = sc.AnnData(X=pred_pert)
    true_pert = sc.AnnData(X=gt_data)
    control = sc.AnnData(X=ctrl_data)

    logger.debug(f"Getting ground Truth DEGs for {pert} and {cell}")
    true_DEGs_df = get_DEGs(control, true_pert)
    signif_true_DEG = true_DEGs_df[true_DEGs_df['pvals_adj'] < pval_threshold]

    logger.debug(f"Number of significant DEGS from ground truth: {signif_true_DEG.shape[0]}")

    logger.debug(f"Getting predicted DEGs for {pert} and {cell}")
    pred_DEGs_df = get_DEGs(control, pred_pert)


    logger.debug(f"Getting evaluation metrics for {pert} and {cell}")
    r2_and_mse = get_eval(control, true_pert, pred_pert, true_DEGs_df, [100,50,20], pval_threshold, log_fold_change_threshold)

    logger.debug(f"Getting DEG overlaps for {pert} and {cell}")
    DEGs_overlaps = get_DEGs_overlaps(true_DEGs_df, pred_DEGs_df, [100,50,20], pval_threshold, log_fold_change_threshold)

    results_dict[(cell, pert)] = (r2_and_mse, DEGs_overlaps)



  utils.warn_names_duplicates("obs")
  scores[group_index, :] = (
  utils.warn_names_duplicates("obs")


KeyboardInterrupt: 

In [8]:
for (cell, pert) in results_dict:
    r2_and_mse, DEGs_overlaps = results_dict[(cell, pert)]
    print(f"Cell: {cell}, Pert: {pert}")
    # print(f"R2 and MSE: {r2_and_mse}")
    print(f"DEGs Overlaps: {DEGs_overlaps}")
    print("-"*100)

Cell: k562, Pert: RPL15
DEGs Overlaps: {'Overlap_in_top_2157_DEGs': 988, 'Overlap_in_top_100_DEGs': 33, 'Overlap_in_top_50_DEGs': 12, 'Overlap_in_top_20_DEGs': 2, 'Jaccard': 0.25747165922363446}
----------------------------------------------------------------------------------------------------


In [11]:
r2_and_mse


{'all_genes_mean_sub_diff_R': np.float32(0.9394147),
 'all_genes_mean_sub_diff_R2': np.float32(0.88249993),
 'all_genes_mean_sub_diff_MSE': np.float32(0.0017594295),
 'all_genes_mean_fold_diff_R': np.float32(0.9819075),
 'all_genes_mean_fold_diff_R2': np.float32(0.9641423),
 'all_genes_mean_fold_diff_MSE': np.float32(0.10515996),
 'all_genes_mean_R': np.float32(0.99663806),
 'all_genes_mean_R2': np.float32(0.99328744),
 'all_genes_mean_MSE': np.float32(0.0017594295),
 'all_genes_var_R': np.float32(0.914543),
 'all_genes_var_R2': np.float32(0.8363888),
 'all_genes_var_MSE': np.float32(0.0018092245),
 'all_genes_corr_mtx_R': np.float64(0.17428146307689132),
 'all_genes_corr_mtx_R2': np.float64(0.030374028372221834),
 'all_genes_corr_mtx_MSE': np.float64(0.007376010796250857),
 'all_genes_cov_mtx_R': np.float64(0.23831284114404422),
 'all_genes_cov_mtx_R2': np.float64(0.05679301025414645),
 'all_genes_cov_mtx_MSE': np.float64(0.00021684922227372595),
 'Top_2157_DEGs_sub_diff_R': np.float3