In [1]:
import os
import tempfile
import pertpy as pt

import numpy as np
import requests
import scanpy as sc
import scvi
import seaborn as sns
import torch

OMP: Info #276: omp_set_nested routine deprecated, please use omp_set_max_active_levels instead.


In [2]:
scvi.settings.seed = 0
print("Last run with scvi-tools version:", scvi.__version__)

Seed set to 0


Last run with scvi-tools version: 1.1.3


In [3]:
sc.set_figure_params(figsize=(6, 6), frameon=False)
sns.set_theme()
torch.set_float32_matmul_precision("high")
save_dir = tempfile.TemporaryDirectory()

%config InlineBackend.print_figure_kwargs={"facecolor": "w"}
%config InlineBackend.figure_format="retina"

In [4]:
mdata = pt.dt.papalexi_2021()



In [5]:
mdata["rna"].layers['counts'] = mdata["rna"].X

In [None]:
sc.pp.normalize_total(mdata["rna"])
sc.pp.log1p(mdata["rna"])
sc.pp.highly_variable_genes(mdata["rna"],n_top_genes=2000, batch_key="replicate",  subset=True)

In [6]:
adata_ref = mdata["rna"].copy()


In [7]:
adata_ref

AnnData object with n_obs × n_vars = 20729 × 18649
    obs: 'orig.ident', 'nCount_RNA', 'nFeature_RNA', 'nCount_HTO', 'nFeature_HTO', 'nCount_GDO', 'nCount_ADT', 'nFeature_ADT', 'percent.mito', 'MULTI_ID', 'HTO_classification', 'guide_ID', 'gene_target', 'NT', 'perturbation', 'replicate', 'S.Score', 'G2M.Score', 'Phase'
    var: 'name'
    layers: 'counts'

In [8]:
scvi.external.ContrastiveVI.setup_anndata(adata_ref, batch_key = 'replicate', layer="counts")

  _verify_and_correct_data_format(adata, self.attr_name, self.attr_key)


In [9]:
contrastive_vi_model = scvi.external.ContrastiveVI(
    adata_ref, n_salient_latent=10, n_background_latent=10, use_observed_lib_size=False
)

In [10]:
background_indices = np.where(adata_ref.obs["gene_target"] == "NT")[0]
target_indices = np.where(adata_ref.obs["gene_target"] != "NT")[0]

In [None]:
contrastive_vi_model.train(
    background_indices=background_indices,
    target_indices=target_indices,
    early_stopping=True,
    max_epochs=500,
)

GPU available: True (mps), used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
/Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/site-packages/lightning/pytorch/trainer/setup.py:187: GPU available but not used. You can set it by doing `Trainer(accelerator='gpu')`.
/Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:441: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=9` in the `DataLoader` to improve performance.
/Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:441: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=9` in the `Da

Epoch 1/500:   0%|                                                                                             | 0/500 [00:00<?, ?it/s]

/Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/site-packages/lightning/pytorch/utilities/data.py:77: Trying to infer the `batch_size` from an ambiguous collection. The batch size we found is 110. To avoid any miscalculations, use `self.log(..., batch_size=batch_size)`.


Epoch 2/500:   0%|                       | 1/500 [00:50<7:00:12, 50.53s/it, v_num=1, train_loss_step=3.37e+4, train_loss_epoch=1.02e+5]

/Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/site-packages/lightning/pytorch/utilities/data.py:77: Trying to infer the `batch_size` from an ambiguous collection. The batch size we found is 128. To avoid any miscalculations, use `self.log(..., batch_size=batch_size)`.


Epoch 14/500:   3%|▌                       | 13/500 [11:11<7:00:33, 51.81s/it, v_num=1, train_loss_step=1.97e+4, train_loss_epoch=2e+4]

In [None]:
perturbed_adata = adata_ref[adata_ref.obs["gene_target"] != "NT"]  # Only consider perturbed cells


In [None]:
perturbed_adata.obsm["salient_rep"] = contrastive_vi_model.get_latent_representation(
    perturbed_adata, representation_kind="salient"
)

In [None]:
sc.pp.neighbors(perturbed_adata, use_rep="salient_rep")
sc.tl.umap(perturbed_adata)
sc.pl.umap(perturbed_adata, color=["Phase", "gene_target"])

In [None]:
sc.pl.umap(perturbed_adata, color=["NT"])