In [126]:
from pathlib import Path

import anndata as ad
import numpy as np
import torch

from chemCPA.data import PerturbationDataModule, load_dataset_splits
from chemCPA.lightning_module import ChemCPA

In [2]:
ckpt = "last.ckpt"
split = 0
# update the run_id and path to checkpoints
run_id = "qipfyleo"
cp_path = Path("/home/jovyan/git-repos/platform-publication/method_benchmarking/chemcpa/chemCPA/project_folder/output") / run_id / ckpt

In [3]:
module = ChemCPA.load_from_checkpoint(cp_path)

In [4]:
data_params = module.config["dataset"]

In [5]:
datasets, dataset = load_dataset_splits(**data_params, return_dataset=True)

  utils.warn_names_duplicates("obs")


In [92]:
dm = PerturbationDataModule(datasplits=datasets, train_bs=module.config["model"]["hparams"]["batch_size"])
dm.setup(stage="fit")  # fit, validate/test, predict

In [93]:
module.model.eval();

In [94]:
train_control_dataset = dm.train_control_dataset
ood_control_dataset = dm.ood_control_dataset
test_control_dataset = dm.test_control_dataset
train_treated_dataset = dm.train_treated_dataset
ood_treated_dataset = dm.ood_treated_dataset
test_treated_dataset = dm.test_treated_dataset

In [106]:
def compute_preds(control_dataset, treated_dataset):
    control_genes = {}

    # Iterate over the dataset
    _genes = control_dataset.genes
    _cov_names = control_dataset.covariate_names["cell_type"]
    
    for covariate, gene in zip(_cov_names, _genes):
        if covariate not in control_genes:
            control_genes[covariate] = gene.unsqueeze(0)
            continue
        control_genes[covariate] = torch.concat([control_genes[covariate], gene.unsqueeze(0)], dim=0)
    module.model.eval()
    module.model.to("cuda")
    
    preds = {}
    targs = {}
    
    for pert_cat, item in zip(treated_dataset.pert_categories, treated_dataset):
        if pert_cat not in preds:
            genes = item[0]
            drug_idx = item[1]
            dosages = item[2]
            covariates = item[4:]
            cl = pert_cat.split("_")[0]
            dose = pert_cat.split("_")[-1]
            drug = "_".join(pert_cat.split("_")[1:-1])
    
            genes = control_genes[cl]
            n_obs = len(control_genes[cl])
    
            # repeat torch tensor n_obs times
            drugs_idx = drug_idx.repeat(n_obs)
            dosages = dosages.repeat(n_obs)
            covariates = [cov.repeat(n_obs, 1) for cov in covariates]
            gene_reconstructions, cell_drug_embedding, latent_basal = module.model.predict(
                genes=genes,
                drugs=None,
                drugs_idx=drugs_idx,
                dosages=dosages,
                covariates=covariates,
                return_latent_basal=True,
            )
    
            dim = gene_reconstructions.size(1) // 2
            mean = gene_reconstructions[:, :dim]
            var = gene_reconstructions[:, dim:]
    
            preds[pert_cat] = mean.detach().cpu().numpy()
            targs[pert_cat] = (
                (treated_dataset.genes[treated_dataset.pert_categories == pert_cat]).clone().numpy()
            )
    predictions = []
    targets = []
    cl_p = []
    cl_t = []
    drug_p = []
    drug_t = []
    dose_p = []
    dose_t = []
    control = {}
    control_cl = {}
    for key, val in preds.items():
        cl = key.split("_")[0]
        drug = "_".join(key.split("_")[1:-1])
        dose = key.split("_")[-1]
    
        control[cl] = control_genes[cl].numpy()
        control_cl[cl] = control[cl].shape[0] * [cl]
    
        predictions.append(val)
        cl_p.extend(val.shape[0] * [cl])
        drug_p.extend(val.shape[0] * [drug])
        dose_p.extend(val.shape[0] * [float(dose)])
    
        targets.append(targs[key])
        cl_t.extend(targs[key].shape[0] * [cl])
        drug_t.extend(targs[key].shape[0] * [drug])
        dose_t.extend(targs[key].shape[0] * [float(dose)])
    
    adata_c = ad.AnnData(np.concatenate([control[cl] for cl in control], axis=0))
    adata_c.obs["cell_line"] = list(np.concatenate([control_cl[cl] for cl in control], axis=0))
    adata_c.obs["condition"] = "control"
    adata_c.obs["perturbation"] = "Vehicle"
    adata_c.obs["dose"] = 1.0
    
    adata_p = ad.AnnData(np.concatenate(predictions, axis=0))
    adata_p.obs["condition"] = "prediction"
    adata_p.obs["cell_line"] = cl_p
    adata_p.obs["perturbation"] = drug_p
    adata_p.obs["dose"] = dose_p
    
    
    adata_t = ad.AnnData(np.concatenate(targets, axis=0))
    adata_t.obs["condition"] = "target"
    adata_t.obs["cell_line"] = cl_t
    adata_t.obs["perturbation"] = drug_t
    adata_t.obs["dose"] = dose_t
    
    adata = ad.concat([adata_c, adata_p, adata_t])
    adata.obs_names_make_unique()
    adata.obs["pert_category"] = None
    
    for key in np.unique(ood_treated_dataset.pert_categories):
        cl = key.split("_")[0]
        drug = "_".join(key.split("_")[1:-1])
        dose = float(key.split("_")[-1])
    
        cond = adata.obs["cell_line"] == cl
        cond *= adata.obs["perturbation"] == drug
        cond *= adata.obs["dose"] == dose
        adata.obs.loc[cond, "pert_category"] = key
    return adata

