# Demo Run

In [1]:
%load_ext autoreload
%autoreload 2
# this may cause DisentenglementTargetManager to reimported, losing all the data e.g. configurations.

In [2]:
import os
import sys
import warnings
import anndata as ad
import scanpy as sc
from pathlib import Path

sys.path.append(str(Path(os.getcwd()).resolve().parents[0]))

import tardis

local_run = True
if local_run:
    tardis.config = tardis.config_local
else:
    tardis.config = tardis.config_server
    
adata_file_path = os.path.join(tardis.config.io_directories["processed"], "dataset_subset_sample_status_1.h5ad")
assert os.path.isfile(adata_file_path), f"File not already exist: `{adata_file_path}`"
metadata_of_interest = "integration_sample_status"

adata = ad.read_h5ad(adata_file_path)

In [3]:
import tardis

local_run = True
if local_run:
    tardis.config = tardis.config_local
else:
    tardis.config = tardis.config_server

disentenglement_targets_configurations=[
    dict(
        obs_key = "integration_sample_status",
        n_reserved_latent = 3,
        counteractive_minibatch_settings = dict(
            method = "random",
            method_kwargs = dict(
                exclude_itself = True,
                exclude_group = True,
                group_size_aware = True,
                within_label = True,
                within_batch = True,
                seed = "gglobal",
            )
        ),
        auxillary_losses = dict(
            loss_complete_latent = {"apply": True, "method": "mse", "weight": 1.0, "negative_sign": True, "method_kwargs": {}},
            loss_subset_latent = {"apply": False, "method": "cross_entropy", "weight": 2.0, "negative_sign": True, "method_kwargs": {}},
        ),
    ),
    dict(
        obs_key = "sample_ID",
        n_reserved_latent = 5,
        counteractive_minibatch_settings = dict(
            method = "random",
            method_kwargs = dict(
                exclude_itself = True,
                exclude_group = True,
                group_size_aware = True,
                within_label = True,
                within_batch = True,
                seed = "gglobal",
            )
        ),
        auxillary_losses = dict(
            loss_complete_latent = {"apply": True, "method": "mse", "weight": 1.0, "negative_sign": True, "method_kwargs": {}},
            loss_subset_latent = {"apply": False, "method": "cross_entropy", "weight": 2.0, "negative_sign": True, "method_kwargs": {}},
        ),
    )
]
model_params = dict(
    n_hidden=512,
    n_layers=3, 
    n_latent=20, 
    gene_likelihood="nb",
    dropout_rate = 0.1
)
train_params = dict(
    max_epochs=3,
    train_size=0.2
)
dataset_params = dict(
    layer=None, 
    labels_key="cell_type",
    batch_key="concatenated_integration_covariates",
    disentenglement_targets_configurations=disentenglement_targets_configurations
)

tardis.MyModel.setup_anndata(adata, **dataset_params)

# tardis.MyModel.setup_wandb(
#     wandb_configurations=tardis.config_local.wandb,
#     hyperparams=dict(
#         model_params=model_params,
#         train_params=train_params,
#         dataset_params=dataset_params,
#     )
# )

vae = tardis.MyModel(adata, **model_params)
vae.train(**train_params)

GPU available: True (mps), used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
  rank_zero_warn(


Epoch 1/3:   0%|          | 0/3 [00:00<?, ?it/s]####
(11279, 2) <class 'pandas.core.frame.DataFrame'>
Garcia_FCA_GND9331970_AAACCTGAGACTAAGT    1
Garcia_FCA_GND9331970_AAACCTGAGCAGGTCA    1
Garcia_FCA_GND9331970_AAACCTGAGGACAGAA    1
Garcia_FCA_GND9331970_AAACCTGAGGCACATG    1
Garcia_FCA_GND9331970_AAACCTGCAAGCGAGT    1
                                         ..
Garcia_FCA_GND9295212_TTTGTCAGTAGCTCCG    0
Garcia_FCA_GND9295212_TTTGTCAGTTATCGGT    0
Garcia_FCA_GND9295212_TTTGTCAGTTTGACAC    0
Garcia_FCA_GND9295212_TTTGTCATCATCTGCC    0
Garcia_FCA_GND9295212_TTTGTCATCCACTGGG    0
Name: integration_sample_status, Length: 11279, dtype: int8
####
(11279, 2) <class 'pandas.core.frame.DataFrame'>
Garcia_FCA_GND9331970_AAACCTGAGACTAAGT    0
Garcia_FCA_GND9331970_AAACCTGAGCAGGTCA    0
Garcia_FCA_GND9331970_AAACCTGAGGACAGAA    0
Garcia_FCA_GND9331970_AAACCTGAGGCACATG    0
Garcia_FCA_GND9331970_AAACCTGCAAGCGAGT    0
                                         ..
Garcia_FCA_GND9295212_TTTGTCAGTAGCTC

`Trainer.fit` stopped: `max_epochs=3` reached.


Epoch 3/3: 100%|██████████| 3/3 [00:01<00:00,  1.84it/s, v_num=1, train_loss_step=165, train_loss_epoch=151]


In [15]:
from tardis._disentenglementtargetmanager import DisentenglementTargetManager

In [12]:
DisentenglementTargetManager.configurations.unreserved_latent_indices

[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19]

In [13]:
DisentenglementTargetManager.configurations

DisentenglementTargetConfigurations(items=[], unreserved_latent_indices=[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19])

## Visualization

In [None]:
adata.obsm["X_scVI"] = vae.get_latent_representation()
sc.pp.neighbors(adata, n_neighbors = 30, use_rep="X_scVI")
sc.tl.umap(adata, min_dist=0.2)

In [None]:
with warnings.catch_warnings():
    warnings.simplefilter("ignore")
    sc.pl.umap(
        adata, 
        color=[metadata_of_interest, "cell_type", "concatenated_integration_covariates"], 
        ncols=3,
        frameon=False,
        title="",
        legend_fontsize="xx-small"
    )

# Playground