# Build reference atlas from scratch

In [None]:
import os
os.chdir('../')
import warnings
warnings.simplefilter(action='ignore', category=FutureWarning)
warnings.simplefilter(action='ignore', category=UserWarning)

In [None]:
import scanpy as sc
import torch
import scarches as sca
from scarches.dataset.trvae.data_handling import remove_sparsity
import matplotlib.pyplot as plt
import numpy as np
import gdown

In [None]:
sc.settings.set_figure_params(dpi=200, frameon=False)
sc.set_figure_params(dpi=200)
sc.set_figure_params(figsize=(4, 4))
torch.set_printoptions(precision=3, sci_mode=False, edgeitems=7)

### Download raw Dataset

In [None]:
url = 'https://drive.google.com/uc?id=1LaYOadbotGC6gXAlo-aKfHz-spoFnawk'
output = 'pbmc.h5ad'
gdown.download(url, output, quiet=False)

In [None]:
adata = sc.read('pbmc.h5ad')

In [None]:
adata.X = adata.layers["counts"].copy()

We now split the data into reference and query dataset to simulate the building process. Here we use the '10X' batch as query data.

In [None]:
target_conditions = ["10X"]
source_adata = adata[~adata.obs.study.isin(target_conditions)].copy()
target_adata = adata[adata.obs.study.isin(target_conditions)].copy()
print(source_adata)
print(target_adata)

For a better model performance it is necessary to select HVGs. We are doing this by applying the scanpy.pp function highly_variable_genes(). The n_top_genes is set to 2000 here. However, if you have more complicated datasets you might have to  increase number of genes to capture more diversity in the data.

In [None]:
source_adata.raw = source_adata

In [None]:
source_adata

In [None]:
sc.pp.normalize_total(source_adata)

In [None]:
sc.pp.log1p(source_adata)

In [None]:
sc.pp.highly_variable_genes(
    source_adata,
    n_top_genes=2000,
    batch_key="batch",
    subset=True)

For consistency we set adata.X to be raw counts. In other datasets that may be already the case

In [None]:
source_adata.X = source_adata.raw[:, source_adata.var_names].X

In [None]:
source_adata

### Create SCVI model and train it on reference dataset

Remember that the adata file has to have count data in adata.X for SCVI/SCANVI if not further specified.

In [None]:
sca.models.SCVI.setup_anndata(source_adata, batch_key="batch")

Create the SCVI model instance with ZINB loss as default. Insert "gene_likelihood='nb'," to change the reconstruction loss to NB loss.

In [None]:
vae = sca.models.SCVI(
    source_adata,
    n_layers=2,
    encode_covariates=True,
    deeply_inject_covariates=False,
    use_layer_norm="both",
    use_batch_norm="none",
)

In [None]:
early_stopping_kwargs = {
    "early_stopping_metric": "elbo",
    "save_best_state_metric": "elbo",
    "patience": 10,
    "threshold": 0,
    "reduce_lr_on_plateau": True,
    "lr_patience": 8,
    "lr_factor": 0.1,
}
vae.train(n_epochs=500, frequency=1, early_stopping_kwargs=early_stopping_kwargs)

The resulting latent representation of the data can then be visualized with UMAP

In [None]:
reference_latent = sc.AnnData(vae.get_latent_representation())
reference_latent.obs["cell_type"] = source_adata.obs["final_annotation"].tolist()
reference_latent.obs["batch"] = source_adata.obs["batch"].tolist()

In [None]:
sc.pp.neighbors(reference_latent, n_neighbors=8)
sc.tl.leiden(reference_latent)
sc.tl.umap(reference_latent)
sc.pl.umap(reference_latent,
           color=['batch', 'cell_type'],
           frameon=False,
           wspace=0.6,
           )

After pretraining the model can be saved for later use or also be uploaded for other researchers with via Zenodo. For the second option please also have a look at the Zenodo notebook.

In [None]:
ref_path = 'ref_model/'
vae.save(ref_path, overwrite=True)

### Use pretrained reference model and apply surgery with a new query dataset to get a bigger reference atlas

In [None]:
target_adata

Since the model requires the datasets to have the same genes we also filter the query dataset to have the same genes as the reference dataset.

In [None]:
target_adata = target_adata[:, source_adata.var_names]
target_adata

We then can apply the model surgery with the new query dataset:

In [None]:
model = sca.models.SCVI.load_query_data(
    target_adata,
    ref_path,
    freeze_dropout = True,
)

In [None]:
model.train(n_epochs=500, frequency=1, early_stopping_kwargs=early_stopping_kwargs, weight_decay=0)

In [None]:
query_latent = sc.AnnData(model.get_latent_representation())
query_latent.obs['cell_type'] = target_adata.obs["final_annotation"].tolist()
query_latent.obs['batch'] = target_adata.obs["batch"].tolist()

In [None]:
sc.pp.neighbors(query_latent)
sc.tl.leiden(query_latent)
sc.tl.umap(query_latent)
plt.figure()
sc.pl.umap(
    query_latent,
    color=["batch", "cell_type"],
    frameon=False,
    wspace=0.6,
)

And again we can save or upload the retrained model for later use or additional extensions.

In [None]:
surgery_path = 'surgery_model'
model.save(surgery_path, overwrite=True)

### Get latent representation of reference + query dataset and compute UMAP

In [None]:
adata_full = source_adata.concatenate(target_adata, batch_key="ref_query")
adata_full

In [None]:
full_latent = sc.AnnData(model.get_latent_representation(adata=adata_full))
full_latent.obs['cell_type'] = adata_full.obs["final_annotation"].tolist()
full_latent.obs['batch'] = adata_full.obs["batch"].tolist()

In [None]:
sc.pp.neighbors(full_latent)
sc.tl.leiden(full_latent)
sc.tl.umap(full_latent)
plt.figure()
sc.pl.umap(
    full_latent,
    color=["batch", "cell_type"],
    frameon=False,
    wspace=0.6,
)