In [1]:
import sys
import yaml
import torch
import logging
from pathlib import Path

# Add the path to the directory containing the omnicell package
# Assuming the omnicell package is in the parent directory of your notebook
sys.path.append('..')  # Adjust this path as ne
import yaml
import torch
import logging
from pathlib import Path
from omnicell.config.config import Config, ETLConfig, ModelConfig, DatasplitConfig, EvalConfig, EmbeddingConfig
from omnicell.data.loader import DataLoader
from omnicell.constants import PERT_KEY, GENE_EMBEDDING_KEY, CONTROL_PERT
from train import get_model

# Set up logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

# Configure paths
MODEL_CONFIG = ModelConfig.from_yaml("/orcd/data/omarabu/001/njwfish/omnicell/configs/models/scot/proportional_scot.yaml")
ETL_CONFIG = ETLConfig(name = "no_preprocessing", log1p = False, drop_unmatched_perts = True)
EMBEDDING_CONFIG = EmbeddingConfig(pert_embedding='GenePT')

SPLIT_CONFIG = DatasplitConfig.from_yaml("/orcd/data/omarabu/001/njwfish/omnicell/configs/splits/repogle_k562_essential_raw/random_splits/rs_accP_k562_ood_ss:ns_20_2_most_pert_0.1/split_0/split_config.yaml")
EVAL_CONFIG = EvalConfig.from_yaml("/orcd/data/omarabu/001/njwfish/omnicell/configs/splits/repogle_k562_essential_raw/random_splits/rs_accP_k562_ood_ss:ns_20_2_most_pert_0.1/split_0/eval_config.yaml")  # Set this if you want to run evaluations

# SPLIT_CONFIG = DatasplitConfig.from_yaml("/orcd/data/omarabu/001/njwfish/omnicell/configs/splits/satija_IFNB_raw/random_splits/acrossC_ood_ss:10/split_A549/split_config.yaml")
# EVAL_CONFIG = EvalConfig.from_yaml("/orcd/data/omarabu/001/njwfish/omnicell/configs/splits/satija_IFNB_raw/random_splits/acrossC_ood_ss:10/split_A549/eval_config.yaml")  # Set this if you want to run evaluations

# Load configurations
config = Config(model_config=MODEL_CONFIG,
                 etl_config=ETL_CONFIG, 
                 datasplit_config=SPLIT_CONFIG, 
                 eval_config=EVAL_CONFIG)


#Alternatively you can initialize the config objects manually as follows:
# etl_config = ETLConfig(name = XXX, log1p = False, drop_unmatched_perts = False, ...)
# model_config = ...
# embedding_config = ...
# datasplit_config = ...
# eval_config = ...
# config = Config(etl_config, model_config, datasplit_config, eval_config)


config.etl_config.pert_embedding = 'bioBERT'
config.etl_config.drop_unmatched_perts = True
# Set up device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# Initialize data loader and load training data
loader = DataLoader(config)
adata, pert_rep_map = loader.get_training_data()

# Get dimensions and perturbation IDs
input_dim = adata.obsm['embedding'].shape[1]
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
pert_ids = adata.obs[PERT_KEY].unique()
gene_emb_dim = adata.varm[GENE_EMBEDDING_KEY].shape[1] if GENE_EMBEDDING_KEY in adata.varm else None

print(f"Data loaded:")
print(f"- Number of cells: {adata.shape[0]}")
print(f"- Input dimension: {input_dim}")
print(f"- Number of perturbations: {len(pert_ids)}")


2025-02-02 12:24:41,360 - INFO - Loading data catalogue from /orcd/data/omarabu/001/njwfish/omnicell/configs/catalogue/repogle_k562_essential_raw.yaml
2025-02-02 12:24:41,362 - INFO - Loading data catalogue from /orcd/data/omarabu/001/njwfish/omnicell/configs/catalogue/satija_IFNB_raw.yaml
2025-02-02 12:24:41,363 - INFO - Loading data catalogue from /orcd/data/omarabu/001/njwfish/omnicell/configs/catalogue/adamson_INCOMPLETE.yaml
2025-02-02 12:24:41,365 - INFO - Loading data catalogue from /orcd/data/omarabu/001/njwfish/omnicell/configs/catalogue/satija_IFNB_HVG.yaml
2025-02-02 12:24:41,366 - INFO - Loading data catalogue from /orcd/data/omarabu/001/njwfish/omnicell/configs/catalogue/kang.yaml
2025-02-02 12:24:41,367 - INFO - Loading data catalogue from /orcd/data/omarabu/001/njwfish/omnicell/configs/catalogue/essential_gene_knockouts_raw.yaml
2025-02-02 12:24:41,368 - INFO - Loading data catalogue from /orcd/data/omarabu/001/njwfish/omnicell/configs/catalogue/satija_IFNG_raw_INCOMPLET

