In [1]:
import os
import scanpy as sc
import anndata
import numpy as np
import pandas as pd
import pickle
from sklearn.metrics import r2_score
from scipy.sparse import csr_matrix, vstack
from tqdm import tqdm
from ott.geometry import costs, pointcloud
from ott.solvers.linear import sinkhorn
from ott.solvers import linear
from ott.tools.sinkhorn_divergence import sinkhorn_divergence
from sklearn.metrics import r2_score

In [2]:
k_ood = 6
pathway = 'IFNG'
cell_type = 'BXPC3'
output_dir = "/lustre/groups/ml01/workspace/ot_perturbation/data/satija/datasets/r2"
pca_components = 30

In [3]:
def compute_r_squared(x: np.ndarray, y: np.ndarray) -> float:
    return r2_score(np.mean(x, axis=0), np.mean(y, axis=0))

def rank_genes_groups_by_cov(
    adata,
    groupby,
    covariate,
    control_group,
    n_genes=50,
    rankby_abs=True,
    key_added="rank_genes_groups_cov",
    return_dict=False,
):
    gene_dict = {}
    cov_categories = adata.obs[covariate].unique()
    for cov_cat in cov_categories:
        # name of the control group in the groupby obs column
        control_group_cov = "_".join([cov_cat, control_group])
        adata_cov = adata[adata.obs[covariate] == cov_cat]

        # compute DEGs
        sc.tl.rank_genes_groups(
            adata_cov,
            groupby=groupby,
            reference=control_group_cov,
            rankby_abs=rankby_abs,
            n_genes=n_genes,
            use_raw=False,
        )
        # add entries to dictionary of gene sets
        de_genes = pd.DataFrame(adata_cov.uns["rank_genes_groups"]["names"])
        for group in de_genes:
            gene_dict[group] = de_genes[group].tolist()
    adata.uns[key_added] = gene_dict
    if return_dict:
        return gene_dict


def get_DE_genes(adata):
    adata.obs.loc[:, "control"] = adata.obs.gene.apply(lambda x: 1 if x == "NT" else 0)
    adata.obs = adata.obs.astype("category")
    rank_genes_groups_by_cov(
        adata,
        groupby="condition",
        covariate="cell_type",
        control_group="NT",
        n_genes=50,
        key_added="rank_genes_groups_cov_all",
    )
    return adata

In [4]:
# TODO create complete list of embeddings
ko_embeddings = pickle.load(open('/lustre/groups/ml01/workspace/ot_perturbation/data/satija/embeddings/embeddings_ifng.pkl', 'rb'))
ko_embeddings = pd.DataFrame(ko_embeddings).T
cell_embeddings = pd.read_csv('/lustre/groups/ml01/workspace/ot_perturbation/data/satija/embeddings/cell_line_embedding_full_ccle_300_normalized.csv', index_col=0)

In [5]:
adata = sc.read("/lustre/groups/ml01/workspace/ot_perturbation/data/satija/h5ad/" + pathway + "_Perturb_seq.h5ad")
adata

AnnData object with n_obs × n_vars = 245240 × 33525
    obs: 'orig.ident', 'nCount_RNA', 'nFeature_RNA', 'sample', 'cell_type', 'percent.mito', 'sample_ID', 'Batch_info', 'bc1_well', 'bc2_well', 'bc3_well', 'guide', 'gene', 'mixscale_score'
    var: 'vst.mean', 'vst.variance', 'vst.variance.expected', 'vst.variance.standardized', 'vst.variable'

In [6]:
adata = adata[adata.obs['cell_type'] == cell_type, :]
adata

View of AnnData object with n_obs × n_vars = 67554 × 33525
    obs: 'orig.ident', 'nCount_RNA', 'nFeature_RNA', 'sample', 'cell_type', 'percent.mito', 'sample_ID', 'Batch_info', 'bc1_well', 'bc2_well', 'bc3_well', 'guide', 'gene', 'mixscale_score'
    var: 'vst.mean', 'vst.variance', 'vst.variance.expected', 'vst.variance.standardized', 'vst.variable'

