In [1]:
!pip uninstall -y typing_extensions
!pip install --quiet scvi-colab

from scvi_colab import install

install()

Found existing installation: typing_extensions 4.11.0
Uninstalling typing_extensions-4.11.0:
  Successfully uninstalled typing_extensions-4.11.0
[34mINFO    [0m scvi-colab: Installing scvi-tools.                                                                        
[34mINFO    [0m scvi-colab: Install successful. Testing import.                                                           


In [2]:
import sys
import warnings

IN_COLAB = "google.colab" in sys.modules
if IN_COLAB:
    !pip install --quiet scrublet

import anndata
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import scanpy as sc
import scrublet as scr
import scvi

In [3]:
warnings.simplefilter(action="ignore", category=FutureWarning)


sc.set_figure_params(figsize=(4, 4))
scvi.settings.seed = 94705

%config InlineBackend.print_figure_kwargs={'facecolor' : "w"}
%config InlineBackend.figure_format='retina'

INFO: Seed set to 94705
INFO:lightning.fabric.utilities.seed:Seed set to 94705


# Data preparation

In [17]:
def get_hvg(ds, n_top_genes):
    """
    Returns an AnnData's highly-variable genes list.

    Parameters
    ----------
    ds: AnnData
        The dataset to compute the hvg genes of.

    Returns
    -------
    List[str]:
        List of genes with the highest variability.
    """
    sc.pp.highly_variable_genes(ds, n_top_genes=n_top_genes)

    return ds.var["highly_variable"][ds.var["highly_variable"]].index.to_list()

In [5]:
from google.colab import drive
drive.mount('/content/drive')

path_to_dataset = "drive/MyDrive/research/in-progress/RefCM/"

Mounted at /content/drive


In [68]:
# no batch effect
ref_adata = sc.read_csv(path_to_dataset + 'data/splatter_ref_counts.csv')
q_adata = sc.read_csv(path_to_dataset + 'data/splatter_q_counts.csv')

ref_labels = pd.read_csv(path_to_dataset + 'data/splatter_ref_labels.csv', header=None).to_numpy()[:, 0]
q_labels = pd.read_csv(path_to_dataset + 'data/splatter_q_labels.csv', header=None).to_numpy()[:, 0]

KeyboardInterrupt: 

In [69]:
# with batch effect
b=1
X_train = pd.read_csv(path_to_dataset + 'data/splatter_ref_counts_b%s.csv' % b, index_col=0).values.T
X_test = pd.read_csv(path_to_dataset + 'data/splatter_q_counts_b%s.csv' % b, index_col=0).values.T

ref_labels = pd.read_csv(path_to_dataset + 'data/splatter_ref_labels_b%s.csv' % b, index_col=0).values.flatten()
q_labels = pd.read_csv(path_to_dataset + 'data/splatter_q_labels_b%s.csv' % b, index_col=0).values.flatten()

q_adata = anndata.AnnData(X_test.astype(np.float64))
ref_adata = anndata.AnnData(X_train.astype(np.float64))

In [70]:
ref_adata.obs['labels'] = ref_labels
q_adata.obs['labels'] = q_labels

In [71]:
# select the top 200 varying genes
sc.pp.highly_variable_genes(ref_adata, n_top_genes=200, flavor='seurat_v3', subset=True)
q_adata = q_adata[:, ref_adata.var_names].copy()

# Train on reference

In [72]:
scvi.model.SCVI.setup_anndata(ref_adata)

In [73]:
# pretraining scvi model
arches_params = dict(
    use_layer_norm="both",
    use_batch_norm="none",
    encode_covariates=True,
    dropout_rate=0.2,
    n_layers=2,
)

vae_ref = scvi.model.SCVI(ref_adata, **arches_params)
vae_ref.train(max_epochs=50)

INFO: GPU available: False, used: False
INFO:lightning.pytorch.utilities.rank_zero:GPU available: False, used: False
INFO: TPU available: False, using: 0 TPU cores
INFO:lightning.pytorch.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO: IPU available: False, using: 0 IPUs
INFO:lightning.pytorch.utilities.rank_zero:IPU available: False, using: 0 IPUs
INFO: HPU available: False, using: 0 HPUs
INFO:lightning.pytorch.utilities.rank_zero:HPU available: False, using: 0 HPUs


Epoch 50/50: 100%|██████████| 50/50 [00:31<00:00,  1.95it/s, v_num=1, train_loss_step=585, train_loss_epoch=597]

INFO: `Trainer.fit` stopped: `max_epochs=50` reached.
INFO:lightning.pytorch.utilities.rank_zero:`Trainer.fit` stopped: `max_epochs=50` reached.


Epoch 50/50: 100%|██████████| 50/50 [00:31<00:00,  1.57it/s, v_num=1, train_loss_step=585, train_loss_epoch=597]


In [74]:
# traing the scanvi model with labels
vae_ref_scan = scvi.model.SCANVI.from_scvi_model(
    vae_ref,
    unlabeled_category="Unknown",
    labels_key="labels",
)

vae_ref_scan.train(max_epochs=20)

[34mINFO    [0m Training for [1;36m20[0m epochs.                                                                                   


INFO: GPU available: False, used: False
INFO:lightning.pytorch.utilities.rank_zero:GPU available: False, used: False
INFO: TPU available: False, using: 0 TPU cores
INFO:lightning.pytorch.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO: IPU available: False, using: 0 IPUs
INFO:lightning.pytorch.utilities.rank_zero:IPU available: False, using: 0 IPUs
INFO: HPU available: False, using: 0 HPUs
INFO:lightning.pytorch.utilities.rank_zero:HPU available: False, using: 0 HPUs


Epoch 20/20: 100%|██████████| 20/20 [00:23<00:00,  1.18s/it, v_num=1, train_loss_step=577, train_loss_epoch=593]

INFO: `Trainer.fit` stopped: `max_epochs=20` reached.
INFO:lightning.pytorch.utilities.rank_zero:`Trainer.fit` stopped: `max_epochs=20` reached.


Epoch 20/20: 100%|██████████| 20/20 [00:23<00:00,  1.17s/it, v_num=1, train_loss_step=577, train_loss_epoch=593]


In [75]:
dir_path_scan = path_to_dataset + "models/vae_ref_scan_top_genes"
vae_ref_scan.save(dir_path_scan)

# Update with query

In [76]:
scvi.model.SCANVI.prepare_query_anndata(q_adata, dir_path_scan)

[34mINFO    [0m File drive/MyDrive/research/in-progress/RefCM/models/vae_ref_scan_top_genes/model.pt already downloaded   
[34mINFO    [0m Found [1;36m100.0[0m% reference vars in query data.                                                                


In [77]:
vae_q = scvi.model.SCANVI.load_query_data(
    q_adata,
    dir_path_scan,
)

[34mINFO    [0m File drive/MyDrive/research/in-progress/RefCM/models/vae_ref_scan_top_genes/model.pt already downloaded   


In [78]:
vae_q.train(
    max_epochs=100,
    plan_kwargs=dict(weight_decay=0.0),
    check_val_every_n_epoch=10,
)

[34mINFO    [0m Training for [1;36m100[0m epochs.                                                                                  


INFO: GPU available: False, used: False
INFO:lightning.pytorch.utilities.rank_zero:GPU available: False, used: False
INFO: TPU available: False, using: 0 TPU cores
INFO:lightning.pytorch.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO: IPU available: False, using: 0 IPUs
INFO:lightning.pytorch.utilities.rank_zero:IPU available: False, using: 0 IPUs
INFO: HPU available: False, using: 0 HPUs
INFO:lightning.pytorch.utilities.rank_zero:HPU available: False, using: 0 HPUs


Epoch 100/100: 100%|██████████| 100/100 [01:40<00:00,  1.03s/it, v_num=1, train_loss_step=1.37e+3, train_loss_epoch=1.4e+3] 

INFO: `Trainer.fit` stopped: `max_epochs=100` reached.
INFO:lightning.pytorch.utilities.rank_zero:`Trainer.fit` stopped: `max_epochs=100` reached.


Epoch 100/100: 100%|██████████| 100/100 [01:40<00:00,  1.01s/it, v_num=1, train_loss_step=1.37e+3, train_loss_epoch=1.4e+3]


In [79]:
q_adata.obs["predictions"] = vae_q.predict()

In [80]:
np.mean(q_adata.obs["predictions"] == q_adata.obs["labels"])

0.488