# Demo Run

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import os
import sys
import warnings
import anndata as ad
import scanpy as sc
from pathlib import Path

sys.path.append(str(Path(os.getcwd()).resolve().parents[0]))

In [3]:
import tardis

local_run = True
if local_run:
    tardis.config = tardis.config_local
else:
    tardis.config = tardis.config_server

In [4]:
adata_file_path = os.path.join(tardis.config.io_directories["processed"], "dataset_subset_sample_status_1.h5ad")
assert os.path.isfile(adata_file_path), f"File not already exist: `{adata_file_path}`"
metadata_of_interest = "integration_sample_status"

In [5]:
adata = ad.read_h5ad(adata_file_path)
adata

AnnData object with n_obs × n_vars = 11279 × 2048
    obs: 'sample_ID', 'organ', 'age', 'cell_type', 'sex', 'sex_inferred', 'concatenated_integration_covariates', 'integration_donor', 'integration_biological_unit', 'integration_sample_status', 'integration_library_platform_coarse'

## Training

In [6]:
disentenglement_targets_configurations=[
    dict(
        key="integration_sample_status"
    ),
    dict(
        key="sample_ID"
    )
]

model_params = dict(
    n_hidden=512,
    n_layers=3, 
    n_latent=20, 
    gene_likelihood="nb",
    dropout_rate = 0.1
)
train_params = dict(
    max_epochs=3,
    train_size=0.2
)
dataset_params = dict(
    layer=None, 
    # labels_key="cell_type",
    batch_key="concatenated_integration_covariates",
    disentenglement_targets_configurations=disentenglement_targets_configurations
)

In [9]:
tardis.MyModel.setup_anndata(adata, **dataset_params)

In [10]:
# tardis.MyModel.setup_wandb(
#     wandb_configurations=tardis.config_local.wandb,
#     hyperparams=dict(
#         model_params=model_params,
#         train_params=train_params,
#         dataset_params=dataset_params,
#     )
# )

In [11]:
vae = tardis.MyModel(adata, **model_params)

In [12]:
vae.train(**train_params)

GPU available: True (mps), used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
  rank_zero_warn(


Epoch 1/3:   0%|          | 0/3 [00:00<?, ?it/s]dict_keys(['X', 'batch', 'disentenglement_target', 'labels'])
tensor([[0],
        [0],
        [0],
        [0],
        [0],
        [0],
        [0],
        [0],
        [0],
        [0],
        [0],
        [0],
        [0],
        [0],
        [0],
        [0],
        [0],
        [0],
        [0],
        [0],
        [0],
        [0],
        [0],
        [0],
        [0],
        [0],
        [0],
        [0],
        [0],
        [0],
        [0],
        [0],
        [0],
        [0],
        [0],
        [0],
        [0],
        [0],
        [0],
        [0],
        [0],
        [0],
        [0],
        [0],
        [0],
        [0],
        [0],
        [0],
        [0],
        [0],
        [0],
        [0],
        [0],
        [0],
        [0],
        [0],
        [0],
        [0],
        [0],
        [0],
        [0],
        [0],
        [0],
        [0],
        [0],
        [0],
        [0],
        [0],
      

`Trainer.fit` stopped: `max_epochs=3` reached.


Epoch 3/3: 100%|██████████| 3/3 [00:01<00:00,  1.75it/s, v_num=1, train_loss_step=157, train_loss_epoch=149]


In [13]:
from tardis._mydatasplitter import CounteractiveMinibatchGenerator
CounteractiveMinibatchGenerator._disentenglement_targets_configurations

[{'key': 'integration_sample_status'}, {'key': 'sample_ID'}]

In [14]:
CounteractiveMinibatchGenerator._anndata_manager_state_registry

{'batch': {'categorical_mapping': array(['Hrv25_Garcia_et_al_Cell_Frozen_5GEX',
         'Hrv27_Garcia_et_al_Cell_Fresh_5GEX'], dtype=object),
  'original_key': 'concatenated_integration_covariates'},
 'labels': {'categorical_mapping': array([0]), 'original_key': '_scvi_labels'},
 'size_factor': {},
 'extra_categorical_covs': {},
 'extra_continuous_covs': {},
 'disentenglement_target': {'mappings': {'integration_sample_status': array(['Fresh', 'Frozen'], dtype=object),
   'sample_ID': array(['Hrv25', 'Hrv27'], dtype=object)},
  'field_keys': ['integration_sample_status', 'sample_ID'],
  'n_cats_per_key': [2, 2]}}

In [67]:
len({i: vae.adata_manager._registry["field_registries"][i]["state_registry"] for i in vae.adata_manager._registry["field_registries"] if i != "X"}["labels"]['categorical_mapping'])

1

In [68]:
len({i: vae.adata_manager._registry["field_registries"][i]["state_registry"] for i in vae.adata_manager._registry["field_registries"] if i != "X"}["batch"]['categorical_mapping'])

2

## Visualization

In [None]:
adata.obsm["X_scVI"] = vae.get_latent_representation()
sc.pp.neighbors(adata, n_neighbors = 30, use_rep="X_scVI")
sc.tl.umap(adata, min_dist=0.2)

In [None]:
with warnings.catch_warnings():
    warnings.simplefilter("ignore")
    sc.pl.umap(
        adata, 
        color=[metadata_of_interest, "cell_type", "concatenated_integration_covariates"], 
        ncols=3,
        frameon=False,
        title="",
        legend_fontsize="xx-small"
    )