In [1]:
import sys
import yaml
import torch
import logging
from pathlib import Path

# Add the path to the directory containing the omnicell package
# Assuming the omnicell package is in the parent directory of your notebook
sys.path.append('..')  # Adjust this path as needed

import yaml
import torch
import logging
from pathlib import Path
from omnicell.config.config import Config, ETLConfig, ModelConfig, DatasplitConfig, EvalConfig, EmbeddingConfig
from omnicell.data.loader import DataLoader
from omnicell.constants import PERT_KEY, GENE_EMBEDDING_KEY, CONTROL_PERT
from omnicell.models.selector import load_model as get_model

# Set up logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

# Configure paths
MODEL_CONFIG = ModelConfig.from_yaml("/home/jason497/omnicell/configs/models/autoencoder.yaml")
ETL_CONFIG = ETLConfig(name = "no_preprocessing", log1p = False, drop_unmatched_perts = True)
EMBEDDING_CONFIG = EmbeddingConfig(pert_embedding='GenePT')

SPLIT_CONFIG = DatasplitConfig.from_yaml("/orcd/data/omarabu/001/njwfish/omnicell/configs/splits/repogle_k562_essential_raw/random_splits/rs_accP_k562_ood_ss:ns_20_2_most_pert_0.1/split_0/split_config.yaml")
#SPLIT_CONFIG = DatasplitConfig.from_yaml("/orcd/data/omarabu/001/njwfish/omnicell/configs/splits/repogle_k562_essential_raw/random_splits/rs_accP_k562_ood_ss:ns_20_2_most_pert_0.1/split_1/split_config.yaml")
EVAL_CONFIG = EvalConfig.from_yaml("/orcd/data/omarabu/001/njwfish/omnicell/configs/splits/repogle_k562_essential_raw/random_splits/rs_accP_k562_ood_ss:ns_20_2_most_pert_0.1/split_0/eval_config.yaml")  # Set this if you want to run evaluations
#EVAL_CONFIG = EvalConfig.from_yaml("/orcd/data/omarabu/001/njwfish/omnicell/configs/splits/repogle_k562_essential_raw/random_splits/rs_accP_k562_ood_ss:ns_20_2_most_pert_0.1/split_1/eval_config.yaml")

# Load configurations
config = Config(model_config=MODEL_CONFIG,
                 etl_config=ETL_CONFIG, 
                 datasplit_config=SPLIT_CONFIG, 
                 eval_config=EVAL_CONFIG)


#Alternatively you can initialize the config objects manually as follows:
# etl_config = ETLConfig(name = XXX, log1p = False, drop_unmatched_perts = False, ...)
# model_config = ...
# embedding_config = ...
# datasplit_config = ...
# eval_config = ...
# config = Config(etl_config, model_config, datasplit_config, eval_config)

# Set up device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# Initialize data loader and load training data
loader = DataLoader(config)
adata, pert_rep_map = loader.get_training_data()

# Get dimensions and perturbation IDs
input_dim = adata.shape[1]
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
pert_ids = adata.obs[PERT_KEY].unique()
gene_emb_dim = adata.varm[GENE_EMBEDDING_KEY].shape[1] if GENE_EMBEDDING_KEY in adata.varm else None

print(f"Data loaded:")
print(f"- Number of cells: {adata.shape[0]}")
print(f"- Input dimension: {input_dim}")
print(f"- Number of perturbations: {len(pert_ids)}")
# get index of pert in adata.var_names
pert_list = adata.var_names.values.tolist()
pert_rep_map_idxs = {pert: pert_list.index(pert) for pert in adata.obs[PERT_KEY].unique() if pert != CONTROL_PERT}

2025-02-20 08:46:45,079 - INFO - Loading data catalogue from /orcd/data/omarabu/001/njwfish/omnicell/configs/catalogue/DatlingerBock2017.yaml
2025-02-20 08:46:45,081 - INFO - Loading data catalogue from /orcd/data/omarabu/001/njwfish/omnicell/configs/catalogue/SchraivogelSteinmetz2020_TAP_SCREEN__chromosome_11_screen.yaml
2025-02-20 08:46:45,083 - INFO - Loading data catalogue from /orcd/data/omarabu/001/njwfish/omnicell/configs/catalogue/ReplogleWeissman2022_K562_essential.yaml
2025-02-20 08:46:45,084 - INFO - Loading data catalogue from /orcd/data/omarabu/001/njwfish/omnicell/configs/catalogue/GasperiniShendure2019_atscale.yaml
2025-02-20 08:46:45,086 - INFO - Loading data catalogue from /orcd/data/omarabu/001/njwfish/omnicell/configs/catalogue/TianKampmann2021_CRISPRa.yaml
2025-02-20 08:46:45,087 - INFO - Loading data catalogue from /orcd/data/omarabu/001/njwfish/omnicell/configs/catalogue/repogle_k562_essential_raw.yaml
2025-02-20 08:46:45,088 - INFO - Loading data catalogue from /

