In [107]:
import os
import scanpy as sc
import anndata
import numpy as np
import pandas as pd
import pickle
from sklearn.metrics import r2_score
from scipy.sparse import csr_matrix, vstack
from tqdm import tqdm
from ott.geometry import costs, pointcloud
from ott.solvers.linear import sinkhorn
from ott.solvers import linear
from ott.tools.sinkhorn_divergence import sinkhorn_divergence
from sklearn.metrics import r2_score

In [171]:
k_ood = 6
pathway = 'IFNG'
cell_type = 'HAP1'
output_dir = "/lustre/groups/ml01/workspace/ot_perturbation/data/satija/datasets/ifng/r2"

In [172]:
def compute_r_squared(x: np.ndarray, y: np.ndarray) -> float:
    return r2_score(np.mean(x, axis=0), np.mean(y, axis=0))

In [173]:
# TODO create complete list of embeddings
ko_embeddings = pickle.load(open('/lustre/groups/ml01/workspace/ot_perturbation/data/satija/embeddings/embeddings_ifng.pkl', 'rb'))
ko_embeddings = pd.DataFrame(ko_embeddings).T
cell_embeddings = pd.read_csv('/lustre/groups/ml01/workspace/ot_perturbation/data/satija/embeddings/cell_line_embedding_full_ccle_300_normalized.csv', index_col=0)

In [174]:
adata = sc.read("/lustre/groups/ml01/workspace/ot_perturbation/data/satija/h5ad/" + pathway + "_Perturb_seq.h5ad")
adata

AnnData object with n_obs × n_vars = 245240 × 33525
    obs: 'orig.ident', 'nCount_RNA', 'nFeature_RNA', 'sample', 'cell_type', 'percent.mito', 'sample_ID', 'Batch_info', 'bc1_well', 'bc2_well', 'bc3_well', 'guide', 'gene', 'mixscale_score'
    var: 'vst.mean', 'vst.variance', 'vst.variance.expected', 'vst.variance.standardized', 'vst.variable'

In [175]:
adata = adata[adata.obs['cell_type'] == cell_type, :]
adata

View of AnnData object with n_obs × n_vars = 22817 × 33525
    obs: 'orig.ident', 'nCount_RNA', 'nFeature_RNA', 'sample', 'cell_type', 'percent.mito', 'sample_ID', 'Batch_info', 'bc1_well', 'bc2_well', 'bc3_well', 'guide', 'gene', 'mixscale_score'
    var: 'vst.mean', 'vst.variance', 'vst.variance.expected', 'vst.variance.standardized', 'vst.variable'

In [176]:
adata.obs['condition'] = adata.obs.apply(lambda x: "_".join([x.cell_type, x.gene]), axis=1)

  adata.obs['condition'] = adata.obs.apply(lambda x: "_".join([x.cell_type, x.gene]), axis=1)


In [177]:
condition_counts = adata.obs['condition'].value_counts()
conditions_with_at_least_100 = condition_counts[condition_counts >= 100].index.tolist()
adata = adata[adata.obs['condition'].isin(conditions_with_at_least_100), :]
adata

View of AnnData object with n_obs × n_vars = 22402 × 33525
    obs: 'orig.ident', 'nCount_RNA', 'nFeature_RNA', 'sample', 'cell_type', 'percent.mito', 'sample_ID', 'Batch_info', 'bc1_well', 'bc2_well', 'bc3_well', 'guide', 'gene', 'mixscale_score', 'condition'
    var: 'vst.mean', 'vst.variance', 'vst.variance.expected', 'vst.variance.standardized', 'vst.variable'

In [178]:
adata.obs['condition'].value_counts(ascending=True)