In [107]:
ood_preds = compute_preds(ood_control_dataset, ood_treated_dataset)

  utils.warn_names_duplicates("obs")


In [108]:
ood_preds.write(f"/home/jovyan/git-repos/platform-publication/method_benchmarking/chemcpa/chemCPA/project_folder/output/results/adata_biolord_split_{split}_300_pred_ood.h5ad")

In [109]:
test_preds = compute_preds(test_control_dataset, test_treated_dataset)

  utils.warn_names_duplicates("obs")


In [112]:
test_preds.obs.pert_category = test_preds.obs.pert_category.astype(str)

In [113]:
test_preds.write(f"/home/jovyan/git-repos/platform-publication/method_benchmarking/chemcpa/chemCPA/project_folder/output/results/adata_biolord_split_{split}_300_pred_test.h5ad")

In [114]:
train_preds = compute_preds(train_control_dataset, train_treated_dataset)

  utils.warn_names_duplicates("obs")


In [116]:
train_preds.obs.pert_category = train_preds.obs.pert_category.astype(str)
train_preds.write(f"/home/jovyan/git-repos/platform-publication/method_benchmarking/chemcpa/chemCPA/project_folder/output/results/adata_biolord_split_{split}_300_pred_train.h5ad")

In [117]:
# Add dataset labels to each AnnData
train_preds.obs['dataset'] = 'train_preds'
test_preds.obs['dataset'] = 'test_preds'
ood_preds.obs['dataset'] = 'ood_preds'

# Concatenate the three AnnData objects
combined_anndata = ad.concat([train_preds, test_preds, ood_preds], 
                                  label='dataset', 
                                  keys=['train_preds', 'test_preds', 'ood_preds'], 
                                  index_unique="-")


In [119]:
combined_anndata.write_h5ad("/home/jovyan/git-repos/platform-publication/method_benchmarking/chemcpa/chemCPA/project_folder/output/results/combined_data.h5ad")


In [161]:
adata_raw = ad.read_h5ad("project_folder/adata_biolord_300_split_0_chemcpa.h5ad")

  utils.warn_names_duplicates("obs")


In [121]:
import pandas as pd

obs_df = combined_anndata.obs[['cell_line', 'perturbation']]

# Drop duplicates to find unique pairs
unique_pairs = obs_df.drop_duplicates()

# Convert to a list of tuples for better readability (optional)
unique_pairs_list = list(unique_pairs.itertuples(index=False, name=None))

# If you need the result in a DataFrame for further processing
unique_pairs_df = pd.DataFrame(unique_pairs_list, columns=['cell_line', 'perturbation'])

In [123]:
# pip install decoupler

In [124]:
import anndata as ad
import decoupler as dc

# Ensure 'normalized' layer exists; if not, create it from .X
if 'normalized' not in combined_anndata.layers:
    combined_anndata.layers['normalized'] = combined_anndata.X.copy()

# Compute pseudobulk profiles using the mean of normalized counts
pseudobulk_adata = dc.get_pseudobulk(
    adata=combined_anndata,
    sample_col="cell_line",
    groups_col=['cell_line', 'perturbation'],  # Columns to group by
    layer='normalized',  # Layer containing normalized counts
    mode='mean',  # Aggregate counts by calculating the mean
    skip_checks=True
)

In [162]:
adata_raw.var["gene_id"] = range(len(adata_raw.var))
adata_raw.var = adata_raw.var.reset_index().rename(columns={'index': 'gene_name'})

In [163]:
gene_id_to_name = adata_raw.var[["gene_id", "gene_name"]]  # Create a mapping of gene_id to gene_name

# Step 2: Match pseudobulk_adata.var "names" with gene IDs
# Ensure pseudobulk_adata.var "names" is properly formatted as integers
pseudobulk_adata.var["gene_id"] = pseudobulk_adata.var.index.astype(int)

# Step 3: Add gene names to pseudobulk_adata.var based on matching gene IDs
pseudobulk_adata.var = pseudobulk_adata.var.merge(gene_id_to_name, on="gene_id", how="left")

