In [1]:
import scanpy as sc
import anndata as ad
import multigrate
from random import shuffle
from scipy import sparse
from matplotlib import pyplot as plt
import gdown

In [2]:
%config InlineBackend.figure_format = 'retina'

# Load the dataset

In [1]:
!wget "https://hmgubox2.helmholtz-muenchen.de/index.php/s/r2W5dMJdq6mFMZY/download?path=%2Fseurat-2020&files=expressions.h5ad" -O hao2020-expressions.h5ad

--2021-01-05 04:46:52--  https://hmgubox2.helmholtz-muenchen.de/index.php/s/r2W5dMJdq6mFMZY/download?path=%2Fseurat-2020&files=expressions.h5ad
Resolving localhost (localhost)... 127.0.0.1
Connecting to localhost (localhost)|127.0.0.1|:8085... connected.
Proxy request sent, awaiting response... 200 OK
Length: 1949492332 (1.8G) [application/octet-stream]
Saving to: ‘hao2020-expressions.h5ad’


2021-01-05 04:51:08 (7.26 MB/s) - ‘hao2020-expressions.h5ad’ saved [1949492332/1949492332]



In [3]:
scrna = sc.read_h5ad('hao2020-expressions.h5ad')
scrna

AnnData object with n_obs × n_vars = 161764 × 4000
    obs: 'nCount_ADT', 'nFeature_ADT', 'nCount_RNA', 'nFeature_RNA', 'orig.ident', 'lane', 'donor', 'time', 'celltype.l1', 'celltype.l2', 'celltype.l3', 'Phase', 'cell_type'
    var: 'features', 'highly_variable', 'means', 'dispersions', 'dispersions_norm'
    uns: 'celltype.l1_colors', 'celltype.l2_colors', 'celltype.l3_colors', 'neighbors'
    obsm: 'X_apca', 'X_aumap', 'X_pca', 'X_spca', 'X_umap', 'X_wnn.umap'
    varm: 'PCs', 'SPCA'
    layers: 'count'
    obsp: 'distances'

In [6]:
!wget "https://hmgubox2.helmholtz-muenchen.de/index.php/s/r2W5dMJdq6mFMZY/download?path=%2Fseurat-2020&files=protein.h5ad" -O hao2020-proteins.h5ad

--2021-01-05 04:51:19--  https://hmgubox2.helmholtz-muenchen.de/index.php/s/r2W5dMJdq6mFMZY/download?path=%2Fseurat-2020&files=protein.h5ad
Resolving localhost (localhost)... 127.0.0.1
Connecting to localhost (localhost)|127.0.0.1|:8085... connected.
Proxy request sent, awaiting response... 200 OK
Length: 904554908 (863M) [application/octet-stream]
Saving to: ‘hao2020-proteins.h5ad’


2021-01-05 04:53:34 (6.42 MB/s) - ‘hao2020-proteins.h5ad’ saved [904554908/904554908]



In [4]:
cite = sc.read_h5ad('hao2020-proteins.h5ad')
cite

AnnData object with n_obs × n_vars = 161764 × 224
    obs: 'nCount_ADT', 'nFeature_ADT', 'nCount_RNA', 'nFeature_RNA', 'orig.ident', 'lane', 'donor', 'time', 'celltype.l1', 'celltype.l2', 'celltype.l3', 'Phase', 'cell_type'
    var: 'features'
    uns: 'celltype.l1_colors', 'celltype.l2_colors', 'celltype.l3_colors'
    obsm: 'X_apca', 'X_aumap', 'X_pca', 'X_spca', 'X_umap', 'X_wnn.umap'
    varm: 'APCA'
    layers: 'count'

# Paired setting

## Configure and train the model

In [9]:
# paired
model = multigrate.models.MultiVAE(
    adatas=[[scrna], [cite]],
    names=[['scRNA-seq'], ['scCITE-seq']],
    pair_groups=[[0], [0]],
    z_dim=20,
    h_dim=128,
    hiddens=[[], []],
    output_activations=['linear', 'linear'],
    shared_hiddens=[],
    adver_hiddens=[],
    recon_coef=1,
    kl_coef=1e-5,
    integ_coef=1e-2,
    cycle_coef=0,
    adversarial=False,
    dropout=0.2,
)

In [None]:
model.train(
    n_iters=50000,
    batch_size=64,
    lr=3e-4,
    val_split=0.1,
    adv_iters=0,
    kl_anneal_iters=20000,
    validate_every=5000,
    verbose=1
)

  res = method(*args, **kwargs)


## Plot training history

In [None]:
model.history

In [None]:
plt.figure(figsize=(15, 10));
plt.subplot(221);
plt.plot(model.history['iteration'], model.history['train_loss'], '.-', label='Train loss');
plt.plot(model.history['iteration'], model.history['val_loss'], '.-', label='Val loss');
plt.xlabel('#Iterations');
plt.legend();

plt.subplot(222);
plt.plot(model.history['iteration'], model.history['train_recon'], '.-', label='Train recon loss');
plt.plot(model.history['iteration'], model.history['val_recon'], '.-', label='Val recon loss');
plt.xlabel('#Iterations');
plt.legend();

plt.subplot(223);
plt.plot(model.history['iteration'], model.history['train_kl'], '.-', label='Train kl loss');
plt.plot(model.history['iteration'], model.history['val_kl'], '.-', label='Val kl loss');
plt.xlabel('#Iterations');
plt.legend();

plt.subplot(224);
plt.plot(model.history['iteration'], model.history['train_integ'], '.-', label='Train integ loss');
plt.plot(model.history['iteration'], model.history['val_integ'], '.-', label='Val integ loss');
plt.xlabel('#Iterations');
plt.legend();

## Recover and visualize the latent space

In [None]:
z = model.predict(
    adatas=[[scrna], [cite]],
    names=[['scRNA-seq'], ['scCITE-seq']],
    batch_size=64,
)
z

In [None]:
sc.pp.neighbors(z)
sc.tl.umap(z)

In [None]:
sc.pl.umap(z, color=['modality', 'cell_type'], ncols=1)

## Metrics

In [None]:
sc.pp.pca(z)
metrics = multigrate.metrics.metrics(
    z, z,
    batch_key='modality',
    label_key='cell_type',
    method='multigrate'
)
metrics