In [18]:
%pylab inline
%config InlineBackend.figure_format ='retina'
import time

import pandas as pd
import anndata

import torch

from scvi.dataset import GeneExpressionDataset
from scvi.models import VAE, LDVAE
from scvi.inference import UnsupervisedTrainer
from scvi.inference.posterior import Posterior

Populating the interactive namespace from numpy and matplotlib


In [2]:
adata = anndata.read('cao_atlas.h5ad')

In [4]:
idx = adata.obs.query('detected_doublet == "False"').index
adata = adata[idx]

In [5]:
X, local_means, local_vars, batch_indices, labels = \
GeneExpressionDataset.get_attributes_from_matrix(adata.X)

In [6]:
cells_dataset = \
GeneExpressionDataset(X, local_means, local_vars, batch_indices, labels,
                      gene_names=np.array(adata.var.index.values, dtype=str))

In [7]:
cells_dataset.subsample_genes(1000)

Downsampling from 26183 to 1000 genes
Downsampling from 1949131 to 1949131 cells


In [13]:
vae = VAE(cells_dataset.nb_genes, reconstruction_loss='nb')

In [14]:
trainer = UnsupervisedTrainer(vae, cells_dataset, use_cuda=True, frequency=1)

In [15]:
trainer.train(n_epochs=3, )

training: 100%|█████████████████████████████████████████████████████████████████████████| 3/3 [18:49<00:00, 376.28s/it]


In [16]:
trainer.history['ll_train_set']

[2703.8185090486486, 603.1618393949595, 602.4474014493832, 600.6084155099206]

In [17]:
trainer.history['ll_test_set']

[2702.0280159131485, 603.0747248539095, 602.3740936183711, 600.5486634235347]

In [21]:
hist_df = \
pd.DataFrame({'vae_ll_train_set': trainer.history['ll_train_set'],
              'vae_ll_test_set': trainer.history['ll_test_set']})

hist_df

Unnamed: 0,vae_ll_train_set,vae_ll_test_set
0,2703.818509,2702.028016
1,603.161839,603.074725
2,602.447401,602.374094
3,600.608416,600.548663


In [22]:
ldvae = LDVAE(cells_dataset.nb_genes, reconstruction_loss='nb')

In [23]:
trainer = UnsupervisedTrainer(ldvae, cells_dataset, use_cuda=True, frequency=1)

In [24]:
trainer.train(n_epochs=3, )

training: 100%|█████████████████████████████████████████████████████████████████████████| 3/3 [17:33<00:00, 352.50s/it]


In [25]:
trainer.history['ll_train_set']

[2679.306746674418, 611.5832323827825, 611.2913683851434, 611.4970870630121]

In [26]:
trainer.history['ll_test_set']

[2677.9429170664166, 611.4190416718506, 611.1215942925484, 611.3273155564038]

In [27]:
hist_df['ldvae_ll_train_set'] = trainer.history['ll_train_set']
hist_df['ldvae_ll_test_set'] = trainer.history['ll_test_set']

In [28]:
hist_df

Unnamed: 0,vae_ll_train_set,vae_ll_test_set,ldvae_ll_train_set,ldvae_ll_test_set
0,2703.818509,2702.028016,2679.306747,2677.942917
1,603.161839,603.074725,611.583232,611.419042
2,602.447401,602.374094,611.291368,611.121594
3,600.608416,600.548663,611.497087,611.327316


In [29]:
hist_df.to_csv('cao_full_training_hist.csv')