Using device: cuda


2025-02-02 12:24:44,787 - INFO - Loaded unpreprocessed data, # of data points: 310385, # of genes: 8563.
2025-02-02 12:24:44,788 - INFO - Preprocessing training data
2025-02-02 12:24:44,789 - INFO - Using identity features for perturbations
2025-02-02 12:24:44,911 - INFO - Removing observations with perturbations not in the dataset as a column
2025-02-02 12:24:45,109 - INFO - Removed 189 perturbations that were not in the dataset columns and 0 perturbations that did not have an embedding for a total of 189 perturbations removed out of an initial 2058 perturbations
  adata.obsm["embedding"] = adata.X.toarray().astype('float32')
2025-02-02 12:25:22,401 - INFO - Doing OOD split


Data loaded:
- Number of cells: 279630
- Input dimension: 8563
- Number of perturbations: 1850


In [2]:
import numpy as np
from omnicell.models.datamodules import get_dataloader
dset, ns, dl = get_dataloader(adata, pert_ids=np.array(adata.obs[PERT_KEY].values), pert_map=pert_rep_map, collate='ot')


['k562', 'k562', 'k562', 'k562', 'k562', ..., 'k562', 'k562', 'k562', 'k562', 'k562']
Length: 10691
Categories (1, object): ['k562'] ['k562'] 1
Creating source indices
Creating target indices
Creating pert indices
Creating source and target dicts
Strata probs [1.85915765e-05 1.85915765e-05 1.85915765e-05 ... 3.62163911e-03
 3.68856878e-03 7.42175735e-03]


In [8]:
from omnicell.models.distribute_shift_numpy import sample_pert, get_proportional_weighted_dist

synthetic_counterfactuals = {}
for stratum in dset.strata:
    synthetic_counterfactuals[stratum] = {}
    X_ctrl = dset.source[stratum]
    mean_ctrl = X_ctrl.mean(axis=0)
    weighted_dist = get_proportional_weighted_dist(X_ctrl)
    for i, pert in enumerate(dset.unique_pert_ids):
        if i % 1 == 0:
            pert_start = time.time()
            # print(f"{i} / {len(dset.unique_pert_ids)}")
        
        X_pert = dset.target[stratum][pert]
        mean_pert = X_pert.mean(axis=0)
        mean_shift = mean_pert - mean_ctrl

        preds = sample_pert(X_ctrl, weighted_dist, mean_shift, max_rejections=100)
        if i % 1 == 0:
            pert_time = time.time() - pert_start
            print(f"Perturbation {i} took: {pert_time:.2f}s")
        synthetic_counterfactuals[stratum][pert] = preds

Perturbation 0 took: 5.06s
Perturbation 1 took: 5.17s


  p = p / p.sum()


ValueError: pvals < 0, pvals > 1 or pvals contains NaNs

In [3]:
import pickle
import time
from omnicell.models.distribute_shift_numpy import (
    sample_pert, 
    get_proportional_weighted_dist
)

num_ctrls = dset.source['k562'].shape[0]
batch_size = 512

# Add overall timing
total_start_time = time.time()

