# Norman 2019 Training Demo

In [1]:
import sys
#if branch is stable, will install via pypi, else will install from source
branch = "stable"
IN_COLAB = "google.colab" in sys.modules

if IN_COLAB and branch == "stable":
    !pip install cpa-tools
elif IN_COLAB and branch != "stable":
    !pip install --quiet --upgrade jsonschema
    !pip install --quiet git+https://github.com/theislab/cpa

In [2]:
import os
os.chdir('/home/mohsen/projects/cpa/')
os.environ['CUDA_VISIBLE_DEVICES'] = '1'

In [3]:
import cpa
import scanpy as sc

Global seed set to 0


In [4]:
sc.settings.set_figure_params(dpi=100)

In [5]:
data_path = '/home/mohsen/projects/cpa/datasets/Norman2019_prep_new.h5ad'

In [6]:
adata = sc.read(data_path, backup_url='https://drive.google.com/u/0/uc?id=1pxT0fvXtqBBtdv1CCPVwJaMLHe9XpMHo&export=download&confirm=t')
adata

AnnData object with n_obs × n_vars = 108497 × 5000
    obs: 'cov_drug_dose_name', 'dose_val', 'control', 'condition', 'guide_identity', 'drug_dose_name', 'cell_type', 'split', 'split1', 'split2', 'split3', 'split4', 'split5', 'split6', 'split7', 'split8', 'split9', 'split10', 'split11', 'split12', 'split13', 'split14', 'split15', 'split16', 'split17', 'split18', 'split19', 'split20', 'split21', 'split22', 'split23', 'split24', 'split25'
    var: 'gene_symbols', 'highly_variable', 'means', 'dispersions', 'dispersions_norm'
    uns: 'hvg', 'rank_genes_groups_cov'
    layers: 'counts'

In [7]:
cpa.CPA.setup_anndata(adata, 
                      drug_key='condition', 
                      dose_key='dose_val',
                      control_key='control',
                      categorical_covariate_keys=['cell_type'],
                      combinatorial=True,
                     )