Using device: cuda


2025-02-20 08:46:48,409 - INFO - Loaded unpreprocessed data, # of data points: 310385, # of genes: 8563.
2025-02-20 08:46:48,410 - INFO - Preprocessing training data
2025-02-20 08:46:48,411 - INFO - Using identity features for perturbations
2025-02-20 08:46:48,532 - INFO - Removing observations with perturbations not in the dataset as a column
2025-02-20 08:46:48,719 - INFO - Removed 189 perturbations that were not in the dataset columns and 0 perturbations that did not have an embedding for a total of 189 perturbations removed out of an initial 2058 perturbations
2025-02-20 08:47:17,709 - INFO - Doing OOD split


Data loaded:
- Number of cells: 279630
- Input dimension: 8563
- Number of perturbations: 1850


In [11]:
import os
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, random_split
import torch.optim as optim
import scanpy as sc
import anndata


import pandas as pd
from scipy.sparse import issparse
import scipy
import warnings
PERT_KEY = 'pert'
CELL_KEY = 'cell'
CONTROL_PERT = 'ctrl'
GENE_VAR_KEY = 'gene_name'

warnings.filterwarnings(
    "ignore",
    message="Observation names are not unique. To make them unique, call `.obs_names_make_unique`.",
    category=UserWarning
)



##############################################
# ADDED FOR KNN (Imports for building and vectorizing)
##############################################
from sklearn.neighbors import NearestNeighbors

def build_knn_indices(emb_tensor, k=10):
    """
    emb_tensor: torch.Tensor of shape [N, d], on CPU or GPU
    Returns: knn_list, a list of arrays where knn_list[i] is the (k) neighbors of i.
    """
    # Move to CPU NumPy for scikit-learn
    emb = emb_tensor.detach().cpu().numpy()  # shape [N, d]
    nbrs = NearestNeighbors(n_neighbors=k+1, algorithm='auto').fit(emb)
    distances, indices = nbrs.kneighbors(emb)
    
    # For each row i, indices[i,0] == i (the point itself)
    # So the kNN excluding the point itself is indices[i, 1..k]
    knn_list = [row[1:] for row in indices]
    return knn_list

def build_neighbor_idx(knn_list):
    """
    knn_list: list of length N, where knn_list[i] is a list/array of k neighbors for node i.
    Returns a LongTensor neighbor_idx of shape [N, k].
    """
    neighbor_tensors = []
    for nbrs in knn_list:
        nbrs_t = torch.tensor(nbrs, dtype=torch.long).unsqueeze(0)  # shape [1, k]
        neighbor_tensors.append(nbrs_t)
    neighbor_idx = torch.cat(neighbor_tensors, dim=0)  # shape [N, k]
    return neighbor_idx

def knn_pull_loss_vec(emb_tensor, neighbor_idx):
    """
    Vectorized pull loss that encourages each node i to be close to its k neighbors.
    emb_tensor: [N, d]
    neighbor_idx: [N, k]
    Returns: scalar 'pull' loss (MSE between each i and neighbor j).
    """
    device = emb_tensor.device
    N, d = emb_tensor.shape
    k = neighbor_idx.shape[1]
    
    # i_idx => shape [N, k]: repeated i for each neighbor
    i_idx = torch.arange(N, device=device).unsqueeze(1).expand(N, k)  # [N, k]
    
    # Gather embeddings
    x_i = emb_tensor[i_idx]              # shape [N, k, d]
    x_j = emb_tensor[neighbor_idx]       # shape [N, k, d]
    
    # Squared distances => [N, k, d]
    dist_sq = (x_i - x_j)**2
    # Sum over the last dimension => [N, k]
    dist_sum = dist_sq.sum(dim=2)
    # Mean over all pairs => scalar
    pull_loss = dist_sum.mean()
    return pull_loss

