In [1]:
import torch
import numpy as np
import scanpy as sc
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import scarches as sca
from sklearn.metrics import classification_report

from scarches.dataset.trvae.data_handling import remove_sparsity
from lataq.models import EMBEDCVAE
from lataq.exp_dict import EXPERIMENT_INFO
import time
sc.settings.set_figure_params(dpi=200, frameon=False)
sc.set_figure_params(dpi=500)
plt.rcParams['figure.figsize'] = (5, 5)
torch.set_printoptions(precision=3, sci_mode=False, edgeitems=7)

%load_ext autoreload
%autoreload 2

In [2]:
adata = sc.read(
    '../data/hlca_counts_commonvars.h5ad'
)
#adata = adata[:, adata.var.highly_variable].copy()
#adata.X

In [3]:
condition_key = 'sample'
cell_type_key = ['ann_finest_level']

In [4]:
vae_epochs = 50
scanvi_epochs = 20

early_stopping_kwargs = {
    "early_stopping_metric": "elbo",
    "save_best_state_metric": "elbo",
    "patience": 10,
    "threshold": 0,
    "reduce_lr_on_plateau": True,
    "lr_patience": 8,
    "lr_factor": 0.1,
}
early_stopping_kwargs_scanvi = {
    "early_stopping_metric": "accuracy",
    "save_best_state_metric": "accuracy",
    "on": "full_dataset",
    "patience": 10,
    "threshold": 0.001,
    "reduce_lr_on_plateau": True,
    "lr_patience": 8,
    "lr_factor": 0.1,
}

In [5]:
sca.dataset.setup_anndata(
    adata, batch_key=condition_key, labels_key=cell_type_key[0]
)

vae = sca.models.SCANVI(
    adata,
    None,
    n_layers=2,
    n_latent = 30, # to allow for capturing more heterogeneity
    encode_covariates=True,
    deeply_inject_covariates=False,
    use_layer_norm="both",
    use_batch_norm="none",
    gene_likelihood="nb", # because we have UMI data
)

[34mINFO    [0m Using batches from adata.obs[1m[[0m[32m"sample"[0m[1m][0m                                              
[34mINFO    [0m Using labels from adata.obs[1m[[0m[32m"ann_finest_level"[0m[1m][0m                                     
[34mINFO    [0m Using data from adata.X                                                             
[34mINFO    [0m Computing library size prior per batch                                              
[34mINFO    [0m Successfully registered anndata object containing [1;36m584884[0m cells, [1;36m1897[0m vars, [1;36m166[0m      
         batches, [1;36m58[0m labels, and [1;36m0[0m proteins. Also registered [1;36m0[0m extra categorical covariates  
         and [1;36m0[0m extra continuous covariates.                                                  
[34mINFO    [0m Please do not further modify adata until model is trained.                          




In [6]:
vae.train(
    max_epochs=50
)

[34mINFO    [0m Training for [1;36m50[0m epochs.                                                             


GPU available: True, used: True
TPU available: False, using: 0 TPU cores
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Epoch 50/50: 100%|██████████| 50/50 [1:15:52<00:00, 91.05s/it, loss=714, v_num=1]


In [7]:
vae.save(f"scanvi_hlca_sample", overwrite=True)

In [8]:
adata_latent = sc.AnnData(X=vae.get_latent_representation(), obs=adata.obs)



In [9]:
adata_latent.write('../figure_notebooks/hlca_core_scanvi_sample.h5ad')