In [7]:
adata.obs['condition'] = adata.obs.apply(lambda x: "_".join([x.cell_type, x.gene]), axis=1)

  adata.obs['condition'] = adata.obs.apply(lambda x: "_".join([x.cell_type, x.gene]), axis=1)


In [8]:
condition_counts = adata.obs['condition'].value_counts()
conditions_with_at_least_100 = condition_counts[condition_counts >= 100].index.tolist()
adata = adata[adata.obs['condition'].isin(conditions_with_at_least_100), :]
adata

View of AnnData object with n_obs × n_vars = 67155 × 33525
    obs: 'orig.ident', 'nCount_RNA', 'nFeature_RNA', 'sample', 'cell_type', 'percent.mito', 'sample_ID', 'Batch_info', 'bc1_well', 'bc2_well', 'bc3_well', 'guide', 'gene', 'mixscale_score', 'condition'
    var: 'vst.mean', 'vst.variance', 'vst.variance.expected', 'vst.variance.standardized', 'vst.variable'

In [9]:
adata.obs['condition'].value_counts(ascending=True)

condition
BXPC3_MCRS1        103
BXPC3_TSC22D1      121
BXPC3_ZC3H3        149
BXPC3_PPARG        156
BXPC3_PTGES3       163
BXPC3_PIK3CA       168
BXPC3_CLK1         178
BXPC3_SP110        185
BXPC3_IFI16        185
BXPC3_ZFP36        194
BXPC3_FBXO6        200
BXPC3_ZNFX1        211
BXPC3_ZNF267       215
BXPC3_KLF4         232
BXPC3_IRF5         258
BXPC3_MAFF         321
BXPC3_CEBPE        353
BXPC3_STAT1        396
BXPC3_TRAFD1       418
BXPC3_STAT3        435
BXPC3_FMNL2        473
BXPC3_RNF14        663
BXPC3_PLEK         715
BXPC3_RFX5         744
BXPC3_SRC          752
BXPC3_CEBPB        785
BXPC3_IFNGR1       812
BXPC3_MYC          899
BXPC3_MAFB        1012
BXPC3_JAK1        1024
BXPC3_ETV7        1032
BXPC3_PRDM1       1047
BXPC3_IRF7        1055
BXPC3_RARRES3     1195
BXPC3_JAK2        1233
BXPC3_BATF2       1254
BXPC3_GUK1        1269
BXPC3_HLA-DQB1    1270
BXPC3_FOXN3       1357
BXPC3_IRF9        1456
BXPC3_IRF2        1584
BXPC3_JUN         1755
BXPC3_ATF5        1762
B

In [10]:
sc.pp.highly_variable_genes(adata, inplace=True, n_top_genes=2000)
adata = adata[:,adata.var["highly_variable"]==True]

  adata.uns["hvg"] = {"flavor": flavor}