##############################################
# ADDED FOR NEGATIVE SAMPLING (Push)
##############################################
def sample_negatives(knn_list, N, num_neg=5000, max_tries=500000):
    """
    Sample 'num_neg' random (i, j) pairs that are:
      - i != j
      - j not in knn_list[i]
    We do random i, j until we collect 'num_neg' valid pairs or exceed 'max_tries'.
    
    Returns two LongTensors i_neg, j_neg of shape [num_neg].
    """
    # Build a set of neighbors for quick membership checks
    neighbors_set = []
    for i, nbrs in enumerate(knn_list):
        neighbors_set.append(set(nbrs.tolist()))
    
    i_neg_list = []
    j_neg_list = []
    tries = 0
    
    while len(i_neg_list) < num_neg and tries < max_tries:
        i_ = np.random.randint(0, N)
        j_ = np.random.randint(0, N)
        if j_ == i_:
            tries += 1
            continue
        if j_ in neighbors_set[i_]:
            tries += 1
            continue

        # valid negative pair
        i_neg_list.append(i_)
        j_neg_list.append(j_)
        tries += 1
    
    # If we didn't get enough pairs, replicate
    if len(i_neg_list) < num_neg:
        shortfall = num_neg - len(i_neg_list)
        i_neg_list += i_neg_list[:shortfall]
        j_neg_list += j_neg_list[:shortfall]
    
    i_neg_t = torch.tensor(i_neg_list[:num_neg], dtype=torch.long)
    j_neg_t = torch.tensor(j_neg_list[:num_neg], dtype=torch.long)
    return i_neg_t, j_neg_t

def neg_push_loss(emb_tensor, i_neg, j_neg, margin=1.0):
    """
    Margin-based push loss for negative pairs (i_neg, j_neg).
    If dist(x_i, x_j) < margin, we penalize => max(0, margin - dist).
    
    emb_tensor: [N, d]
    i_neg, j_neg: shape [num_neg]
    margin: float
    Returns a scalar push loss
    """
    device = emb_tensor.device
    x_i = emb_tensor[i_neg]  # shape [num_neg, d]
    x_j = emb_tensor[j_neg]  # shape [num_neg, d]
    
    # L2 distance
    dist = torch.sqrt(((x_i - x_j)**2).sum(dim=1) + 1e-8)  # shape [num_neg]
    
    # margin-based hinge
    push = F.relu(margin - dist)  # shape [num_neg]
    return push.mean()


class SinusoidalIntegerApprox(nn.Module):
    """
    Modified integer approximation that maintains non-zero gradients at integers
    while still encouraging integer-like outputs
    """
    def __init__(self, alpha=0.1):
        super().__init__()
        self.two_pi = 2 * np.pi
        self.alpha = alpha  # Controls how much gradient to maintain at integers

    def forward(self, x):
        # Original term encourages integer values
        integer_term = x - torch.sin(self.two_pi * x) / self.two_pi
        
        # Add small linear term to maintain gradient
        # This prevents gradient from going completely to zero at integers
        return (1 - self.alpha) * integer_term + self.alpha * x

class SmoothPositiveIntegerApprox(nn.Module):
    """
    Custom activation function that approximates positive integers with smooth tailing
    Uses a modified version of the sinusoidal approximation that:
    1. Smoothly approaches positive values using softplus
    2. Has diminishing oscillations at higher values
    3. Maintains gradients throughout the input range
    """
    def __init__(self, alpha=0.1, beta=2):
        super().__init__()
        self.two_pi = 2 * np.pi
        self.beta = beta  # Controls the sharpness of the softplus transition
        self.alpha = alpha

    def forward(self, x):
        # Smooth positive transformation using softplus
        x_pos = F.softplus(x, beta=self.beta)
        
        # Modified sinusoidal term with diminishing effect
        sin_term = torch.sin(self.two_pi * x_pos) / self.two_pi
        integer_term = x_pos - sin_term
        
        return (1 - self.alpha) * integer_term + self.alpha * x_pos
    
