# Integration and reference mapping with multigrate

In this notebook, we demonstrate how to use Multigrate with scArches: we build a trimodal reference atlas with Multigrate by integrating CITE-seq and multiome data, and map unimodal as well as multimodal queries onto the reference. We use publically available datasets from NeurIPS 2021 workshop https://openproblems.bio/neurips_2021/.

In [1]:
import scarches as sca
import scanpy as sc
import anndata as ad
import numpy as np
import muon
import gdown
import json
import torch
import pandas as pd
from os.path import join as pj
import os
import warnings
warnings.filterwarnings("ignore")

sc.set_figure_params(figsize=(4, 4), fontsize=8)
os.environ['CUDA_VISIBLE_DEVICES'] = '0'

 captum (see https://github.com/pytorch/captum).
INFO:lightning_fabric.utilities.seed:Global seed set to 0
  from .autonotebook import tqdm as notebook_tqdm
  jax.tree_util.register_keypaths(


## Data preprocessing
First, we download the datasets and split them into AnnData objects corresponding to individual modalities: gene expression (RNA) and protein abundance (ADT) for CITE-seq, and gene expression (RNA) and chromatin opennes (ATAC) for multiome.

In [2]:
task = 'wnn'
method = 'multigrate'
experiment = 'multigrate_offline'
result_path = pj('../../../result', task, experiment)
data_root = 'path/data/raw/rna+adt'
save_path = '../multigrate/offline_cellmask/'
batch_key='batch'
label_key='l1'

In [3]:
rna = sc.read_h5ad(pj(data_root, task,'rna_cellmask.h5ad'))
adt = sc.read_h5ad(pj(data_root, task,'adt_cellmask.h5ad'))
rna
muon.prot.pp.clr(adt)
adt
## Prep the input AnnData object
adata = sca.models.organize_multiome_anndatas(
    adatas = [[rna], [adt]],    # a list of anndata objects per modality, RNA-seq always goes first
    layers = [[None], [None]], # if need to use data from .layers, if None use .X
)
adata

sca.models.MultiVAE.setup_anndata(
    adata,
    categorical_covariate_keys=[batch_key],
    rna_indices_end=adata.uns['modality_lengths'][0],
)



# Offline Training

## Initialize the model

Next, we initialize the model. If using raw counts for RNA-seq, use NB loss, if normalized counts, use MSE. For ADT we use CLR-normalized counts and MSE loss. We need to specify `mmd='marginal'` and set the coeficient to the integration loss if we want to later map unimodal data onto this reference.

In [28]:
model = sca.models.MultiVAE(
    adata, 
    losses=['nb', 'mse'],
    loss_coefs={'kl': 1e-1,
               'integ': 3000,
               },
    z_dim=32,
    # integrate_on='Modality',
    mmd='marginal',
)
model.train()
# to load trained model
# state = torch.load('/opt/data/private/xx/code/MIRACLE/comparison/results/wnn/multigrate/offline_cellmask/model.pt')
# state.keys()
# for k, v in state['attr_dict'].items():
#     # print(k)
#     model.__dict__[k] = v
# model.module.load_state_dict(state['model_state_dict'])

INFO:pytorch_lightning.utilities.rank_zero:GPU available: True (cuda), used: True
INFO:pytorch_lightning.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO:pytorch_lightning.utilities.rank_zero:IPU available: False, using: 0 IPUs
INFO:pytorch_lightning.utilities.rank_zero:HPU available: False, using: 0 HPUs
INFO:pytorch_lightning.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Epoch 23/200:  11%|█         | 22/200 [01:59<16:40,  5.62s/it, loss=1.67e+03, v_num=1]

### save latent and model

In [34]:
if not os.path.exists(pj('../', method, 'offline')):
    os.makedirs(pj('../', method, 'offline'))
model.save(pj('../', method, experiment),overwrite=True)
model.get_latent_representation()

In [None]:
for i in pd.unique(adata.obs[batch_key]):

    if type(i)=='str':
        j = int(i[1])-1
        print(j)
        if not os.path.exists(pj(result_path, 'default/predict','subset_%d/z/joint'%(j))):
            os.makedirs(pj(result_path, 'default/predict','subset_%d/z/joint'%(j)))
        print(adata[adata.obs[batch_key]=='P'+str(j+1)].obsm['latent'].shape)
        pd.DataFrame(adata[adata.obs[batch_key]=='P'+str(j+1)].obsm['latent']).to_csv(pj(result_path, 'default/predict','subset_%d/z/joint'%(j),'00.csv'), index=False, header=False)
    else:
        print(i)
        if not os.path.exists(pj(result_path, 'default/predict','subset_%d/z/joint'%(i))):
            os.makedirs(pj(result_path, 'default/predict','subset_%d/z/joint'%(i)))
        print(adata[adata.obs['batch']==i].obsm['latent'].shape)
        pd.DataFrame(adata[adata.obs['batch']==i].obsm['latent']).to_csv(pj(result_path, 'default/predict','subset_%d/z/joint'%(i),'00.csv'), index=False, header=False)

0
(6142, 32)
1
(3978, 32)
2
(4103, 32)
3
(5172, 32)
4
(4205, 32)
5
(5265, 32)
6
(8770, 32)
7
(8578, 32)


### save reconstruction

In [56]:
adata = sca.models.organize_multiome_anndatas(
    adatas = [[rna], [adt]],    # a list of anndata objects per modality, RNA-seq always goes first
    layers = [[None], [None]], # if need to use data from .layers, if None use .X
)
adata

sca.models.MultiVAE.setup_anndata(
    adata,
    categorical_covariate_keys=[batch_key],
    rna_indices_end=adata.uns['modality_lengths'][0],
)


adata = model._validate_anndata(adata)
scdl = model._make_data_loader(adata=adata)

rna_r = torch.Tensor([])
adt_r = torch.Tensor([])

for tensors in scdl: 
    x_r = model.module.sample(tensors)
    rna_r = torch.concat([rna_r, x_r[0].cpu()])
    adt_r = torch.concat([adt_r, x_r[1].cpu()])
l = 0
for i in pd.unique(adata.obs[batch_key]):
    if type(i)==str:
        j = int(i[1])-1
        print(j)
        if j==0:
            rna_batch = rna_r[0:adata[adata.obs[batch_key]=='P'+str(j+1)].shape[0]]
            adt_batch = adt_r[0:adata[adata.obs[batch_key]=='P'+str(j+1)].shape[0]]
        else:
            rna_batch = rna_r[l:l+adata[adata.obs[batch_key]=='P'+str(j+1)].shape[0]]
            adt_batch = adt_r[l:l+adata[adata.obs[batch_key]=='P'+str(j+1)].shape[0]]
        l = l+adata[adata.obs[batch_key]=='P'+str(j+1)].shape[0]
        print(l)
        if not os.path.exists(pj(result_path, 'default/predict','subset_%d/x_bc/rna'%(j))):
            os.makedirs(pj(result_path, 'default/predict','subset_%d/x_bc/rna'%(j)))
        if not os.path.exists(pj(result_path, 'default/predict','subset_%d/x_bc/adt'%(j))):
            os.makedirs(pj(result_path, 'default/predict','subset_%d/x_bc/adt'%(j)))
        pd.DataFrame(rna_batch).to_csv(pj(result_path, 'default/predict','subset_%d/x_bc/rna'%(j),'00.csv'), index=False, header=False)
        pd.DataFrame(adt_batch).to_csv(pj(result_path, 'default/predict','subset_%d/x_bc/adt'%(j),'00.csv'), index=False, header=False)
    else:
        print(i)
        j = i
        if j==0:
            rna_batch = rna_r[0:adata[adata.obs[batch_key]==(j)].shape[0]]
            adt_batch = adt_r[0:adata[adata.obs[batch_key]==(j)].shape[0]]
        else:
            rna_batch = rna_r[l:l+adata[adata.obs[batch_key]==(j)].shape[0]]
            adt_batch = adt_r[l:l+adata[adata.obs[batch_key]==(j)].shape[0]]
        l = l+adata[adata.obs[batch_key]==j].shape[0]
        print(l)
        
        if not os.path.exists(pj(result_path, 'default/predict','subset_%d/x_bc/rna'%(i))):
            os.makedirs(pj(result_path, 'default/predict','subset_%d/x_bc/rna'%(i)))
        if not os.path.exists(pj(result_path, 'default/predict','subset_%d/x_bc/adt'%(i))):
            os.makedirs(pj(result_path, 'default/predict','subset_%d/x_bc/adt'%(i)))
        pd.DataFrame(rna_batch).to_csv(pj(result_path, 'default/predict','subset_%d/x_bc/rna'%(i),'00.csv'), index=False, header=False)
        pd.DataFrame(adt_batch).to_csv(pj(result_path, 'default/predict','subset_%d/x_bc/adt'%(i),'00.csv'), index=False, header=False)

0
6142
1
10120
2
14223
3
19395
4
23600
5
28865
6
37635
7
46213


In [57]:
adt_batch.shape

torch.Size([0, 224])

### save latent and model

In [None]:
if not os.path.exists(pj('../', method, experiment)):
    os.makedirs(pj('../', method, experiment))
model.save(pj('../', method, experiment),overwrite=True)
model.get_latent_representation()

In [None]:
for i in pd.unique(adata.obs[batch_key]):

    if type(i)=='str':
        j = int(i[1])-1
        print(j)
        if not os.path.exists(pj(result_path, 'default/predict','subset_%d/z/joint'%(j))):
            os.makedirs(pj(result_path, 'default/predict','subset_%d/z/joint'%(j)))
        print(adata[adata.obs[batch_key]=='P'+str(j+1)].obsm['latent'].shape)
        pd.DataFrame(adata[adata.obs[batch_key]=='P'+str(j+1)].obsm['latent']).to_csv(pj(result_path, 'default/predict','subset_%d/z/joint'%(j),'00.csv'), index=False, header=False)
    else:
        print(i)
        if not os.path.exists(pj(result_path, 'default/predict','subset_%d/z/joint'%(i))):
            os.makedirs(pj(result_path, 'default/predict','subset_%d/z/joint'%(i)))
        print(adata[adata.obs['batch']==i].obsm['latent'].shape)
        pd.DataFrame(adata[adata.obs['batch']==i].obsm['latent']).to_csv(pj(result_path, 'default/predict','subset_%d/z/joint'%(i),'00.csv'), index=False, header=False)

0
(6142, 32)
1
(3978, 32)
2
(4103, 32)
3
(5172, 32)
4
(4205, 32)
5
(5265, 32)
6
(8770, 32)
7
(8578, 32)


### save reconstruction

In [None]:
adata = model._validate_anndata(adata)
scdl = model._make_data_loader(adata=adata)

rna_r = torch.Tensor([])
adt_r = torch.Tensor([])

for tensors in scdl: 
    x_r = model.module.sample(tensors)
    rna_r = torch.concat([rna_r, x_r[0].cpu()])
    adt_r = torch.concat([adt_r, x_r[1].cpu()])

rna_r = sc.AnnData(np.array(rna_r))
rna_r.obs_names = adata.obs_names
rna_r.obs =  adata.obs

adt_r = sc.AnnData(np.array(adt_r))
adt_r.obs_names = adata.obs_names
adt_r.obs =  adata.obs

for i in pd.unique(adata.obs[batch_key]):
    if type(i)=='str':
        j = int(i[1])-1
        print(j)
        if not os.path.exists(pj(result_path, 'default/predict','subset_%d/x/rna'%(j))):
            os.makedirs(pj(result_path, 'default/predict','subset_%d/x/rna'%(j)))
        if not os.path.exists(pj(result_path, 'default/predict','subset_%d/x/adt'%(j))):
            os.makedirs(pj(result_path, 'default/predict','subset_%d/x/adt'%(j)))
        pd.DataFrame(rna_r[rna_r.obs[batch_key]=='P'+str(j+1)].X).to_csv(pj(result_path, 'default/predict','subset_%d/x/rna'%(j),'00.csv'), index=False, header=False)
        pd.DataFrame(rna_r[rna_r.obs[batch_key]=='P'+str(j+1)].X).to_csv(pj(result_path, 'default/predict','subset_%d/x/adt'%(j),'00.csv'), index=False, header=False)
    else:
        print(i)
        if not os.path.exists(pj(result_path, 'default/predict','subset_%d/x/rna'%(i))):
            os.makedirs(pj(result_path, 'default/predict','subset_%d/x/rna'%(i)))
        if not os.path.exists(pj(result_path, 'default/predict','subset_%d/x/adt'%(i))):
            os.makedirs(pj(result_path, 'default/predict','subset_%d/x/adt'%(i)))
        pd.DataFrame(adata[adata.obs[batch_key]==i].X).to_csv(pj(result_path, 'default/predict','subset_%d/x/rna'%(i),'00.csv'), index=False, header=False)
        pd.DataFrame(adata[adata.obs[batch_key]==i].X).to_csv(pj(result_path, 'default/predict','subset_%d/x/adt'%(i),'00.csv'), index=False, header=False)

0
1
2
3
4
5
6
7
