In [1]:
import argparse
import os
import pickle
import sys

import numpy as np
import scanpy as sc
import torch


from scvi._settings import settings
from scvi.model import SCVI, TOTALVI
from scvi.data import synthetic_iid

from contrastive_vi.data.dataloaders.contrastive_dataloader import ContrastiveDataLoader
from contrastive_vi.module.total_contrastive_vi import TotalContrastiveVIModule
from contrastive_vi.model.total_contrastive_vi import TotalContrastiveVIModel

from scripts import constants

Global seed set to 0
1: package ‘methods’ was built under R version 3.6.1 
2: package ‘datasets’ was built under R version 3.6.1 
3: package ‘utils’ was built under R version 3.6.1 
4: package ‘grDevices’ was built under R version 3.6.1 
5: package ‘graphics’ was built under R version 3.6.1 
6: package ‘stats’ was built under R version 3.6.1 


In [2]:
n_genes = 50
n_proteins = 4
adata = synthetic_iid(
    run_setup_anndata=False,
    n_batches=2,  # Same number of cells in each batch.
    n_genes=n_genes,
    n_proteins=n_proteins,
)
# Make number of cells unequal across batches to test edge cases.
adata = adata[:-3, :]
adata.layers["count"] = adata.X.copy()

  adata = AnnData(data)


In [3]:
TOTALVI.setup_anndata(
    adata,
    protein_expression_obsm_key=constants.PROTEIN_EXPRESSION_KEY,
    layer="count",
)

[34mINFO    [0m No batch_key inputted, assuming all cells are same batch                            
[34mINFO    [0m No label_key inputted, assuming all cells have same label                           
[34mINFO    [0m Using data from adata.layers[1m[[0m[32m"count"[0m[1m][0m                                               
[34mINFO    [0m Using protein expression from adata.obsm[1m[[0m[32m'protein_expression'[0m[1m][0m                      
[34mINFO    [0m Generating sequential protein names                                                 
[34mINFO    [0m Successfully registered anndata object containing [1;36m397[0m cells, [1;36m50[0m vars, [1;36m1[0m batches, [1;36m1[0m  
         labels, and [1;36m4[0m proteins. Also registered [1;36m0[0m extra categorical covariates and [1;36m0[0m extra  
         continuous covariates.                                                              
[34mINFO    [0m Please do not further modify adata until model is

In [4]:
background_indices = (
    adata.obs.index[(adata.obs["batch"] == "batch_0")]
    .astype(int)
    .tolist()
)
target_indices = (
    adata.obs.index[(adata.obs["batch"] == "batch_1")]
    .astype(int)
    .tolist()
)

In [5]:
model = TotalContrastiveVIModel(
    adata,
    n_hidden=16,
    n_background_latent=4,
    n_salient_latent=2,
)

[34mINFO    [0m contrastive_vi: The model has been initialized                                      


In [6]:
model.train(
    background_indices=background_indices,
    target_indices=target_indices,
    max_epochs=25,
    check_val_every_n_epoch=1,
    train_size=0.8,
    early_stopping=True,
    use_gpu=False,
)

GPU available: True, used: False
TPU available: False, using: 0 TPU cores
  rank_zero_warn(


Epoch 25/25: 100%|███████████████████████████████████████████████████████████| 25/25 [00:47<00:00,  1.91s/it, loss=456, v_num=1]


In [9]:
target_adata = adata[target_indices].copy()
latent_representations = model.get_latent_representation(
    adata=target_adata, representation_kind="salient"
)

In [10]:
latent_representations.shape

(197, 2)

In [11]:
latent_representations = model.get_latent_representation(
    adata=target_adata, representation_kind="background"
)
latent_representations.shape

(197, 4)