class autoencoder(nn.Module):
    """
    Architecture:
      1) control_encoder: (b, G) => (b, enc_dim)
      2) pert_embedding: (b,) => (b, enc_dim)
      3) gene_embedding: (G, enc_dim)
      4) For each gene => combine => decode to produce (b, G) for:
         - pred_ctrl
         - pred_delta
    """
    def __init__(
        self,
        model_config: dict,
        num_genes: int,
        enc_dim_cell: int = 340,
        enc_dim_pert: int = 80,
        hidden_enc_1: int = 500,
        hidden_dec_1: int = 340
    ):
        super().__init__()
        self.model=None
        self.num_genes = num_genes

        # Encode Control

        # Pert Embedding
        self.shared_embedding = nn.Embedding(num_genes, enc_dim_pert)

        # We'll decode for pred_ctrl and pred_delta using
        # a single hidden layer, but then 2 heads:
        #   head_ctrl => (hidden_dec_1->1)
        #   head_delta => (hidden_dec_1->1)
        self.hidden_layer = nn.Linear(2*enc_dim_pert+num_genes, hidden_enc_1)
        self.output_layer = nn.Linear(hidden_enc_1, 1)

    def forward(self, x_ctrl_log, whichpert_idx, multiplier):
        """
        Inputs:
          x_ctrl_log: (b, G)
          whichpert_idx: (b,)
        Returns:
          pred_ctrl:  (b, G)
          pred_delta: (b, G)
        """
        
        #forward
        b, G = x_ctrl_log.size()


        # Gene embeddings => shape(G, enc_dim_pert)
        gene_emb_all = self.shared_embedding.weight[:G]
        ge_expanded  = gene_emb_all.unsqueeze(0).expand(b, G, -1)
        pert_embed = self.shared_embedding(whichpert_idx)
        pert_embed_flat = pert_embed.unsqueeze(1).expand(b, G, -1)
        x_ctrl_log_flat = x_ctrl_log.unsqueeze(1).expand(b, G, -1)
        
        
        # Concat => shape(b, G, enc_dim_pert*2 + enc_dim_cell)
        dec_in = torch.cat([x_ctrl_log_flat, pert_embed_flat, ge_expanded], dim=2)
        # Flatten => shape(b*G, enc_dim_pert*2 + enc_dim_cell)
        dec_in_flat = dec_in.view(b*G, -1)
        
        x_hidden = F.relu(self.hidden_layer(dec_in_flat))
        pred_delta = F.relu(self.output_layer(x_hidden)).view(b, G)

        return pred_delta

    def log_normalized_mse_dual(self, pred_delta, true_ctrl, true_pert, normalize_to=10000):
        """
        Weighted MSE over:
           - pred_ctrl vs. true_ctrl
           - pred_delta vs. (true_pert - true_ctrl)
        Using log1p(...) of (counts / normalized_sum).
        """


        true_delta = true_pert - true_ctrl

        diff_sq_delta = (pred_delta - true_delta)**2

        loss_delta = diff_sq_delta.mean()

        return loss_delta

    def mse_dual(self, pred_ctrl, pred_delta, true_ctrl, true_pert):
        """
        Weighted MSE over:
          - pred_ctrl vs. true_ctrl
          - pred_delta vs. (true_pert - true_ctrl)
        """
        true_delta = true_pert - true_ctrl
        diff_sq_ctrl  = (pred_ctrl - true_ctrl)**2
        diff_sq_delta = (pred_delta - true_delta)**2

        loss_ctrl  = diff_sq_ctrl.mean()
        loss_delta = diff_sq_delta.mean()
        loss_total = 0.9*loss_ctrl + 0.1*loss_delta
        return loss_total, loss_ctrl.item(), loss_delta.item()

    def train(self, dl, lr=1e-3):
        """
        Simple custom training loop that runs a single epoch 
        (as per the original notebook logic).
        """
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.to(device)
        
        optimizer = optim.Adam(self.parameters(), lr=lr)
        start_epoch = 0
        num_epochs  = 1
        print_interval = 500
        
        KNN_K = 10
        KNN_EMB_LOSS_WEIGHT = 1e-2
        REBUILD_EVERY = 2
        NEG_LOSS_WEIGHT = 1e-2
        NEG_PAIRS = 7000
        MARGIN = 1.0
        
        neg_i = None
        neg_j = None
        knn_list = None
        neighbor_idx = None
        
        import time
        start_time = time.time()
        
        for epoch in range(start_epoch, num_epochs):
            # Rebuild adjacency (knn_list) + neighbor_idx + negative pairs
            if epoch % REBUILD_EVERY == 0:
                emb_mat = self.shared_embedding.weight  # shape [N, d]
                N = emb_mat.size(0)
        
                knn_list = build_knn_indices(emb_mat, k=KNN_K)
                neighbor_idx = build_neighbor_idx(knn_list).to(emb_mat.device)
        
                i_neg_t, j_neg_t = sample_negatives(knn_list, N, num_neg=NEG_PAIRS)
                neg_i = i_neg_t.to(emb_mat.device)
                neg_j = j_neg_t.to(emb_mat.device)
        
            # Reset accumulators for each printing interval
            running_loss_total = 0.0
            running_loss_ctrl  = 0.0
            running_loss_delta = 0.0
            running_loss_knn   = 0.0
            running_loss_neg   = 0.0
            num_samples = 0
        
            for batch_idx, (x_ctrl_batch_in, x_pert_batch_in, whichpert_batch_in) in enumerate(dl):
                x_ctrl_batch = x_ctrl_batch_in.to(device)
                x_pert_batch = x_pert_batch_in.to(device)
                whichpert_batch = torch.clone(whichpert_batch_in).to(device)
                
                optimizer.zero_grad()

        
                pred_delta = self.forward(x_ctrl_batch, whichpert_batch, multiplier=1)
        
                # Get losses from your dual loss function:
                loss_total = self.log_normalized_mse_dual(
                    pred_delta, x_ctrl_batch, x_pert_batch
                )
        
                # Compute the knn pull loss:
                if neighbor_idx is not None:
                    emb_mat = self.shared_embedding.weight
                    loss_knn = knn_pull_loss_vec(emb_mat, neighbor_idx)
                else:
                    loss_knn = 0.0
        
                # Compute the negative sampling push loss:
                if neg_i is not None and neg_j is not None:
                    emb_mat = self.shared_embedding.weight
                    dist_push = neg_push_loss(emb_mat, neg_i, neg_j, margin=MARGIN)
                    loss_neg = dist_push
                else:
                    loss_neg = 0.0
        
                # Combine losses into the final loss:
                loss_total_with_knn = loss_total
                if loss_knn != 0.0:
                    loss_total_with_knn += KNN_EMB_LOSS_WEIGHT * loss_knn
                if loss_neg != 0.0:
                    loss_total_with_knn += NEG_LOSS_WEIGHT * loss_neg
        
                loss_total_with_knn.backward()
                optimizer.step()
        
                bs = x_ctrl_batch.size(0)
                running_loss_total += loss_total.item() * bs
                running_loss_knn   += (loss_knn.item() if hasattr(loss_knn, 'item') else loss_knn) * bs
                running_loss_neg   += (loss_neg.item() if hasattr(loss_neg, 'item') else loss_neg) * bs
                num_samples += bs
        
                iteration = batch_idx + 1 + epoch * len(dl)
                if iteration % print_interval == 0:
                    curr_total = running_loss_total / num_samples
                    curr_ctrl  = running_loss_ctrl  / num_samples
                    curr_delta = running_loss_delta / num_samples
                    curr_knn   = running_loss_knn   / num_samples
                    curr_neg   = running_loss_neg   / num_samples
        
                    elapsed = time.time() - start_time
                    print(f"[Epoch {epoch+1}, Iter {batch_idx+1}] "
                          f"Train MSE Total={curr_total:.6f} | "
                          f"knn_loss={curr_knn:.6f} | neg_loss={curr_neg:.6f}   "
                          f"Time for last {print_interval} iters: {elapsed:.2f}s")
                    
                    # Reset the accumulators and timer
                    running_loss_total = 0.0
                    running_loss_ctrl  = 0.0
                    running_loss_delta = 0.0
                    running_loss_knn   = 0.0
                    running_loss_neg   = 0.0
                    num_samples = 0
                    start_time = time.time()

    def make_predict(self, adata: sc.AnnData, pert_id: str, cell_type: str) -> np.ndarray:
        """
        Inference routine: given (control) adata for a specific cell_type,
        produce predicted counts for the given pert_id.
        """
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.to(device)

        # Extract control cells for the given cell_type
        ctrl_cells_adata = adata[
            (adata.obs[PERT_KEY] == CONTROL_PERT) & (adata.obs[CELL_KEY] == cell_type)
        ]
        if ctrl_cells_adata.shape[0] == 0:
            # If no such cells exist, return an empty array
            logger.warning(f"No control cells found for cell_type={cell_type}. Returning empty array.")
            return np.array([])

        ctrl_cells = ctrl_cells_adata.X.toarray().copy()
        ctrl_cells = torch.from_numpy(ctrl_cells).float().to(device)
        ctrl_cells = torch.log1p(ctrl_cells)

        # Build whichpert_idx
        # We rely on adata.var[GENE_VAR_KEY] to find index of pert_id
        if pert_id not in adata.var[GENE_VAR_KEY].values:
            logger.warning(f"Pert_id={pert_id} not in adata.var. Returning empty.")
            return np.array([])

        whichpert_idx = np.where(adata.var[GENE_VAR_KEY] == pert_id)[0][0]
        whichpert = torch.tensor([whichpert_idx] * ctrl_cells.shape[0], dtype=torch.long, device=device)

        batch_size = 32
        pred_delta_list = []
        with torch.no_grad():
            for i in range(0, ctrl_cells.size(0), batch_size):
                batch_ctrl_cells = ctrl_cells[i:i+batch_size]
                batch_whichpert  = whichpert[i:i+batch_size]
                batch_pred_delta = self.forward(batch_ctrl_cells, batch_whichpert, multiplier=1)
                pred_delta_list.append(batch_pred_delta.clone())

        pred_delta = torch.cat(pred_delta_list, dim=0)
        pred = ctrl_cells + pred_delta
        pred[pred <= 0] = 0  # clamp negative predictions to 0
        pred = pred.cpu().detach().numpy()
        pred = np.round(pred)
        return pred