[34mINFO    [0m No batch_key inputted, assuming all cells are same batch                            
[34mINFO    [0m No label_key inputted, assuming all cells have same label                           
[34mINFO    [0m Using data from adata.X                                                             
[34mINFO    [0m Successfully registered anndata object containing [1;36m108497[0m cells, [1;36m5000[0m vars, [1;36m1[0m        
         batches, [1;36m1[0m labels, and [1;36m0[0m proteins. Also registered [1;36m0[0m extra categorical covariates   
         and [1;36m0[0m extra continuous covariates.                                                  
[34mINFO    [0m Please do not further modify adata until model is trained.                          


100%|███████████████████████████████████████████████████████████████████████████████████████████████████████| 108497/108497 [00:14<00:00, 7551.64it/s]


In [8]:
ae_hparams = {'adversary_depth': 2,
              'adversary_width': 256,
              'autoencoder_depth': 4,
              'autoencoder_width': 256,
              'dosers_depth': 3,
              'dosers_width': 32,
              'use_batch_norm': True,
              'use_layer_norm': False,
              'dropout_rate': 0.0,
              'variational': False,
              'seed': 31,
              }

trainer_params = {
    'n_epochs_warmup': 0,
    'adversary_lr': 0.00012219948594647382,
    'adversary_steps': 2,
    'adversary_wd': 1.4033946047401463e-05,
    'autoencoder_lr': 0.00014147035543889223,
    'autoencoder_wd': 2.2782368178455333e-08,
    'dosers_lr': 0.0007022363227033381,
    'dosers_wd': 6.249509305603141e-06,
    'penalty_adversary': 0.013702812231919399,
    'reg_adversary': 4.02272482876072,
    'step_size_lr': 45,
}

In [9]:
model = cpa.CPA(adata=adata, 
                n_latent=512,
                loss_ae='gauss',
                doser_type='logsigm',
                split_key='split',
                **ae_hparams,
               )

Global seed set to 31


In [None]:
model.train(max_epochs=2000,
            use_gpu=True, 
            batch_size=256,
            early_stopping=True,
            plan_kwargs=trainer_params,
            early_stopping_patience=15,
            check_val_every_n_epoch=20,
            save_path='/home/mohsen/projects/cpa/lightning_logs/Norman2019_prep_new/',
           )

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [1]

  | Name   | Type      | Params
-------------------------------------
0 | module | CPAModule | 4.8 M 
-------------------------------------
4.8 M     Trainable params
0         Non-trainable params
4.8 M     Total params
19.130    Total estimated model params size (MB)


Epoch 420/2000:  21%|▏| 420/2000 [5:09:51<7:14:49, 16.51s/it, recon=-1.32, adv_loss=0.0679, val_recon=-2.61, val_disnt_basal=0.0218, val_disnt_after=0

In [None]:
cpa.pl.plot_history(model)

In [28]:
cpa_api = cpa.ComPertAPI(adata, model)

In [29]:
cpa_plots = cpa.pl.CompertVisuals(cpa_api, fileprefix=None)

In [30]:
cpa_api.compute_comb_emb(thrh=30)

In [17]:
import pandas as pd

In [None]:
# construct prediction annadata
subset = adata[adata.obs['cell_type'] == 'A549'].copy()
genes_control = subset[subset.obs['condition'] == 'ctrl'].copy()

df = pd.DataFrame({'condition': ['TSC22D1+ctrl'] + ['KLF1+MAP2K6'], 
                   'dose_val': ['1+1'] + ['1+1'], 
                   'cell_type': ['A549']*2})

pred = cpa_api.predict(genes_control.X.toarray(), df)  # normally would put `sample=True` here
pred.obs['cov_drug_dose_name'] = pred.obs['cell_type'].astype(str) + '_' + pred.obs['condition'].astype(str) + '_' + pred.obs['dose_val'].astype(str)
pred.obs['cov_drug_dose_name'] = pred.obs['cov_drug_dose_name'].astype('category')
pred

[34mINFO    [0m .obs[1m[[0m_scvi_batch[1m][0m not found in target, assuming every cell is same category         
[34mINFO    [0m .obs[1m[[0m_scvi_labels[1m][0m not found in target, assuming every cell is same category        
[34mINFO    [0m Using data from adata.X                                                             
[34mINFO    [0m Registered keys:[1m[[0m[32m'X'[0m, [32m'batch_indices'[0m, [32m'labels'[0m, [32m'drugs_doses'[0m, [32m'cell_type'[0m[1m][0m        
[34mINFO    [0m Successfully registered anndata object containing [1;36m8907[0m cells, [1;36m5000[0m vars, [1;36m1[0m batches, 
         [1;36m1[0m labels, and [1;36m0[0m proteins. Also registered [1;36m0[0m extra categorical covariates and [1;36m0[0m extra
         continuous covariates.                                                              




[34mINFO    [0m .obs[1m[[0m_scvi_batch[1m][0m not found in target, assuming every cell is same category         
[34mINFO    [0m .obs[1m[[0m_scvi_labels[1m][0m not found in target, assuming every cell is same category        
[34mINFO    [0m Using data from adata.X                                                             
[34mINFO    [0m Registered keys:[1m[[0m[32m'X'[0m, [32m'batch_indices'[0m, [32m'labels'[0m, [32m'drugs_doses'[0m, [32m'cell_type'[0m[1m][0m        
[34mINFO    [0m Successfully registered anndata object containing [1;36m8907[0m cells, [1;36m5000[0m vars, [1;36m1[0m batches, 
         [1;36m1[0m labels, and [1;36m0[0m proteins. Also registered [1;36m0[0m extra categorical covariates and [1;36m0[0m extra
         continuous covariates.                                                              


AnnData expects .obs.index to contain strings, but got values like:
    [0, 1, 2, 3, 4]

    Inferred to be: integer

  value_idx = self._prep_dim_index(value.index, attr)


AnnData object with n_obs × n_vars = 17814 × 5000
    obs: 'condition', 'dose_val', 'cell_type', 'uncertainty_cosine', 'uncertainty_euclidean', 'closest_cond_cosine', 'closest_cond_euclidean', 'cov_drug_dose_name'
    layers: 'variance'

In [19]:
adata.obs['cov_drug_dose_name'].unique()

['A549_ctrl_1', 'A549_TSC22D1+ctrl_1+1', 'A549_KLF1+MAP2K6_1+1', 'A549_CEBPE+RUNX1T1_1+1', 'A549_MAML2+ctrl_1+1', ..., 'A549_SNAI1+ctrl_1+1', 'A549_PLK4+STIL_1+1', 'A549_ZBTB10+ELMSAN1_1+1', 'A549_CDKN1C+ctrl_1+1', 'A549_C3orf72+FOXL2_1+1']
Length: 284
Categories (284, object): ['A549_AHR+FEV_1+1', 'A549_AHR+KLF1_1+1', 'A549_AHR+ctrl_1+1', 'A549_ARID1A+ctrl_1+1', ..., 'A549_ctrl+UBASH3B_1+1', 'A549_ctrl+ZBTB1_1+1', 'A549_ctrl+ZBTB25_1+1', 'A549_ctrl_1']

In [22]:
de_dict = adata[adata.obs.split == 'train'].uns['rank_genes_groups_cov']
de_dict = {k: adata.var_names for k, v in de_dict.items()}

In [23]:
pred_adata = pred.concatenate(adata, batch_key='source', batch_categories=['pred', 'true'])

Observation names are not unique. To make them unique, call `.obs_names_make_unique`.
Observation names are not unique. To make them unique, call `.obs_names_make_unique`.
Observation names are not unique. To make them unique, call `.obs_names_make_unique`.


In [None]:
drug_name = 'TSC22D1+ctrl'
sc.pl.violin(
    pred_adata[pred_adata.obs.condition == drug_name],
    keys=de_dict[f'A549_{drug_name}_1+1'][:6],
    groupby='source')

  c.reorder_categories(natsorted(c.categories), inplace=True)
Trying to set attribute `.obs` of view, copying.
... storing 'condition' as categorical
  c.reorder_categories(natsorted(c.categories), inplace=True)
Trying to set attribute `.obs` of view, copying.
... storing 'dose_val' as categorical
  c.reorder_categories(natsorted(c.categories), inplace=True)
Trying to set attribute `.obs` of view, copying.
... storing 'cell_type' as categorical
  c.reorder_categories(natsorted(c.categories), inplace=True)
Trying to set attribute `.obs` of view, copying.
... storing 'closest_cond_cosine' as categorical
  c.reorder_categories(natsorted(c.categories), inplace=True)
Trying to set attribute `.obs` of view, copying.
... storing 'closest_cond_euclidean' as categorical
  c.reorder_categories(natsorted(c.categories), inplace=True)
Trying to set attribute `.obs` of view, copying.
... storing 'cov_drug_dose_name' as categorical
