# Demo Run

In [1]:
# %load_ext autoreload
# %autoreload 2
# # this may cause DisentenglementTargetManager to reimported, losing all the data e.g. configurations.

In [2]:
import os
import sys
import warnings
import anndata as ad
import scanpy as sc
import torch
from pathlib import Path

sys.path.append(str(Path(os.getcwd()).resolve().parents[0]))

import tardis

local_run = False
if local_run:
    tardis.config = tardis.config_local
else:
    tardis.config = tardis.config_server
torch.cuda.is_available()

True

In [3]:
adata_file_path = os.path.join(tardis.config.io_directories["processed"], "dataset_subset_sample_status_1.h5ad")
assert os.path.isfile(adata_file_path), f"File not already exist: `{adata_file_path}`"
adata = ad.read_h5ad(adata_file_path)

In [4]:
disentenglement_targets_configurations=[
    dict(
        obs_key = "integration_sample_status",
        n_reserved_latent = 3,
        counteractive_minibatch_settings = dict(
            method = "random",
            method_kwargs = dict(
                within_labels = True,
                within_batch = True,
                within_categorical_covs = [True, False],
                within_other_groups = True,
                seed = "gglobal",
            )
        ),
        auxillary_losses = dict(
            loss_complete_latent = dict(
                apply = True, 
                method = "mse", 
                weight = 1.0, 
                negative_sign = True, 
                method_kwargs = {}
            ),
            loss_subset_latent = dict(
                apply = False, 
                method = "cross_entropy", 
                weight = 2.0, 
                negative_sign = True, 
                method_kwargs = {}
            ),
        ),
    ),
    dict(
        obs_key = "sample_ID",
        n_reserved_latent = 5,
        counteractive_minibatch_settings = dict(
            method = "random",
            method_kwargs = dict(
                within_labels = True,
                within_batch = True,
                within_categorical_covs = [True, False],
                within_other_groups = True,
                seed = "gglobal",
            )
        ),
        auxillary_losses = dict(
            loss_complete_latent = dict(
                apply = True, 
                method = "mse", 
                weight = 1.0, 
                negative_sign = True, 
                method_kwargs = {}
            ),
            loss_subset_latent = dict(
                apply = False, 
                method = "cross_entropy", 
                weight = 2.0, 
                negative_sign = True, 
                method_kwargs = {}
            ),
        ),
    )
]

model_params = dict(
    n_hidden=256,
    n_layers=3, 
    n_latent=20, 
    gene_likelihood="nb",
    dropout_rate = 0.1
)
train_params = dict(
    max_epochs=100,
    train_size=0.8
)
dataset_params = dict(
    layer=None, 
    labels_key="cell_type",
    batch_key="concatenated_integration_covariates",
    categorical_covariate_keys=["sex", "age"],
    disentenglement_targets_configurations=disentenglement_targets_configurations
)

tardis.MyModel.setup_anndata(adata, **dataset_params)

tardis.MyModel.setup_wandb(
    wandb_configurations=tardis.config.wandb,
    hyperparams=dict(
        model_params=model_params,
        train_params=train_params,
        dataset_params=dataset_params,
    )
)

vae = tardis.MyModel(adata, **model_params)
vae.train(**train_params)

CUDA backend failed to initialize: Found cuDNN version 8700, but JAX was built against version 8800, which is newer. The copy of cuDNN that is installed must be at least as new as the version against which JAX was built. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
You are using a CUDA device ('NVIDIA A100 80GB PCIe') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]
SLURM auto-requeueing enabled. Setting signal handlers.


Epoch 1/100:   0%|                                                                                                                    | 0/100 [00:00<?, ?it/s]

  possible_indices = CachedPossibleGroupDefinitionIndices.get(
  possible_indices = CachedPossibleGroupDefinitionIndices.get(
  possible_indices = CachedPossibleGroupDefinitionIndices.get(
  possible_indices = CachedPossibleGroupDefinitionIndices.get(


Epoch 100/100: 100%|█████████████████████████████████████████████████| 100/100 [01:58<00:00,  1.15s/it, v_num=x7_1, train_loss_step=129, train_loss_epoch=118]

`Trainer.fit` stopped: `max_epochs=100` reached.


Epoch 100/100: 100%|█████████████████████████████████████████████████| 100/100 [01:58<00:00,  1.19s/it, v_num=x7_1, train_loss_step=129, train_loss_epoch=118]
W&B logger finalized with the following parameters: 
Exit Code: 0
Entity: inecik-academic
Project: tardis_experimental_runs
ID: bmbehex7
Name: young-gorge-29
Tags: tardis, experimental, development
Notes: Development runs for tardis.
URL: https://wandb.ai/inecik-academic/tardis_experimental_runs/runs/bmbehex7
Directory: /home/icb/kemal.inecik/work/codes/tardis/training/server/wandb/run-20240408_140438-bmbehex7/files



__Note: Debugging takes significant amount of time__

In [6]:
from tardis._DEBUG import DEBUG

In [7]:
DEBUG.tensors.keys()

dict_keys(['X', 'batch', 'disentenglement_target', 'extra_categorical_covs', 'labels', 'disentenglement_target_tensors'])

In [8]:
DEBUG.tensors["disentenglement_target_tensors"].keys()

dict_keys(['integration_sample_status', 'sample_ID'])

In [9]:
DEBUG.tensors["disentenglement_target_tensors"]["sample_ID"].keys()

dict_keys(['X', 'batch', 'disentenglement_target', 'extra_categorical_covs', 'labels'])

## Visualization

In [None]:
adata.obsm["X_scVI"] = vae.get_latent_representation()
sc.pp.neighbors(adata, n_neighbors = 30, use_rep="X_scVI")
sc.tl.umap(adata, min_dist=0.2)

In [None]:
with warnings.catch_warnings():
    warnings.simplefilter("ignore")
    sc.pl.umap(
        adata, 
        color=["integration_sample_status", "sample_ID", "cell_type", "concatenated_integration_covariates"], 
        ncols=3,
        frameon=False,
        title="",
        legend_fontsize="xx-small"
    )

# Playground