In [5]:
import numpy as np

npz = np.load("/orcd/data/omarabu/001/Omnicell_datasets/repogle_k562_essential_raw/proportional_scot/synthetic_counterfactuals_0.pkl", allow_pickle=True)

class PairedStratifiedDataset(torch.utils.data.Dataset):
    def __init__(
            self, source_dict, target_dict, pert_map
    ):
        self.source = source_dict
        self.target = target_dict
        self.strata = np.array(list(self.source.keys()))
        print(self.strata)
        self.unique_pert_ids = np.array(list(self.target[self.strata[0]].keys()))
        print(self.unique_pert_ids)
        self.pert_map = pert_map
        self.ns = np.array([
            len(self.source[stratum]) for stratum in self.strata
        ])

        self.samples_per_epoch = len(self.unique_pert_ids) * self.source[self.strata[0]].shape[0]

    def __len__(self):
        return len(self.source_dict * self.target_dict)
    
    def __getitem__(self, strata_idx):
        (stratum_idx,), idx = strata_idx
        stratum = self.strata[stratum_idx]
        pert = np.random.choice(self.unique_pert_ids)
        return (
            self.source[stratum][idx],
            self.target[stratum][pert][idx],
            self.pert_map[pert]
        )
    
dset = PairedStratifiedDataset(
    source_dict=npz['source'],
    target_dict=npz['synthetic_counterfactuals'],
    pert_map=pert_rep_map_idxs
)