condition
HAP1_PIK3CA       104
HAP1_EHF          129
HAP1_ZC3H3        131
HAP1_IRF2         133
HAP1_HLX          134
HAP1_TAPBPL       136
HAP1_NFKB1        138
HAP1_SRC          166
HAP1_IRF9         174
HAP1_RARRES3      182
HAP1_ATF5         202
HAP1_STAT2        203
HAP1_ZNF267       206
HAP1_RFX5         207
HAP1_SP110        217
HAP1_GUK1         221
HAP1_PTGES3       239
HAP1_CEBPB        277
HAP1_IFNGR1       300
HAP1_TRAFD1       309
HAP1_FBXO6        322
HAP1_FMNL2        334
HAP1_KLF4         335
HAP1_HLA-DQB1     344
HAP1_ZNFX1        356
HAP1_IRF7         372
HAP1_RNF14        390
HAP1_MAFB         405
HAP1_ZFP36        428
HAP1_PTPN11       436
HAP1_JAK2         439
HAP1_PLEK         447
HAP1_CEBPE        448
HAP1_PARP12       473
HAP1_ATF3         493
HAP1_JUN          497
HAP1_SP100        512
HAP1_STAT3        544
HAP1_TBX21        554
HAP1_FOXN3        560
HAP1_MYC          585
HAP1_TSC22D1      586
HAP1_IRF1         644
HAP1_MAFF         682
HAP1_PRDM1        686


In [179]:
sc.pp.highly_variable_genes(adata, inplace=True, n_top_genes=2000)
adata = adata[:,adata.var["highly_variable"]==True]

  adata.uns["hvg"] = {"flavor": flavor}


In [180]:
controls = {}
for ct in adata.obs["cell_type"].unique():
    controls[ct] = adata[adata.obs["condition"]==ct+'_NT'].X.A

In [181]:
pert_effects = {}
for c in tqdm(adata.obs["condition"].unique()):
    if c.endswith('_NT'):
        continue
    cell_type = c.split("_")[0]
    pert_effects[c] = float(compute_r_squared(
        controls[cell_type],
        adata[adata.obs["condition"]==c].X.A))

100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 53/53 [00:01<00:00, 36.56it/s]


In [182]:
pert_effects

