In [1]:
import scanpy as sc
import cfp
import os
import random
import numpy as np
import pickle
import anndata as ad

  from optuna import progress_bar as pbar_module


In [2]:
data_dir = "/lustre/groups/ml01/workspace/ot_perturbation/data/pbmc"

In [3]:
adata = sc.read_h5ad("/lustre/groups/ml01/workspace/ot_perturbation/data/pbmc/adata_hvg2000_LV.h5ad")

In [4]:
adata

AnnData object with n_obs × n_vars = 9697974 × 2000
    obs: 'sample', 'species', 'gene_count', 'tscp_count', 'mread_count', 'bc1_wind', 'bc2_wind', 'bc3_wind', 'bc1_well', 'bc2_well', 'bc3_well', 'log1p_n_genes_by_counts', 'log1p_total_counts', 'total_counts_MT', 'pct_counts_MT', 'log1p_total_counts_MT', 'donor', 'cytokine', 'treatment', 'cell_type'
    var: 'n_cells', 'highly_variable', 'means', 'dispersions', 'dispersions_norm'
    uns: 'hvg', 'log1p'
    layers: 'counts'

In [5]:
adata.X = adata.layers["counts"]
sc.pp.normalize_total(adata, target_sum=1e4)
sc.pp.log1p(adata)



In [6]:
esm_data_dir = "/lustre/groups/ml01/workspace/ot_perturbation/data/pbmc"
with open(os.path.join(esm_data_dir, "esm2_embeddings.pkl"), "rb") as file:
    esm2_embeddings = pickle.load(file)

adata.uns["esm2_embeddings"] = esm2_embeddings

In [7]:
def get_donor_embeddings(adata: ad.AnnData) -> None:
    adata.uns["donor_embeddings"] = {}
    for donor in adata.obs["donor"].unique():
        pbs_gex = adata[(adata.obs["donor"]==donor) & (adata.obs["cytokine"]=="PBS")]
        adata.uns["donor_embeddings"][donor] = np.array(pbs_gex.X.mean(axis=0))

In [8]:
get_donor_embeddings(adata)

In [9]:
adata.obs["condition"] = adata.obs["sample"]

In [10]:
adata.obs["is_control"] = adata.obs.apply(lambda x: True if x["treatment"]=="PBS" else False, axis=1)

In [11]:
rng = np.random.default_rng(0)

unique_cytokines = adata.obs["cytokine"].unique()
cytokines_to_impute = rng.choice(unique_cytokines, size=10)

unique_cytokines = list(set(unique_cytokines)-set(["PBS"]).union(set(cytokines_to_impute)))
cytokines_to_impute

array(['OX40L', 'IL-32-beta', 'IL-1Ra', 'IFN-gamma', 'IFN-omega', 'BAFF',
       'CD27L', 'ADSF', 'FasL', 'M-CSF'], dtype=object)

In [12]:
num_observed_cytokines = [1,2,4,8,16,32,64,80]

In [13]:
cytokines_to_train_data = {}
for k in num_observed_cytokines:
    res = []
    for i in range(3):
        candidate = rng.choice(unique_cytokines, size=k, replace=False)
        res.append(list(candidate) + ["PBS"])
    
    cytokines_to_train_data[str(k)] = res

In [14]:
adata.uns['cytokines_to_impute'] = cytokines_to_impute
adata.uns['cytokines_to_train_data'] = cytokines_to_train_data

In [15]:
adata.write("/lustre/groups/ml01/workspace/ot_perturbation/data/pbmc/pbmc_new_donor_processed.h5ad")

Choose 10 cytokines to always evaluate on.
Then, choose 3 sets of size [1,2,4,8,16,32,64,80] of "train cytokines" to include, resulting in 12 x 3 x 8 = 288 models, and also include the "OOD patient" scenario 12 times -> 288+12=300 models. We always impute on 10 cytokines --> 3000 predictions.