['k562']
['AAAS' 'AAMP' 'AARS' ... 'ZRSR2' 'ZW10' 'ZWINT']


In [None]:
from omnicell.models.utils.datamodules import StratifiedBatchSampler, StreamingOfflinePairedStratifiedDataset

load_all_data = False
if load_all_data:
    dset = StreamingOfflinePairedStratifiedDataset(
        data_dir='/orcd/data/omarabu/001/Omnicell_datasets/repogle_k562_essential_raw/proportional_scot',
        pert_embedding=pert_rep_map_idxs,
        num_files=21,
        device='cuda'
    )

dl = torch.utils.data.DataLoader(
    dset, 
    batch_sampler=StratifiedBatchSampler(
        ns=dset.ns, batch_size=16, samples_per_epoch=dset.samples_per_epoch
    )
)

In [12]:
model = autoencoder(None, 8563)

In [13]:
model.train(dl, lr=1e-3)

[Epoch 1, Iter 500] Train MSE so far=0.897812 (ctrl+delta)   Time last 500 iters: 7.37s
[Epoch 1, Iter 1000] Train MSE so far=0.449795 (ctrl+delta)   Time last 500 iters: 6.84s
[Epoch 1, Iter 1500] Train MSE so far=0.287755 (ctrl+delta)   Time last 500 iters: 6.85s
[Epoch 1, Iter 2000] Train MSE so far=0.203946 (ctrl+delta)   Time last 500 iters: 6.86s
[Epoch 1, Iter 2500] Train MSE so far=0.163224 (ctrl+delta)   Time last 500 iters: 6.86s
[Epoch 1, Iter 3000] Train MSE so far=0.138004 (ctrl+delta)   Time last 500 iters: 6.99s
[Epoch 1, Iter 3500] Train MSE so far=0.124729 (ctrl+delta)   Time last 500 iters: 6.96s
[Epoch 1, Iter 4000] Train MSE so far=0.118541 (ctrl+delta)   Time last 500 iters: 6.97s
[Epoch 1, Iter 4500] Train MSE so far=0.115352 (ctrl+delta)   Time last 500 iters: 6.95s
[Epoch 1, Iter 5000] Train MSE so far=0.113665 (ctrl+delta)   Time last 500 iters: 6.96s
[Epoch 1, Iter 5500] Train MSE so far=0.109977 (ctrl+delta)   Time last 500 iters: 6.95s
[Epoch 1, Iter 6000] T