for i in range(0, num_ctrls, batch_size):   
    iteration_start_time = time.time()
    source_batch = {} 
    synthetic_counterfactual_batch = {}
    
    for stratum in dset.strata:
        source_batch[stratum] = X_ctrl = dset.source[stratum][i:i+batch_size]
        synthetic_counterfactual_batch[stratum] = {}

        mean_ctrl = X_ctrl.mean(axis=0)
        
        # Time the weighted dist calculation
        dist_start = time.time()
        weighted_dist = get_proportional_weighted_dist(X_ctrl)
        weighted_dist = weighted_dist.astype(np.float64)
        s = weighted_dist.sum(axis=0)
        weighted_dist[:, s > 0] /= s[s > 0]
        
        dist_time = time.time() - dist_start
        print(f"Weighted dist calculation took: {dist_time:.2f}s")

        for j, pert in enumerate(dset.unique_pert_ids):
            if j % 10 == 0:
                pert_start = time.time()
                print(f"{i} / {num_ctrls}, {j} / {len(dset.unique_pert_ids)}")
            
            X_pert = dset.target[stratum][pert]
            mean_pert = X_pert.mean(axis=0)
            mean_shift = mean_pert - mean_ctrl
            
            # Time the sample_pert call
            preds = sample_pert(
                X_ctrl, 
                weighted_dist, 
                mean_shift, 
                max_rejections=100, 
                # num_threads=2
            )
            
            synthetic_counterfactual_batch[stratum][pert] = preds.astype(np.int16)
            
            if (j + 1) % 10 == 0:
                pert_time = time.time() - pert_start
                print(f"Perturbation {j} took: {pert_time:.2f}s")
        
    # Save timing data along with results
    data_dict = {
        'synthetic_counterfactuals': synthetic_counterfactual_batch,
        'source': source_batch,
        'unique_pert_ids': dset.unique_pert_ids,
        'strata': dset.strata,
    }

    with open(f'/orcd/data/omarabu/001/Omnicell_datasets/repogle_k562_essential_raw/proportional_scot/synthetic_counterfactuals_{i}.pkl', 'wb') as f:
        pickle.dump(data_dict, f)



Weighted dist calculation took: 0.06s
0 / 10691, 0 / 1849


  p = p / p.sum()


ValueError: pvals < 0, pvals > 1 or pvals contains NaNs

In [30]:
import pickle
import time
from omnicell.models.distribute_shift import (
    get_proportional_weighted_dist,
    sample_multinomial_batch,
    sample_pert
)

batch_size = 512
num_ctrls = dset.source['k562'].shape[0]

# Add overall timing
total_start_time = time.time()
iteration_times = []
pert_times = []

for i in range(0, num_ctrls, batch_size):
    iteration_start_time = time.time()
    source_batch = {} 
    synthetic_counterfactual_batch = {}
    
    for stratum in dset.strata:
        source_batch[stratum] = X_ctrl = dset.source[stratum][i:i+512]
        synthetic_counterfactual_batch[stratum] = {}

        mean_ctrl = X_ctrl.mean(axis=0)
        
        # Time the weighted dist calculation
        dist_start = time.time()
        weighted_dist = get_proportional_weighted_dist(X_ctrl)
        dist_time = time.time() - dist_start
        print(f"Weighted dist calculation took: {dist_time:.2f}s")

        for j, pert in enumerate(dset.unique_pert_ids):
            pert_start = time.time()
            
            print(f"{i} / {num_ctrls}, {j} / {len(dset.unique_pert_ids)}")
            
            X_pert = dset.target[stratum][pert]
            mean_pert = X_pert.mean(axis=0)
            mean_shift = mean_pert - mean_ctrl
            
            # Time the sample_pert call
            sample_start = time.time()
            num_threads = 2
            min_block_size = X_ctrl.shape[1] // num_threads + 1
            preds = sample_pert(X_ctrl, weighted_dist, mean_shift, 
                              max_rejections=100, min_block_size=min_block_size, num_threads=num_threads)
            sample_time = time.time() - sample_start
            
            synthetic_counterfactual_batch[stratum][pert] = preds.astype(np.int16)
            
            pert_time = time.time() - pert_start
            pert_times.append({
                'iteration': i,
                'pert_idx': j,
                'total_time': pert_time,
                'sample_time': sample_time
            })
            
            print(f"Perturbation {j} took: {pert_time:.2f}s (sampling: {sample_time:.2f}s)")
        
    # Save timing data along with results
    data_dict = {
        'synthetic_counterfactuals': synthetic_counterfactual_batch,
        'source': source_batch,
        'unique_pert_ids': dset.unique_pert_ids,
        'strata': dset.strata,
        'timing': {
            'pert_times': pert_times,
            'iteration_time': time.time() - iteration_start_time
        }
    }

    iteration_times.append(time.time() - iteration_start_time)
    print(f"\nIteration {i} took: {iteration_times[-1]:.2f}s")
    print(f"Average perturbation time: {np.mean([p['total_time'] for p in pert_times[-len(dset.unique_pert_ids):]]):.2f}s")
    print(f"Average sampling time: {np.mean([p['sample_time'] for p in pert_times[-len(dset.unique_pert_ids):]]):.2f}s")

    with open(f'/orcd/data/omarabu/001/Omnicell_datasets/repogle_k562_essential_raw/proportional_scot/synthetic_counterfactuals_{i}.pkl', 'wb') as f:
        pickle.dump(data_dict, f)



