In [2]:
import scanpy as sc
import scvi

[2020-07-29 15:16:39,535] INFO - scvi._settings | Added StreamHandler with custom formatter to 'scvi' logger.


In [45]:
import scIB

In [4]:
import scvi.models

In [58]:
adata = sc.read('/storage/groups/ml01/workspace/scIB/simulations_1_1/sim1_1_norm.h5ad')

In [26]:
from scvi.dataset import AnnDatasetFromAnnData
from sklearn.preprocessing import LabelEncoder

In [48]:
import numpy as np

In [17]:
adata

AnnData object with n_obs × n_vars = 26668 × 18756 
    obs: 'barcode', 'batch', 'cell_type', 'cell_type_union', 'channel', 'log_counts', 'marker_gene', 'n_counts', 'n_genes', 'percent_mito', 'sample', 'sample_id', 'sex', 'size_factors', 'study', 'tissue'
    var: 'gene_ids-1-1'
    layers: 'counts'

In [27]:
le = LabelEncoder()
adata.obs['labels'] = le.fit_transform(adata.obs['cell_type_union'].values)

In [32]:
net_adata = AnnDatasetFromAnnData(adata, batch_label='batch', class_label='labels')

[2020-07-29 15:48:49,661] INFO - scvi.dataset.dataset | Remapping batch_indices to [0,N]
[2020-07-29 15:48:49,663] INFO - scvi.dataset.dataset | Remapping labels to [0,N]
[2020-07-29 15:48:49,823] INFO - scvi.dataset.dataset | Computing the library size for the new data
[2020-07-29 15:48:49,995] INFO - scvi.dataset.dataset | Downsampled from 26668 to 26668 cells


In [37]:
net_adata.labels = adata.obs['labels']

In [39]:
net_adata.n_labels

107

In [53]:
def runScanvi(adata, batch, labels, hvg=None):
    # Use non-normalized (count) data for scvi!
    # Expects data only on HVGs
    
    scIB.utils.checkSanity(adata, batch, hvg)

    # Check for counts data layer
    if 'counts' not in adata.layers:
        raise TypeError('Adata does not contain a `counts` layer in `adata.layers[`counts`]`')

    from scvi.models import VAE, SCANVI
    from scvi.inference import AlternateSemiSupervisedTrainer, SemiSupervisedTrainer
    from sklearn.preprocessing import LabelEncoder
    from scvi.dataset import AnnDatasetFromAnnData

    # Defaults from SCVI github tutorials scanpy_pbmc3k and harmonization
    n_epochs=np.min([round((20000/adata.n_obs)*400), 400])
    n_latent=30
    n_hidden=128
    n_layers=2
    
    net_adata = adata.copy()
    net_adata.X = adata.layers['counts']
    del net_adata.layers['counts']
    # Ensure that the raw counts are not accidentally used
    del net_adata.raw # Note that this only works from anndata 0.7

    # Define batch indices
    le = LabelEncoder()
    net_adata.obs['batch_indices'] = le.fit_transform(net_adata.obs[batch].values)
    net_adata.obs['labels'] = le.fit_transform(net_adata.obs[labels].values)

    net_adata = AnnDatasetFromAnnData(net_adata)
    net_adata.labels = adata.obs['labels']

    scanvi = SCANVI(
        net_adata.nb_genes,
        reconstruction_loss='nb',
        n_batch=net_adata.n_batches,
        n_labels=net_adata.n_labels,
        n_layers=n_layers,
        n_latent=n_latent,
        n_hidden=n_hidden,
    )

    trainer = SemiSupervisedTrainer(
        scanvi,
        net_adata,
        train_size=1.0,
        use_cuda=False,
    )

    trainer.train(n_epochs=5)

    full = trainer.create_posterior(trainer.model, net_adata, indices=np.arange(len(net_adata)))
    latent, _, _ = full.sequential().get_latent()

    adata.obsm['X_emb'] = latent

    return adata


In [55]:
integrated = runScanvi(adata, 'batch', 'cell_type_union')

[2020-07-29 16:08:39,334] INFO - scvi.dataset.dataset | Remapping batch_indices to [0,N]
[2020-07-29 16:08:39,336] INFO - scvi.dataset.dataset | Remapping labels to [0,N]
[2020-07-29 16:08:39,495] INFO - scvi.dataset.dataset | Computing the library size for the new data
[2020-07-29 16:08:39,646] INFO - scvi.dataset.dataset | Downsampled from 26668 to 26668 cells


training: 100%|██████████| 5/5 [05:37<00:00, 67.49s/it]


In [57]:
integrated

AnnData object with n_obs × n_vars = 26668 × 18756 
    obs: 'barcode', 'batch', 'cell_type', 'cell_type_union', 'channel', 'log_counts', 'marker_gene', 'n_counts', 'n_genes', 'percent_mito', 'sample', 'sample_id', 'sex', 'size_factors', 'study', 'tissue', 'labels'
    var: 'gene_ids-1-1'
    obsm: 'X_emb'
    layers: 'counts'