In [11]:
adata = get_DE_genes(adata)

  self.obj[key] = value
  adata.uns[key_added] = {}
  self.stats[group_name, "names"] = self.var_names[global_indices]
  self.stats[group_name, "scores"] = scores[global_indices]
  self.stats[group_name, "pvals"] = pvals[global_indices]
  self.stats[group_name, "pvals_adj"] = pvals_adj[global_indices]
  self.stats[group_name, "logfoldchanges"] = np.log2(
  self.stats[group_name, "names"] = self.var_names[global_indices]
  self.stats[group_name, "scores"] = scores[global_indices]
  self.stats[group_name, "pvals"] = pvals[global_indices]
  self.stats[group_name, "pvals_adj"] = pvals_adj[global_indices]
  self.stats[group_name, "logfoldchanges"] = np.log2(
  self.stats[group_name, "names"] = self.var_names[global_indices]
  self.stats[group_name, "scores"] = scores[global_indices]
  self.stats[group_name, "pvals"] = pvals[global_indices]
  self.stats[group_name, "pvals_adj"] = pvals_adj[global_indices]
  self.stats[group_name, "logfoldchanges"] = np.log2(
  self.stats[group_name, "names"]

In [12]:
adata.uns

OrderedDict([('hvg', {'flavor': 'seurat'}),
             ('rank_genes_groups_cov_all',
              {'BXPC3_ATF3': ['ATF3',
                'AC105460.1',
                'CALB1',
                'LRMDA',
                'LAMB3',
                'XACT',
                'CHN2',
                'AC068633.1',
                'CXCL9',
                'LINC00431',
                'KCNIP4',
                'WARS',
                'SAMD5',
                'CEACAM6',
                'KRT17',
                'DENND5A',
                'FAT2',
                'SLC7A11',
                'CMSS1',
                'JAZF1',
                'FYB1',
                'AL390334.1',
                'CCSER1',
                'AC016205.1',
                'PCDH7',
                'POLQ',
                'PCED1B',
                'PELI2',
                'MGST1',
                'EPHB2',
                'FIRRE',
                'COL4A6',
                'AGPAT4',
                'MDGA1',
                'EML6

In [13]:
len(adata.obs["condition"].unique())

53

In [14]:
controls = {}
for ct in adata.obs["cell_type"].unique():
    controls[ct] = adata[adata.obs["condition"]==ct+'_NT'].X.A

In [15]:
pert_effects = {}
for c in tqdm(adata.obs["condition"].unique()):
    if c.endswith('_NT'):
        continue
    cell_type = c.split("_")[0]
    pert_effects[c] = float(compute_r_squared(
        controls[cell_type],
        adata[adata.obs["condition"]==c].X.A))

100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 53/53 [00:02<00:00, 18.86it/s]


In [16]:
pert_effects

{'BXPC3_IRF9': 0.9856853363998033,
 'BXPC3_TRAFD1': 0.9852078045582372,
 'BXPC3_PRDM1': 0.9886977260532006,
 'BXPC3_BATF2': 0.988948642545062,
 'BXPC3_FOXN3': 0.98519454073412,
 'BXPC3_ZC3H3': 0.969852896166733,
 'BXPC3_PTPN11': 0.9936922697011042,
 'BXPC3_PARP12': 0.9898363502999318,
 'BXPC3_IRF1': 0.9939819336396337,
 'BXPC3_TSC22D1': 0.9782140961499074,
 'BXPC3_IFNGR2': 0.9452783928890172,
 'BXPC3_RARRES3': 0.9880637545854505,
 'BXPC3_EHF': 0.9794072504476613,
 'BXPC3_RUNX1': 0.9911735268629562,
 'BXPC3_IFNGR1': 0.9072536030081575,
 'BXPC3_ATF3': 0.994784859143625,
 'BXPC3_CEBPB': 0.9784095687397139,
 'BXPC3_ATF5': 0.9941769713911562,
 'BXPC3_MAFB': 0.9877679591169174,
 'BXPC3_ETV7': 0.9860055124383612,
 'BXPC3_ZNFX1': 0.9820538891171373,
 'BXPC3_IRF5': 0.98665570302769,
 'BXPC3_FBXO6': 0.9802992943575491,
 'BXPC3_SP100': 0.9953189793233818,
 'BXPC3_MAFF': 0.964723269078971,
 'BXPC3_TBX21': 0.996405731419676,
 'BXPC3_IRF2': 0.9668081881637749,
 'BXPC3_JAK1': 0.9320600372339054,
 'BX

In [17]:
features_cell_line = np.zeros((adata.shape[0], cell_embeddings.shape[1]))
features_ko = np.zeros((adata.shape[0], ko_embeddings.shape[1]))
for cell_line, emb in cell_embeddings.iterrows():
    features_cell_line[adata.obs['cell_type'].isin([cell_line])] = emb.values
for ko, emb in ko_embeddings.iterrows():
    features_ko[adata.obs['gene'].isin([ko])] = emb.values

In [18]:
adata.obsm['cell_line_emb'] = features_cell_line
adata.obsm['gene_emb'] = features_ko
adata.obsm['cond_emb'] = np.concatenate([adata.obsm['cell_line_emb'], adata.obsm['gene_emb']], axis = -1)

In [19]:
adata.obsm['cell_line_emb'].shape

(67155, 300)

In [20]:
adata.obsm['gene_emb'].shape

(67155, 2560)

In [21]:
for col in adata.obs.select_dtypes(include=["category"]):
    adata.obs[col].cat.remove_unused_categories()

In [22]:
adata = get_DE_genes(adata)

  adata.uns[key_added] = {}
  self.stats[group_name, "names"] = self.var_names[global_indices]
  self.stats[group_name, "scores"] = scores[global_indices]
  self.stats[group_name, "pvals"] = pvals[global_indices]
  self.stats[group_name, "pvals_adj"] = pvals_adj[global_indices]
  self.stats[group_name, "logfoldchanges"] = np.log2(
  self.stats[group_name, "names"] = self.var_names[global_indices]
  self.stats[group_name, "scores"] = scores[global_indices]
  self.stats[group_name, "pvals"] = pvals[global_indices]
  self.stats[group_name, "pvals_adj"] = pvals_adj[global_indices]
  self.stats[group_name, "logfoldchanges"] = np.log2(
  self.stats[group_name, "names"] = self.var_names[global_indices]
  self.stats[group_name, "scores"] = scores[global_indices]
  self.stats[group_name, "pvals"] = pvals[global_indices]
  self.stats[group_name, "pvals_adj"] = pvals_adj[global_indices]
  self.stats[group_name, "logfoldchanges"] = np.log2(
  self.stats[group_name, "names"] = self.var_names[global

In [23]:
df_effects = pd.DataFrame({"conditions": pert_effects.keys(), "sinkhorn_div": pert_effects.values()})
df_effects_sorted = df_effects.sort_values("sinkhorn_div")
df_effects_sorted

Unnamed: 0,conditions,sinkhorn_div
49,BXPC3_STAT1,0.88969
14,BXPC3_IFNGR1,0.907254
27,BXPC3_JAK1,0.93206
10,BXPC3_IFNGR2,0.945278
30,BXPC3_JAK2,0.957406
24,BXPC3_MAFF,0.964723
26,BXPC3_IRF2,0.966808
5,BXPC3_ZC3H3,0.969853
41,BXPC3_IFI16,0.974417
44,BXPC3_ZFP36,0.97631


In [24]:
ood_ko = df_effects_sorted[::len(df_effects_sorted) // k_ood]
ood_ko

Unnamed: 0,conditions,sinkhorn_div
49,BXPC3_STAT1,0.88969
41,BXPC3_IFI16,0.974417
12,BXPC3_EHF,0.979407
1,BXPC3_TRAFD1,0.985208
18,BXPC3_MAFB,0.987768
47,BXPC3_IRF7,0.991036
17,BXPC3_ATF5,0.994177


In [89]:
ood_conditions = ood_ko.conditions.values
ood_conditions

array(['HAP1_IFNGR2', 'HAP1_EHF', 'HAP1_GUK1', 'HAP1_IRF7', 'HAP1_SP100',
       'HAP1_ZFP36', 'HAP1_TBX21'], dtype=object)

In [90]:
ood_unique = list(set([value for entry in ood_conditions for value in entry.split('_')]))
ood_unique

['SP100', 'IFNGR2', 'GUK1', 'EHF', 'IRF7', 'HAP1', 'ZFP36', 'TBX21']

In [91]:
len(ood_unique)

8

In [92]:
remaining_conditions = df_effects_sorted[~df_effects_sorted['conditions'].isin(ood_conditions)].conditions.values
remaining_unique = list(set([value for entry in remaining_conditions for value in entry.split('_')]))
remaining_unique

['TRAFD1',
 'IRF2',
 'SP110',
 'SRC',
 'MAFB',
 'FBXO6',
 'IRF9',
 'KLF4',
 'IRF5',
 'HLX',
 'ZC3H3',
 'RFX5',
 'PTPN11',
 'STAT1',
 'HAP1',
 'CEBPE',
 'ZNFX1',
 'HLA-DQB1',
 'CEBPB',
 'RNF14',
 'ETV7',
 'ATF5',
 'IFNGR1',
 'JAK1',
 'BATF2',
 'NFKB1',
 'FOXN3',
 'ZNF267',
 'RARRES3',
 'JUN',
 'TAPBPL',
 'PIK3CA',
 'JAK2',
 'PTGES3',
 'FMNL2',
 'PLEK',
 'ATF3',
 'STAT3',
 'PARP12',
 'MYC',
 'IRF1',
 'RUNX1',
 'TSC22D1',
 'MAFF',
 'STAT2',
 'PRDM1']

In [93]:
len(set(ood_unique) - set(remaining_unique))

7

In [94]:
adata.obs["ood"] = adata.obs.apply(lambda x: x["condition"] if x["condition"] in ood_conditions else False, axis=1)
adata.obs["is_ood"] = adata.obs.apply(lambda x: x["condition"] in ood_conditions, axis=1)
adata.obs["ood"] = adata.obs["ood"].astype("category")

In [95]:
adata_train = adata[~adata.obs["is_ood"]]
adata_ood = adata[adata.obs["is_ood"]]

In [96]:
adata_ood.uns

{'hvg': {'flavor': 'seurat'},
 'rank_genes_groups_cov_all': {'HAP1_ATF3': ['DLEU1',
   'AC022639.1',
   'HLA-B',
   'AC009271.1',
   'GABRB3',
   'CRADD',
   'UNC5C',
   'AL356805.1',
   'WARS',
   'PPFIA2',
   'B3GALT4',
   'XIST',
   'SMIM2-AS1',
   'C9orf135',
   'NRP1',
   'KMT2E-AS1',
   'SEC1P',
   'A2M',
   'GPR137C',
   'LINC01376',
   'HIVEP2',
   'NKAIN3',
   'GBP5',
   'DDX60',
   'GRAMD1B',
   'EPHB2',
   'LINC01788',
   'BCAS3',
   'SYN2',
   'OR2I1P',
   'FAT3',
   'RASGEF1B',
   'CFH',
   'Z93403.1',
   'MSC-AS1',
   'FAM177B',
   'CFAP299',
   'SP100',
   'SHROOM4',
   'KCNQ5',
   'ACSBG1',
   'TMPRSS4',
   'DACH1',
   'PRKCA',
   'AL121601.1',
   'WDR72',
   'IGFBP2',
   'EFEMP1',
   'AL049637.1',
   'TLE4'],
  'HAP1_ATF5': ['KIRREL3',
   'ZNF804A',
   'RALYL',
   'LSAMP',
   'CELF2',
   'EPHA6',
   'CDH13',
   'HAVCR1',
   'GRID2',
   'SH3GL3',
   'IQCJ',
   'KMT2E-AS1',
   'SLC35F3',
   'CEP135',
   'AL356805.1',
   'AC024901.1',
   'FUT9',
   'KHDRBS2',
   'CDK14',


In [97]:
adata_train.n_obs, adata_ood.n_obs

(19359, 3043)

In [98]:
rng = np.random.default_rng(0)
split_dfs = []
for c in adata_train.obs["condition"].unique():
    adata_subset = adata_train[(adata_train.obs["condition"]==c)]
    n_cells = adata_subset.n_obs
    if c.endswith('_NT'):
        idx_test = rng.choice(np.arange(n_cells), 500, replace=False)
        remaining_idcs = set(np.arange(n_cells)) - set(idx_test)
        split = ["test" if idx in idx_test else "train" for idx in range(n_cells)]
        df = adata_subset.obs[["condition"]].copy()
        df["split"] = split
        split_dfs.append(df)
    elif n_cells>300:
        idx_test = rng.choice(np.arange(n_cells), 100, replace=False)
        split = ["test" if idx in idx_test else "train" for idx in range(n_cells)]
        df = adata_subset.obs[["condition"]].copy()
        df["split"] = split
        split_dfs.append(df)

In [99]:
df_concat = pd.concat(split_dfs, axis=0)

In [100]:
len(df_concat), adata_train.n_obs

(16210, 19359)

In [101]:
adata_train.obs["split"] = df_concat[["split"]]

  adata_train.obs["split"] = df_concat[["split"]]


In [102]:
adata_ood[adata_ood.obs["condition"].str.endswith('_NT')]

View of AnnData object with n_obs × n_vars = 0 × 2000
    obs: 'orig.ident', 'nCount_RNA', 'nFeature_RNA', 'sample', 'cell_type', 'percent.mito', 'sample_ID', 'Batch_info', 'bc1_well', 'bc2_well', 'bc3_well', 'guide', 'gene', 'mixscale_score', 'condition', 'control', 'ood', 'is_ood'
    var: 'vst.mean', 'vst.variance', 'vst.variance.expected', 'vst.variance.standardized', 'vst.variable', 'highly_variable', 'means', 'dispersions', 'dispersions_norm'
    uns: 'hvg', 'rank_genes_groups_cov_all'
    obsm: 'cell_line_emb', 'gene_emb', 'cond_emb'

In [103]:
adata_train_final = adata_train[adata_train.obs["split"]=="train"]
adata_test_final = adata_train[adata_train.obs["split"]=="test"]
adata_ood_final = anndata.concat((adata_ood, adata_test_final[adata_test_final.obs["condition"].str.endswith('_NT')]))

In [104]:
adata_ood_final.uns = adata_ood.uns

In [105]:
adata_train_final.obs["condition"].value_counts()

condition
HAP1_RUNX1       847
HAP1_BATF2       745
HAP1_STAT1       744
HAP1_ETV7        676
HAP1_IRF5        613
HAP1_JAK1        609
HAP1_PRDM1       586
HAP1_MAFF        582
HAP1_NT          561
HAP1_IRF1        544
HAP1_TSC22D1     486
HAP1_MYC         485
HAP1_FOXN3       460
HAP1_STAT3       444
HAP1_JUN         397
HAP1_ATF3        393
HAP1_PARP12      373
HAP1_CEBPE       348
HAP1_PLEK        347
HAP1_JAK2        339
HAP1_PTPN11      336
HAP1_MAFB        305
HAP1_RNF14       290
HAP1_ZNFX1       256
HAP1_HLA-DQB1    244
HAP1_KLF4        235
HAP1_FMNL2       234
HAP1_FBXO6       222
HAP1_TRAFD1      209
Name: count, dtype: int64

In [106]:
adata_ood_final.obs["condition"].value_counts()

condition
HAP1_IFNGR2    827
HAP1_TBX21     554
HAP1_SP100     512
HAP1_NT        500
HAP1_ZFP36     428
HAP1_IRF7      372
HAP1_GUK1      221
HAP1_EHF       129
Name: count, dtype: int64

In [107]:
adata_test_final.obs["condition"].value_counts()

condition
HAP1_NT          500
HAP1_BATF2       100
HAP1_CEBPE       100
HAP1_ETV7        100
HAP1_ATF3        100
HAP1_FMNL2       100
HAP1_FOXN3       100
HAP1_HLA-DQB1    100
HAP1_IRF1        100
HAP1_IRF5        100
HAP1_JAK1        100
HAP1_JAK2        100
HAP1_FBXO6       100
HAP1_JUN         100
HAP1_KLF4        100
HAP1_MAFF        100
HAP1_MAFB        100
HAP1_MYC         100
HAP1_PARP12      100
HAP1_PLEK        100
HAP1_PRDM1       100
HAP1_PTPN11      100
HAP1_RNF14       100
HAP1_RUNX1       100
HAP1_STAT1       100
HAP1_STAT3       100
HAP1_TRAFD1      100
HAP1_TSC22D1     100
HAP1_ZNFX1       100
Name: count, dtype: int64

In [108]:
adata_train_final.varm["X_train_mean"] = adata_train_final.X.mean(axis=0).T

  adata_train_final.varm["X_train_mean"] = adata_train_final.X.mean(axis=0).T


In [109]:
train_mean = adata_train_final.varm["X_train_mean"].T
adata_train_final.layers["centered_X"] = csr_matrix(adata_train_final.X.A - train_mean)

In [110]:
adata_train_final.layers["X_log1p"] = adata_train_final.X.copy()
adata_train_final.X = adata_train_final.layers["centered_X"]

In [111]:
sc.pp.pca(adata_train_final, zero_center=False, n_comps=30)

In [112]:
adata_train_final.X = adata_train_final.layers["X_log1p"]

In [113]:
adata_train_final.varm["X_train_mean"] = adata_train_final.varm["X_train_mean"]
adata_ood_final.varm["X_train_mean"] = adata_train_final.varm["X_train_mean"]
adata_test_final.varm["X_train_mean"] = adata_train_final.varm["X_train_mean"]

  adata_test_final.varm["X_train_mean"] = adata_train_final.varm["X_train_mean"]


In [114]:
adata_test_final.layers["centered_X"] = csr_matrix(adata_test_final.X.A - train_mean)
adata_ood_final.layers["centered_X"] = csr_matrix(adata_ood_final.X.A - train_mean)

In [115]:
adata_test_final.obsm["X_pca"] = np.matmul(adata_test_final.layers["centered_X"].A, adata_train_final.varm["PCs"])
adata_ood_final.obsm["X_pca"] = np.matmul(adata_ood_final.layers["centered_X"].A, adata_train_final.varm["PCs"])

In [116]:
adata_train_final.obsm["X_pca"].shape

(12910, 30)

In [117]:
for ob in [adata_train_final, adata_ood_final, adata_test_final]:
    for column in ob.obs.columns:
        if pd.api.types.is_categorical_dtype(ob.obs[column]):
            ob.obs[column] = ob.obs[column].astype(str)

  if pd.api.types.is_categorical_dtype(ob.obs[column]):
  if pd.api.types.is_categorical_dtype(ob.obs[column]):
  if pd.api.types.is_categorical_dtype(ob.obs[column]):
  if pd.api.types.is_categorical_dtype(ob.obs[column]):
  if pd.api.types.is_categorical_dtype(ob.obs[column]):
  if pd.api.types.is_categorical_dtype(ob.obs[column]):
  if pd.api.types.is_categorical_dtype(ob.obs[column]):
  if pd.api.types.is_categorical_dtype(ob.obs[column]):
  if pd.api.types.is_categorical_dtype(ob.obs[column]):
  if pd.api.types.is_categorical_dtype(ob.obs[column]):
  if pd.api.types.is_categorical_dtype(ob.obs[column]):
  if pd.api.types.is_categorical_dtype(ob.obs[column]):
  if pd.api.types.is_categorical_dtype(ob.obs[column]):
  if pd.api.types.is_categorical_dtype(ob.obs[column]):
  if pd.api.types.is_categorical_dtype(ob.obs[column]):
  if pd.api.types.is_categorical_dtype(ob.obs[column]):
  if pd.api.types.is_categorical_dtype(ob.obs[column]):
  if pd.api.types.is_categorical_dtype(ob.obs[co

In [118]:
adata_train_final.varm['X_train_mean'] = adata_train_final.varm['X_train_mean'].A
adata_ood_final.varm['X_train_mean'] = adata_ood_final.varm['X_train_mean'].A
adata_test_final.varm['X_train_mean'] = adata_test_final.varm['X_train_mean'].A

In [119]:
os.path.join(output_dir, "adata_train_" + pathway + "_" + cell_type + ".h5ad")

'/lustre/groups/ml01/workspace/ot_perturbation/data/satija/datasets/r2/adata_train_IFNG_HAP1.h5ad'

In [120]:
adata_train_final.write(os.path.join(output_dir, "adata_train_" + pathway + "_" + cell_type + ".h5ad"))
adata_ood_final.write(os.path.join(output_dir, "adata_ood_" + pathway + "_" + cell_type + ".h5ad"))
adata_test_final.write(os.path.join(output_dir, "adata_test_" + pathway + "_" + cell_type + ".h5ad"))

In [121]:
adata_loaded = sc.read(os.path.join(output_dir, "adata_ood_" + pathway + "_" + cell_type + ".h5ad"))

In [122]:
decoded_test = np.matmul(adata_test_final.obsm["X_pca"], adata_train_final.varm["PCs"].T)

In [123]:
compute_r_squared(np.asarray(decoded_test+adata_test_final.varm["X_train_mean"].T), adata_test_final.X.A)

0.9996340942475105

In [124]:
decoded_ood = np.matmul(adata_ood_final.obsm["X_pca"], adata_train_final.varm["PCs"].T)

In [125]:
compute_r_squared(np.asarray(decoded_ood+adata_ood_final.varm["X_train_mean"].T), adata_ood_final.X.A)

0.9994210548548264

In [126]:
adata_train_final

AnnData object with n_obs × n_vars = 12910 × 2000
    obs: 'orig.ident', 'nCount_RNA', 'nFeature_RNA', 'sample', 'cell_type', 'percent.mito', 'sample_ID', 'Batch_info', 'bc1_well', 'bc2_well', 'bc3_well', 'guide', 'gene', 'mixscale_score', 'condition', 'control', 'ood', 'is_ood', 'split'
    var: 'vst.mean', 'vst.variance', 'vst.variance.expected', 'vst.variance.standardized', 'vst.variable', 'highly_variable', 'means', 'dispersions', 'dispersions_norm'
    uns: 'hvg', 'rank_genes_groups_cov_all', 'pca'
    obsm: 'cell_line_emb', 'gene_emb', 'cond_emb', 'X_pca'
    varm: 'X_train_mean', 'PCs'
    layers: 'centered_X', 'X_log1p'

In [127]:
adata_train_final.obsm['X_pca'].shape

(12910, 30)

In [128]:
adata_test_final.obsm['X_pca'].shape

(3300, 30)

In [129]:
adata_ood_final.uns

{'hvg': {'flavor': 'seurat'},
 'rank_genes_groups_cov_all': {'HAP1_ATF3': ['DLEU1',
   'AC022639.1',
   'HLA-B',
   'AC009271.1',
   'GABRB3',
   'CRADD',
   'UNC5C',
   'AL356805.1',
   'WARS',
   'PPFIA2',
   'B3GALT4',
   'XIST',
   'SMIM2-AS1',
   'C9orf135',
   'NRP1',
   'KMT2E-AS1',
   'SEC1P',
   'A2M',
   'GPR137C',
   'LINC01376',
   'HIVEP2',
   'NKAIN3',
   'GBP5',
   'DDX60',
   'GRAMD1B',
   'EPHB2',
   'LINC01788',
   'BCAS3',
   'SYN2',
   'OR2I1P',
   'FAT3',
   'RASGEF1B',
   'CFH',
   'Z93403.1',
   'MSC-AS1',
   'FAM177B',
   'CFAP299',
   'SP100',
   'SHROOM4',
   'KCNQ5',
   'ACSBG1',
   'TMPRSS4',
   'DACH1',
   'PRKCA',
   'AL121601.1',
   'WDR72',
   'IGFBP2',
   'EFEMP1',
   'AL049637.1',
   'TLE4'],
  'HAP1_ATF5': ['KIRREL3',
   'ZNF804A',
   'RALYL',
   'LSAMP',
   'CELF2',
   'EPHA6',
   'CDH13',
   'HAVCR1',
   'GRID2',
   'SH3GL3',
   'IQCJ',
   'KMT2E-AS1',
   'SLC35F3',
   'CEP135',
   'AL356805.1',
   'AC024901.1',
   'FUT9',
   'KHDRBS2',
   'CDK14',
