In [1]:
import argparse
import os
import pickle
import sys

import numpy as np
import scanpy as sc
import torch


from scvi._settings import settings
from scvi.model import SCVI, TOTALVI
from scvi.data import synthetic_iid

from contrastive_vi.data.dataloaders.contrastive_dataloader import ContrastiveDataLoader
from contrastive_vi.module.total_contrastive_vi import TotalContrastiveVIModule

from scripts import constants

Global seed set to 0


In [2]:
n_genes = 50
n_proteins = 4
adata = synthetic_iid(
    run_setup_anndata=False,
    n_batches=2,  # Same number of cells in each batch.
    n_genes=n_genes,
    n_proteins=n_proteins,
)
# Make number of cells unequal across batches to test edge cases.
adata = adata[:-3, :]
adata.layers["raw_counts"] = adata.X.copy()

  adata = AnnData(data)


In [3]:
TOTALVI.setup_anndata(
    adata,
    batch_key="batch",
    layer="raw_counts",
    protein_expression_obsm_key="protein_expression"
)

[34mINFO    [0m Using batches from adata.obs[1m[[0m[32m"batch"[0m[1m][0m                                               
[34mINFO    [0m No label_key inputted, assuming all cells have same label                           
[34mINFO    [0m Using data from adata.layers[1m[[0m[32m"raw_counts"[0m[1m][0m                                          
[34mINFO    [0m Using protein expression from adata.obsm[1m[[0m[32m'protein_expression'[0m[1m][0m                      
[34mINFO    [0m Generating sequential protein names                                                 
[34mINFO    [0m Successfully registered anndata object containing [1;36m397[0m cells, [1;36m50[0m vars, [1;36m2[0m batches, [1;36m1[0m  
         labels, and [1;36m4[0m proteins. Also registered [1;36m0[0m extra categorical covariates and [1;36m0[0m extra  
         continuous covariates.                                                              
[34mINFO    [0m Please do not further mo

In [4]:
background_indices = (
    adata.obs.index[(adata.obs["batch"] == "batch_0")]
    .astype(int)
    .tolist()
)
target_indices = (
    adata.obs.index[(adata.obs["batch"] == "batch_1")]
    .astype(int)
    .tolist()
)

In [5]:
dataloader = ContrastiveDataLoader(
    adata,
    background_indices,
    target_indices,
    batch_size=32,
    shuffle=False,
)
for batch in dataloader:
    pass

In [6]:
total_contrastive_vi_module = TotalContrastiveVIModule(
    n_input_genes=n_genes,
    n_input_proteins=n_proteins,
    n_hidden=12,
    n_background_latent=8,
    n_salient_latent=4,
)

In [7]:
inference_input = total_contrastive_vi_module._get_inference_input(batch)
background_inference_input = inference_input["background"]
target_inference_input = inference_input["target"]

In [8]:
generic_inference_output = total_contrastive_vi_module._generic_inference(
    **background_inference_input,
)

In [9]:
inference_outputs = total_contrastive_vi_module.inference(**inference_input)

In [10]:
generative_input = total_contrastive_vi_module._get_generative_input(
    batch, inference_outputs
)

In [11]:
generic_generative_input = generative_input["background"]

In [12]:
generic_generative_input.keys()

dict_keys(['batch_index', 'z', 's', 'library_gene'])

In [13]:
generic_generative_outputs = total_contrastive_vi_module._generic_generative(
    **generic_generative_input
)

In [14]:
generative_outputs = total_contrastive_vi_module.generative(**generative_input)

In [15]:
generative_outputs["background"].keys()

dict_keys(['px_', 'py_', 'log_pro_back_mean'])

In [16]:
background_tensors = batch["background"]

In [17]:
loss_recorder = total_contrastive_vi_module.loss(
    batch, inference_outputs, generative_outputs
)

In [18]:
loss_recorder.loss

tensor(467.8974, grad_fn=<AddBackward0>)

In [19]:
loss_recorder.reconstruction_loss.shape

torch.Size([5])