Weighted dist calculation took: 0.02s
0 / 10691, 0 / 1849
Perturbation 0 took: 0.18s (sampling: 0.17s)
0 / 10691, 1 / 1849
Perturbation 1 took: 0.16s (sampling: 0.16s)
0 / 10691, 2 / 1849
Perturbation 2 took: 0.25s (sampling: 0.25s)
0 / 10691, 3 / 1849
Perturbation 3 took: 0.18s (sampling: 0.18s)
0 / 10691, 4 / 1849
Perturbation 4 took: 0.18s (sampling: 0.18s)
0 / 10691, 5 / 1849
Perturbation 5 took: 0.16s (sampling: 0.16s)
0 / 10691, 6 / 1849
Perturbation 6 took: 0.17s (sampling: 0.17s)
0 / 10691, 7 / 1849
Perturbation 7 took: 0.16s (sampling: 0.15s)
0 / 10691, 8 / 1849
Perturbation 8 took: 0.18s (sampling: 0.18s)
0 / 10691, 9 / 1849
Perturbation 9 took: 0.17s (sampling: 0.16s)
0 / 10691, 10 / 1849
Perturbation 10 took: 0.20s (sampling: 0.20s)
0 / 10691, 11 / 1849
Perturbation 11 took: 0.18s (sampling: 0.17s)
0 / 10691, 12 / 1849
Perturbation 12 took: 0.18s (sampling: 0.18s)
0 / 10691, 13 / 1849
Perturbation 13 took: 0.17s (sampling: 0.16s)
0 / 10691, 14 / 1849
Perturbation 14 took: 0

KeyboardInterrupt: 

In [24]:
import pickle
import time
from omnicell.models.distribute_shift_numpy import (
    get_proportional_weighted_dist,
    sample_multinomial_batch,
    sample_pert
)

batch_size = 512
num_ctrls = dset.source['k562'].shape[0]

# Add overall timing
total_start_time = time.time()
iteration_times = []
pert_times = []

for i in range(0, num_ctrls, batch_size):
    iteration_start_time = time.time()
    source_batch = {} 
    synthetic_counterfactual_batch = {}
    
    for stratum in dset.strata:
        source_batch[stratum] = X_ctrl = dset.source[stratum][i:i+512]
        synthetic_counterfactual_batch[stratum] = {}

        mean_ctrl = X_ctrl.mean(axis=0)
        
        # Time the weighted dist calculation
        dist_start = time.time()
        weighted_dist = get_proportional_weighted_dist(X_ctrl)
        dist_time = time.time() - dist_start
        print(f"Weighted dist calculation took: {dist_time:.2f}s")

        for j, pert in enumerate(dset.unique_pert_ids):
            pert_start = time.time()
            
            print(f"{i} / {num_ctrls}, {j} / {len(dset.unique_pert_ids)}")
            
            X_pert = dset.target[stratum][pert]
            mean_pert = X_pert.mean(axis=0)
            mean_shift = mean_pert - mean_ctrl
            
            # Time the sample_pert call
            sample_start = time.time()
            preds = sample_pert(X_ctrl, weighted_dist, mean_shift, 
                              max_rejections=100)
            sample_time = time.time() - sample_start
            
            synthetic_counterfactual_batch[stratum][pert] = preds.astype(np.int16)
            
            pert_time = time.time() - pert_start
            pert_times.append({
                'iteration': i,
                'pert_idx': j,
                'total_time': pert_time,
                'sample_time': sample_time
            })
            
            print(f"Perturbation {j} took: {pert_time:.2f}s (sampling: {sample_time:.2f}s)")
        
    # Save timing data along with results
    data_dict = {
        'synthetic_counterfactuals': synthetic_counterfactual_batch,
        'source': source_batch,
        'unique_pert_ids': dset.unique_pert_ids,
        'strata': dset.strata,
        'timing': {
            'pert_times': pert_times,
            'iteration_time': time.time() - iteration_start_time
        }
    }

    iteration_times.append(time.time() - iteration_start_time)
    print(f"\nIteration {i} took: {iteration_times[-1]:.2f}s")
    print(f"Average perturbation time: {np.mean([p['total_time'] for p in pert_times[-len(dset.unique_pert_ids):]]):.2f}s")
    print(f"Average sampling time: {np.mean([p['sample_time'] for p in pert_times[-len(dset.unique_pert_ids):]]):.2f}s")

    with open(f'/orcd/data/omarabu/001/Omnicell_datasets/repogle_k562_essential_raw/proportional_scot/synthetic_counterfactuals_{i}.pkl', 'wb') as f:
        pickle.dump(data_dict, f)