KeyboardInterrupt: 

In [6]:
model.train(dl, lr=2e-4)

[Epoch 1, Iter 500] Train MSE so far=0.114817 (ctrl+delta)   Time last 500 iters: 7.11s
[Epoch 1, Iter 1000] Train MSE so far=0.114359 (ctrl+delta)   Time last 500 iters: 7.02s
[Epoch 1, Iter 1500] Train MSE so far=0.113917 (ctrl+delta)   Time last 500 iters: 7.07s
[Epoch 1, Iter 2000] Train MSE so far=0.113912 (ctrl+delta)   Time last 500 iters: 7.06s
[Epoch 1, Iter 2500] Train MSE so far=0.114046 (ctrl+delta)   Time last 500 iters: 7.06s
[Epoch 1, Iter 3000] Train MSE so far=0.114259 (ctrl+delta)   Time last 500 iters: 7.05s
[Epoch 1, Iter 3500] Train MSE so far=0.113569 (ctrl+delta)   Time last 500 iters: 7.05s
[Epoch 1, Iter 4000] Train MSE so far=0.113675 (ctrl+delta)   Time last 500 iters: 7.04s
[Epoch 1, Iter 4500] Train MSE so far=0.113463 (ctrl+delta)   Time last 500 iters: 7.05s
[Epoch 1, Iter 5000] Train MSE so far=0.113553 (ctrl+delta)   Time last 500 iters: 7.06s
[Epoch 1, Iter 5500] Train MSE so far=0.113685 (ctrl+delta)   Time last 500 iters: 7.05s
[Epoch 1, Iter 6000] T

In [8]:
import numpy as np

logger.info("Running evaluation")

# evaluate each pair of cells and perts
eval_dict = {}
for cell_id, pert_id, ctrl_data, gt_data in loader.get_eval_data():
    logger.debug(f"Making predictions for cell: {cell_id}, pert: {pert_id}")

    preds = model.make_predict(ctrl_data, pert_id, cell_id)
    eval_dict[(cell_id, pert_id)] = (ctrl_data.X.toarray(), gt_data.X.toarray(), preds)
    
if not config.etl_config.log1p:
    for (cell, pert) in eval_dict:  
        ctrl_data, gt_data, pred_pert = eval_dict[(cell, pert)]
        # normalize to sum to 1
        ctrl_data = ctrl_data / ctrl_data.sum(axis=1).reshape(-1, 1) * 10_000
        gt_data = gt_data / gt_data.sum(axis=1).reshape(-1, 1) * 10_000
        pred_pert = pred_pert / pred_pert.sum(axis=1).reshape(-1, 1) * 10_000
        eval_dict[(cell, pert)] =  (np.log1p(ctrl_data), np.log1p(gt_data), np.log1p(pred_pert))


2025-02-19 20:13:32,388 - INFO - Running evaluation
2025-02-19 20:13:32,389 - INFO - Loading data catalogue from /orcd/data/omarabu/001/njwfish/omnicell/configs/catalogue/DatlingerBock2017.yaml
2025-02-19 20:13:32,391 - INFO - Loading data catalogue from /orcd/data/omarabu/001/njwfish/omnicell/configs/catalogue/SchraivogelSteinmetz2020_TAP_SCREEN__chromosome_11_screen.yaml
2025-02-19 20:13:32,393 - INFO - Loading data catalogue from /orcd/data/omarabu/001/njwfish/omnicell/configs/catalogue/ReplogleWeissman2022_K562_essential.yaml
2025-02-19 20:13:32,394 - INFO - Loading data catalogue from /orcd/data/omarabu/001/njwfish/omnicell/configs/catalogue/GasperiniShendure2019_atscale.yaml
2025-02-19 20:13:32,395 - INFO - Loading data catalogue from /orcd/data/omarabu/001/njwfish/omnicell/configs/catalogue/TianKampmann2021_CRISPRa.yaml
2025-02-19 20:13:32,397 - INFO - Loading data catalogue from /orcd/data/omarabu/001/njwfish/omnicell/configs/catalogue/repogle_k562_essential_raw.yaml
2025-02-19

In [9]:
import scanpy as sc
from omnicell.evaluation.utils import get_DEGs, get_eval, get_DEG_Coverage_Recall, get_DEGs_overlaps
pval_threshold = 0.05
log_fold_change_threshold = 0.0