# Step 4: Optional: Set gene names as descriptors (e.g., in .var)
pseudobulk_adata.var["gene_name"] = pseudobulk_adata.var["gene_name"]

In [164]:
pseudobulk_adata.write_h5ad("/home/jovyan/git-repos/platform-publication/method_benchmarking/chemcpa/chemCPA/project_folder/output/results/pseudobulk.h5ad")

In [177]:
# Filter for Vehicle (control) data
control_mask = pseudobulk_adata.obs["perturbation"] == "Vehicle"
control_adata = pseudobulk_adata[control_mask]

# Prepare lists to store results
signature_data = []
obs_data = []

# Iterate over unique cell_lines
for cell_line in pseudobulk_adata.obs["cell_line"].unique():
    # Get control data for the current cell_line
    control_data = control_adata[control_adata.obs["cell_line"] == cell_line]
    
    # Skip if no control data exists
    if control_data.shape[0] == 0:
        continue
    
    # Compute the mean of control in .X space
    control_values = control_data.X.mean(axis=0)
    
    # Get perturbation data for the current cell_line
    perturbation_data = pseudobulk_adata[
        (pseudobulk_adata.obs["cell_line"] == cell_line) & 
        (pseudobulk_adata.obs["perturbation"] != "Vehicle")
    ]
    
    # Iterate over unique perturbations
    for perturbation in perturbation_data.obs["perturbation"].unique():
        # Get data for the current perturbation
        perturbation_values = perturbation_data[
            perturbation_data.obs["perturbation"] == perturbation
        ].X.mean(axis=0)
        
        # Calculate perturbation signature: difference between perturbation and control
        signature = perturbation_values - control_values
        
        # Append signature and observation metadata
        signature_data.append(signature)
        obs_data.append({"cell_line": cell_line, "perturbation": perturbation})

# Convert results to a new AnnData object
signature_matrix = np.vstack(signature_data)  # Stack all signatures into a 2D array
obs_df = pd.DataFrame(obs_data)  # Create a DataFrame for metadata

# Create an AnnData object to store perturbation signatures
signature_adata = ad.AnnData(
    X=signature_matrix,
    obs=obs_df
)

# Assign gene names to var
signature_adata.var = pseudobulk_adata.var.copy()

# Save to file or inspect the result
print(signature_adata)
signature_adata.write("/home/jovyan/git-repos/platform-publication/method_benchmarking/chemcpa/chemCPA/project_folder/output/results/perturbation_signatures_across_lines.h5ad")




AnnData object with n_obs × n_vars = 485 × 1999
    obs: 'cell_line', 'perturbation'
    var: 'gene_id', 'gene_name'


In [178]:
import numpy as np
import anndata as ad
import pandas as pd

# Verify required columns exist in obs
assert "cell_line" in pseudobulk_adata.obs, "'cell_line' column is missing in pseudobulk_adata.obs"
assert "perturbation" in pseudobulk_adata.obs, "'perturbation' column is missing in pseudobulk_adata.obs"

# Filter for Vehicle (control) data
control_mask = pseudobulk_adata.obs["perturbation"] == "Vehicle"
control_adata = pseudobulk_adata[control_mask]

# Prepare lists to store results
signature_data = []
obs_data = []

# Compute the mean control values across all cell lines
control_values = control_adata.X.mean(axis=0)

# Get all unique perturbations (excluding "Vehicle")
unique_perturbations = pseudobulk_adata.obs["perturbation"].unique()
unique_perturbations = unique_perturbations[unique_perturbations != "Vehicle"]

# Iterate over unique perturbations
for perturbation in unique_perturbations:
    # Get data for the current perturbation across all cell lines
    perturbation_data = pseudobulk_adata[pseudobulk_adata.obs["perturbation"] == perturbation]
    
    # Compute the mean perturbation values across cell lines
    perturbation_values = perturbation_data.X.mean(axis=0)
    
    # Calculate perturbation signature: difference between perturbation and control
    signature = perturbation_values - control_values
    
    # Append signature and observation metadata
    signature_data.append(signature)
    obs_data.append({"perturbation": perturbation})

# Convert results to a new AnnData object
signature_matrix = np.vstack(signature_data)  # Stack all signatures into a 2D array
obs_df = pd.DataFrame(obs_data)  # Create a DataFrame for metadata

# Create an AnnData object to store perturbation signatures
signature_adata = ad.AnnData(
    X=signature_matrix,
    obs=obs_df
)

# Assign gene names to var
signature_adata.var = pseudobulk_adata.var.copy()

# Save to file or inspect the result
print(signature_adata)




AnnData object with n_obs × n_vars = 186 × 1999
    obs: 'perturbation'
    var: 'gene_id', 'gene_name'


In [179]:
signature_adata.write("/home/jovyan/git-repos/platform-publication/method_benchmarking/chemcpa/chemCPA/project_folder/output/results/perturbation_signatures_per_compound.h5ad")