{'HAP1_PARP12': 0.9951486773727936,
 'HAP1_HLA-DQB1': 0.9951123417185797,
 'HAP1_ETV7': 0.9968376895289561,
 'HAP1_TRAFD1': 0.9944211037483908,
 'HAP1_PLEK': 0.9958565240759455,
 'HAP1_TBX21': 0.99642143421027,
 'HAP1_CEBPE': 0.9964098481052932,
 'HAP1_HLX': 0.9693206925722188,
 'HAP1_MAFB': 0.9947618655042224,
 'HAP1_BATF2': 0.9911634155531313,
 'HAP1_PTPN11': 0.9916861716802698,
 'HAP1_PTGES3': 0.9917827337371468,
 'HAP1_PRDM1': 0.9961540716062789,
 'HAP1_MAFF': 0.9928683962983671,
 'HAP1_RUNX1': 0.9971833328884568,
 'HAP1_IRF5': 0.9967849039171066,
 'HAP1_IFNGR2': 0.8386700681646465,
 'HAP1_STAT3': 0.9958027883687551,
 'HAP1_TSC22D1': 0.9961330367657327,
 'HAP1_RARRES3': 0.9908305892494338,
 'HAP1_JAK1': 0.91949323892956,
 'HAP1_JUN': 0.9961206182416456,
 'HAP1_MYC': 0.9963200998352286,
 'HAP1_FMNL2': 0.9950976079060341,
 'HAP1_CEBPB': 0.9927641183402391,
 'HAP1_IRF9': 0.9867422629451881,
 'HAP1_ATF5': 0.9921214496886448,
 'HAP1_ATF3': 0.9945984101088307,
 'HAP1_GUK1': 0.99071308005

In [183]:
features_cell_line = np.zeros((adata.shape[0], cell_embeddings.shape[1]))
features_ko = np.zeros((adata.shape[0], ko_embeddings.shape[1]))
for cell_line, emb in cell_embeddings.iterrows():
    features_cell_line[adata.obs['cell_type'].isin([cell_line])] = emb.values
for ko, emb in ko_embeddings.iterrows():
    features_ko[adata.obs['gene'].isin([ko])] = emb.values

In [184]:
adata.obsm['cell_line_emb'] = features_cell_line
adata.obsm['gene_emb'] = features_ko
adata.obsm['cond_emb'] = np.concatenate([adata.obsm['cell_line_emb'], adata.obsm['gene_emb']], axis = -1)

  adata.obsm['cell_line_emb'] = features_cell_line


In [185]:
adata.obsm['cell_line_emb'].shape

(22402, 300)

In [186]:
adata.obsm['gene_emb'].shape

(22402, 2560)

In [187]:
df_effects = pd.DataFrame({"conditions": pert_effects.keys(), "sinkhorn_div": pert_effects.values()})
df_effects_sorted = df_effects.sort_values("sinkhorn_div")
df_effects_sorted

Unnamed: 0,conditions,sinkhorn_div
16,HAP1_IFNGR2,0.83867
35,HAP1_STAT1,0.847299
36,HAP1_JAK2,0.8553
46,HAP1_IFNGR1,0.864424
20,HAP1_JAK1,0.919493
37,HAP1_IRF1,0.943917
7,HAP1_HLX,0.969321
49,HAP1_ZC3H3,0.972156
39,HAP1_EHF,0.98631
25,HAP1_IRF9,0.986742


In [188]:
ood_ko = df_effects_sorted[::k_ood]
ood_ko

Unnamed: 0,conditions,sinkhorn_div
16,HAP1_IFNGR2,0.83867
7,HAP1_HLX,0.969321
51,HAP1_PIK3CA,0.98855
9,HAP1_BATF2,0.991163
44,HAP1_IRF7,0.992463
3,HAP1_TRAFD1,0.994421
0,HAP1_PARP12,0.995149
29,HAP1_RNF14,0.996044
5,HAP1_TBX21,0.996421


In [189]:
ood_conditions = ood_ko.conditions.values
ood_conditions

array(['HAP1_IFNGR2', 'HAP1_HLX', 'HAP1_PIK3CA', 'HAP1_BATF2',
       'HAP1_IRF7', 'HAP1_TRAFD1', 'HAP1_PARP12', 'HAP1_RNF14',
       'HAP1_TBX21'], dtype=object)

In [190]:
ood_unique = list(set([value for entry in ood_conditions for value in entry.split('_')]))
ood_unique

['TBX21',
 'HAP1',
 'HLX',
 'PARP12',
 'IFNGR2',
 'PIK3CA',
 'BATF2',
 'RNF14',
 'TRAFD1',
 'IRF7']

In [191]:
len(ood_unique)

10

In [192]:
remaining_conditions = df_effects_sorted[~df_effects_sorted['conditions'].isin(ood_conditions)].conditions.values
remaining_unique = list(set([value for entry in remaining_conditions for value in entry.split('_')]))
remaining_unique

['NFKB1',
 'FMNL2',
 'ATF3',
 'MYC',
 'MAFB',
 'GUK1',
 'SRC',
 'ZNF267',
 'ZC3H3',
 'IRF1',
 'FBXO6',
 'STAT2',
 'IRF5',
 'CEBPE',
 'FOXN3',
 'PTPN11',
 'IRF9',
 'IRF2',
 'RUNX1',
 'EHF',
 'TSC22D1',
 'IFNGR1',
 'HLA-DQB1',
 'ATF5',
 'STAT3',
 'ZFP36',
 'PRDM1',
 'ZNFX1',
 'JUN',
 'ETV7',
 'MAFF',
 'JAK2',
 'SP100',
 'SP110',
 'STAT1',
 'TAPBPL',
 'PLEK',
 'JAK1',
 'HAP1',
 'CEBPB',
 'KLF4',
 'RFX5',
 'RARRES3',
 'PTGES3']

In [193]:
len(set(ood_unique) - set(remaining_unique))

9

In [194]:
adata.obs["ood"] = adata.obs.apply(lambda x: x["condition"] if x["condition"] in ood_conditions else False, axis=1)
adata.obs["is_ood"] = adata.obs.apply(lambda x: x["condition"] in ood_conditions, axis=1)
adata.obs["ood"] = adata.obs["ood"].astype("category")

In [195]:
adata_train = adata[~adata.obs["condition"].isin(ood_conditions)].copy()
adata_ood = adata[adata.obs["condition"].isin(ood_conditions)].copy()

In [196]:
adata_train.n_obs, adata_ood.n_obs

(18394, 4008)

In [197]:
rng = np.random.default_rng(0)
split_dfs = []
for c in adata_train.obs["condition"].unique():
    adata_subset = adata_train[(adata_train.obs["condition"]==c)]
    n_cells = adata_subset.n_obs
    if c.endswith('_NT'):
        idx_test = rng.choice(np.arange(n_cells), 500, replace=False)
        remaining_idcs = set(np.arange(n_cells)) - set(idx_test)
        split = ["test" if idx in idx_test else "train" for idx in range(n_cells)]
        df = adata_subset.obs[["condition"]].copy()
        df["split"] = split
        split_dfs.append(df)
    elif n_cells>300:
        idx_test = rng.choice(np.arange(n_cells), 100, replace=False)
        split = ["test" if idx in idx_test else "train" for idx in range(n_cells)]
        df = adata_subset.obs[["condition"]].copy()
        df["split"] = split
        split_dfs.append(df)

In [198]:
df_concat = pd.concat(split_dfs, axis=0)

In [199]:
len(df_concat), adata_train.n_obs

(15133, 18394)

In [200]:
df_concat[df_concat["condition"].str.endswith('_NT')]

Unnamed: 0,condition,split
19_54_52_1_1_1_1_1_1_1_1_1,HAP1_NT,train
18_66_05_1_1_1_1_1_1_1_1_1,HAP1_NT,test
19_41_75_1_1_1_1_1_1_1_1_1,HAP1_NT,train
18_02_05_1_1_1_1_1_1_1_1_1,HAP1_NT,test
18_20_51_1_1_1_1_1_1_1_1_1,HAP1_NT,train
...,...,...
14_33_94_2_2,HAP1_NT,train
14_94_31_2_2,HAP1_NT,train
16_21_49_2_2,HAP1_NT,test
16_23_27_2_2,HAP1_NT,test


In [201]:
df_concat["condition"].value_counts()

condition
HAP1_NT          1061
HAP1_RUNX1        947
HAP1_STAT1        844
HAP1_ETV7         776
HAP1_IRF5         713
HAP1_JAK1         709
HAP1_PRDM1        686
HAP1_MAFF         682
HAP1_IRF1         644
HAP1_TSC22D1      586
HAP1_MYC          585
HAP1_FOXN3        560
HAP1_STAT3        544
HAP1_SP100        512
HAP1_JUN          497
HAP1_ATF3         493
HAP1_CEBPE        448
HAP1_PLEK         447
HAP1_JAK2         439
HAP1_PTPN11       436
HAP1_ZFP36        428
HAP1_MAFB         405
HAP1_ZNFX1        356
HAP1_HLA-DQB1     344
HAP1_KLF4         335
HAP1_FMNL2        334
HAP1_FBXO6        322
Name: count, dtype: int64

In [202]:
adata_train.obs["split"] = df_concat[["split"]]

In [203]:
adata_ood[adata_ood.obs["condition"].str.endswith('_NT')]

View of AnnData object with n_obs × n_vars = 0 × 2000
    obs: 'orig.ident', 'nCount_RNA', 'nFeature_RNA', 'sample', 'cell_type', 'percent.mito', 'sample_ID', 'Batch_info', 'bc1_well', 'bc2_well', 'bc3_well', 'guide', 'gene', 'mixscale_score', 'condition', 'ood', 'is_ood'
    var: 'vst.mean', 'vst.variance', 'vst.variance.expected', 'vst.variance.standardized', 'vst.variable', 'highly_variable', 'means', 'dispersions', 'dispersions_norm'
    uns: 'hvg'
    obsm: 'cell_line_emb', 'gene_emb', 'cond_emb'

In [204]:
adata_train_final = adata_train[adata_train.obs["split"]=="train"]
adata_test_final = adata_train[adata_train.obs["split"]=="test"]
adata_ood_final = anndata.concat((adata_ood, adata_test_final[adata_test_final.obs["condition"].str.endswith('_NT')]))

In [205]:
adata_train_final.obs["condition"].value_counts()

condition
HAP1_RUNX1       847
HAP1_STAT1       744
HAP1_ETV7        676
HAP1_IRF5        613
HAP1_JAK1        609
HAP1_PRDM1       586
HAP1_MAFF        582
HAP1_NT          561
HAP1_IRF1        544
HAP1_TSC22D1     486
HAP1_MYC         485
HAP1_FOXN3       460
HAP1_STAT3       444
HAP1_SP100       412
HAP1_JUN         397
HAP1_ATF3        393
HAP1_CEBPE       348
HAP1_PLEK        347
HAP1_JAK2        339
HAP1_PTPN11      336
HAP1_ZFP36       328
HAP1_MAFB        305
HAP1_ZNFX1       256
HAP1_HLA-DQB1    244
HAP1_KLF4        235
HAP1_FMNL2       234
HAP1_FBXO6       222
Name: count, dtype: int64

In [206]:
adata_ood_final.obs["condition"].value_counts()

condition
HAP1_BATF2     845
HAP1_IFNGR2    827
HAP1_TBX21     554
HAP1_NT        500
HAP1_PARP12    473
HAP1_RNF14     390
HAP1_IRF7      372
HAP1_TRAFD1    309
HAP1_HLX       134
HAP1_PIK3CA    104
Name: count, dtype: int64

In [207]:
adata_test_final.obs["condition"].value_counts()

condition
HAP1_NT          500
HAP1_CEBPE       100
HAP1_MAFB        100
HAP1_PTPN11      100
HAP1_IRF5        100
HAP1_FMNL2       100
HAP1_ZFP36       100
HAP1_STAT3       100
HAP1_STAT1       100
HAP1_ETV7        100
HAP1_JUN         100
HAP1_HLA-DQB1    100
HAP1_TSC22D1     100
HAP1_JAK1        100
HAP1_KLF4        100
HAP1_RUNX1       100
HAP1_MAFF        100
HAP1_JAK2        100
HAP1_PLEK        100
HAP1_ATF3        100
HAP1_PRDM1       100
HAP1_ZNFX1       100
HAP1_FOXN3       100
HAP1_FBXO6       100
HAP1_SP100       100
HAP1_MYC         100
HAP1_IRF1        100
Name: count, dtype: int64

In [208]:
adata_train_final.varm["X_train_mean"] = adata_train_final.X.mean(axis=0).T

  adata_train_final.varm["X_train_mean"] = adata_train_final.X.mean(axis=0).T


In [209]:
train_mean = adata_train_final.varm["X_train_mean"].T
adata_train_final.layers["centered_X"] = csr_matrix(adata_train_final.X.A - train_mean)

In [210]:
adata_train_final.layers["X_log1p"] = adata_train_final.X.copy()
adata_train_final.X = adata_train_final.layers["centered_X"]

In [211]:
sc.pp.pca(adata_train_final, zero_center=False, n_comps=30)

In [212]:
adata_train_final.X = adata_train_final.layers["X_log1p"]

In [213]:
adata_ood_final.varm["X_train_mean"] = adata_train_final.varm["X_train_mean"]
adata_test_final.varm["X_train_mean"] = adata_train_final.varm["X_train_mean"]

  adata_test_final.varm["X_train_mean"] = adata_train_final.varm["X_train_mean"]


In [214]:
adata_test_final.layers["centered_X"] = csr_matrix(adata_test_final.X.A - train_mean)
adata_ood_final.layers["centered_X"] = csr_matrix(adata_ood_final.X.A - train_mean)

In [215]:
adata_train_final.obsm["X_pca"] = sc.pp.pca(adata_train_final.layers["centered_X"], zero_center=False, n_comps=50)

In [216]:
adata_train_final.obsm["X_pca"].shape

(12033, 50)

In [217]:
adata_ood_final.varm["X_train_mean"] = adata_train_final.varm["X_train_mean"]
adata_test_final.varm["X_train_mean"] = adata_train_final.varm["X_train_mean"]

In [218]:
adata_test_final.layers["centered_X"] = csr_matrix(adata_test_final.X.A - train_mean)
adata_ood_final.layers["centered_X"] = csr_matrix(adata_ood_final.X.A - train_mean)

In [219]:
adata_test_final.obsm["X_pca"] = np.matmul(adata_test_final.layers["centered_X"].A, adata_train_final.varm["PCs"])
adata_ood_final.obsm["X_pca"] = np.matmul(adata_ood_final.layers["centered_X"].A, adata_train_final.varm["PCs"])

In [220]:
for ob in [adata_train_final, adata_ood_final, adata_test_final]:
    for column in ob.obs.columns:
        if pd.api.types.is_categorical_dtype(ob.obs[column]):
            ob.obs[column] = ob.obs[column].astype(str)

  if pd.api.types.is_categorical_dtype(ob.obs[column]):
  if pd.api.types.is_categorical_dtype(ob.obs[column]):
  if pd.api.types.is_categorical_dtype(ob.obs[column]):
  if pd.api.types.is_categorical_dtype(ob.obs[column]):
  if pd.api.types.is_categorical_dtype(ob.obs[column]):
  if pd.api.types.is_categorical_dtype(ob.obs[column]):
  if pd.api.types.is_categorical_dtype(ob.obs[column]):
  if pd.api.types.is_categorical_dtype(ob.obs[column]):
  if pd.api.types.is_categorical_dtype(ob.obs[column]):
  if pd.api.types.is_categorical_dtype(ob.obs[column]):
  if pd.api.types.is_categorical_dtype(ob.obs[column]):
  if pd.api.types.is_categorical_dtype(ob.obs[column]):
  if pd.api.types.is_categorical_dtype(ob.obs[column]):
  if pd.api.types.is_categorical_dtype(ob.obs[column]):
  if pd.api.types.is_categorical_dtype(ob.obs[column]):
  if pd.api.types.is_categorical_dtype(ob.obs[column]):
  if pd.api.types.is_categorical_dtype(ob.obs[column]):
  if pd.api.types.is_categorical_dtype(ob.obs[co

In [221]:
adata_train_final.varm['X_train_mean'] = adata_train_final.varm['X_train_mean'].A
adata_ood_final.varm['X_train_mean'] = adata_ood_final.varm['X_train_mean'].A
adata_test_final.varm['X_train_mean'] = adata_test_final.varm['X_train_mean'].A

In [222]:
adata_train_final.write(os.path.join(output_dir, "adata_train_" + pathway + "_" + cell_type + ".h5ad"))
adata_ood_final.write(os.path.join(output_dir, "adata_ood_" + pathway + "_" + cell_type + ".h5ad"))
adata_test_final.write(os.path.join(output_dir, "adata_test_" + pathway + "_" + cell_type + ".h5ad"))

In [223]:
adata_loaded = sc.read(os.path.join(output_dir, "adata_ood_" + pathway + "_" + cell_type + ".h5ad"))

In [224]:
decoded_test = np.matmul(adata_test_final.obsm["X_pca"], adata_train_final.varm["PCs"].T)

In [225]:
compute_r_squared(np.asarray(decoded_test+adata_test_final.varm["X_train_mean"].T), adata_test_final.X.A)

0.9996564467456698

In [226]:
decoded_ood = np.matmul(adata_ood_final.obsm["X_pca"], adata_train_final.varm["PCs"].T)

In [227]:
compute_r_squared(np.asarray(decoded_ood+adata_ood_final.varm["X_train_mean"].T), adata_ood_final.X.A)

0.9995589681864923