Weighted dist calculation took: 0.02s
0 / 10691, 0 / 1849
Perturbation 0 took: 0.29s (sampling: 0.29s)
0 / 10691, 1 / 1849
Perturbation 1 took: 0.29s (sampling: 0.29s)
0 / 10691, 2 / 1849


  p = p / p.sum()


ValueError: pvals < 0, pvals > 1 or pvals contains NaNs

In [None]:
import pickle
from omnicell.models.distribute_shift import sample_pert, get_proportional_weighted_dist

batch_size = 512
num_batches = dset.source['k562'].shape[0] // batch_size

for i in range(0, dset.source['k562'].shape[0], batch_size):
    source_batch = {} 
    synthetic_counterfactual_batch = {}
    for stratum in dset.strata:
        source_batch[stratum] = dset.source[stratum][i:i+512]
        synthetic_counterfactual_batch[stratum] = {}

        for j, pert in enumerate(dset.unique_pert_ids):
            
            print(f"{i} / {num_batches}, {j} / {len(dset.unique_pert_ids)}")
            X_ctrl = source_batch[stratum]
            X_pert = dset.target[stratum][pert]
            mean_ctrl = X_ctrl.mean(axis=0)
            mean_pert = X_pert.mean(axis=0)
            
            mean_shift = mean_pert - mean_ctrl
            logger.debug(f"Mean shift shape: {mean_shift.shape}")

            weighted_dist = get_proportional_weighted_dist(X_ctrl)
            
            preds = sample_pert(X_ctrl, weighted_dist, mean_shift, max_rejections=100)
            synthetic_counterfactual_batch[stratum][pert] = preds.astype(np.int16)
        
    # Save as before
    data_dict = {
        'synthetic_counterfactuals': synthetic_counterfactual_batch,
        'source': source_batch,
        'unique_pert_ids': dset.unique_pert_ids,
        'strata': dset.strata
    }

    with open(f'/orcd/data/omarabu/001/Omnicell_datasets/repogle_k562_essential_raw/proportional_scot/synthetic_counterfactuals_{i}.pkl', 'wb') as f:
        pickle.dump(data_dict, f)

In [24]:
# convert to synthetic_counterfactuals, source, and target to np.int16, also convert to sparse
# from scipy import sparse

# Convert synthetic counterfactuals to sparse
synthetic_counterfactuals = {
    stratum: {
        pert: preds.astype(np.int16)
        for pert, preds in synthetic_counterfactuals[stratum].items()
    } 
    for stratum in synthetic_counterfactuals
}

# Convert source to sparse
dset.source = {
    stratum: source.astype(np.int16)
    for stratum, source in dset.source.items()
}

# Convert target to sparse
dset.target = {
    stratum: {
        pert: preds.astype(np.int16)
        for pert, preds in dset.target[stratum].items()
    } 
    for stratum in dset.target
}


In [None]:
# Saving
import pickle
# Save as before
data_dict = {
    'synthetic_counterfactuals': synthetic_counterfactuals,
    'source': dset.source,
    'target': dset.target,
    'unique_pert_ids': dset.unique_pert_ids,
    'strata': dset.strata
}
with open('/orcd/data/omarabu/001/Omnicell_datasets/repogle_k562_essential_raw/proportional_scot/synthetic_counterfactuals.pkl', 'wb') as f:
    pickle.dump(data_dict, f)




In [1]:
# Loading
import pickle
with open('/orcd/data/omarabu/001/Omnicell_datasets/repogle_k562_essential_raw/proportional_scot/synthetic_counterfactuals.pkl', 'rb') as f:
    data_dict = pickle.load(f)
    
synthetic_counterfactuals = data_dict['synthetic_counterfactuals']
source = data_dict['source']
target = data_dict['target']
unique_pert_ids = data_dict['unique_pert_ids']
strata = data_dict['strata']