results_dict = {}
for (cell, pert) in eval_dict:  
    ctrl_data, gt_data, pred_pert = eval_dict[(cell, pert)]

    pred_pert = sc.AnnData(X=pred_pert)
    true_pert = sc.AnnData(X=gt_data)
    control = sc.AnnData(X=ctrl_data)

    logger.debug(f"Getting ground Truth DEGs for {pert} and {cell}")
    true_DEGs_df = get_DEGs(control, true_pert)
    signif_true_DEG = true_DEGs_df[true_DEGs_df['pvals_adj'] < pval_threshold]

    logger.debug(f"Number of significant DEGS from ground truth: {signif_true_DEG.shape[0]}")

    logger.debug(f"Getting predicted DEGs for {pert} and {cell}")
    pred_DEGs_df = get_DEGs(control, pred_pert)


    logger.debug(f"Getting evaluation metrics for {pert} and {cell}")
    r2_and_mse = get_eval(control, true_pert, pred_pert, true_DEGs_df, [100,50,20], pval_threshold, log_fold_change_threshold)

    logger.debug(f"Getting DEG overlaps for {pert} and {cell}")
    DEGs_overlaps = get_DEGs_overlaps(true_DEGs_df, pred_DEGs_df, [100,50,20], pval_threshold, log_fold_change_threshold)

    results_dict[(cell, pert)] = (r2_and_mse, DEGs_overlaps)



  c /= stddev[:, None]
  c /= stddev[None, :]
  c /= stddev[:, None]
  c /= stddev[None, :]


In [10]:
for (cell, pert) in results_dict:
    r2_and_mse, DEGs_overlaps = results_dict[(cell, pert)]
    print(f"Cell: {cell}, Pert: {pert}")
    # print(f"R2 and MSE: {r2_and_mse}")
    print(f"DEGs Overlaps: {DEGs_overlaps}")
    print("-"*100)

Cell: k562, Pert: RPL15
DEGs Overlaps: {'Overlap_in_top_2157_DEGs': 645, 'Overlap_in_top_100_DEGs': 36, 'Overlap_in_top_50_DEGs': 11, 'Overlap_in_top_20_DEGs': 1, 'Jaccard': 0.15932572050027188}
----------------------------------------------------------------------------------------------------
Cell: k562, Pert: RPL4
DEGs Overlaps: {'Overlap_in_top_3714_DEGs': 1512, 'Overlap_in_top_100_DEGs': 31, 'Overlap_in_top_50_DEGs': 13, 'Overlap_in_top_20_DEGs': 6, 'Jaccard': 0.332315743567379}
----------------------------------------------------------------------------------------------------
Cell: k562, Pert: RPL7
DEGs Overlaps: {'Overlap_in_top_3845_DEGs': 1668, 'Overlap_in_top_100_DEGs': 34, 'Overlap_in_top_50_DEGs': 13, 'Overlap_in_top_20_DEGs': 5, 'Jaccard': 0.37248359470581693}
----------------------------------------------------------------------------------------------------
Cell: k562, Pert: RNF113A
DEGs Overlaps: {'Overlap_in_top_2852_DEGs': 194, 'Overlap_in_top_100_DEGs': 5, 'Overlap_

In [18]:
r2_and_mse

r2 {'all_genes_mean_sub_diff_R': np.float32(-0.12489348), 'all_genes_mean_sub_diff_R2': np.float32(0.015598381), 'all_genes_mean_sub_diff_MSE': np.float32(0.068853445), 'all_genes_mean_fold_diff_R': np.float32(-0.46685344), 'all_genes_mean_fold_diff_R2': np.float32(0.21795213), 'all_genes_mean_fold_diff_MSE': np.float32(24.076696), 'all_genes_mean_R': np.float32(0.88978803), 'all_genes_mean_R2': np.float32(0.7917227), 'all_genes_mean_MSE': np.float32(0.068853445), 'all_genes_var_R': np.float32(0.48825538), 'all_genes_var_R2': np.float32(0.23839332), 'all_genes_var_MSE': np.float32(0.02451049), 'all_genes_corr_mtx_R': np.float64(0.2062134434958776), 'all_genes_corr_mtx_R2': np.float64(0.04252398427842751), 'all_genes_corr_mtx_MSE': np.float64(0.0030738678439887306), 'all_genes_cov_mtx_R': np.float64(0.2403351520107548), 'all_genes_cov_mtx_R2': np.float64(0.05776098529203262), 'all_genes_cov_mtx_MSE': np.float64(9.166059826543978e-05), 'Top_3226_DEGs_sub_diff_R': np.float32(-0.13143796),