# Demo Run

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import os
import sys
import warnings
import anndata as ad
import scanpy as sc
from pathlib import Path

sys.path.append(str(Path(os.getcwd()).resolve().parents[0]))

In [3]:
import tardis

local_run = True
if local_run:
    tardis.config = tardis.config_local
else:
    tardis.config = tardis.config_server

In [4]:
adata_file_path = os.path.join(tardis.config.io_directories["processed"], "dataset_subset_sample_status_1.h5ad")
assert os.path.isfile(adata_file_path), f"File not already exist: `{adata_file_path}`"
metadata_of_interest = "integration_sample_status"

In [5]:
adata = ad.read_h5ad(adata_file_path)
adata

AnnData object with n_obs × n_vars = 11279 × 2048
    obs: 'sample_ID', 'organ', 'age', 'cell_type', 'sex', 'sex_inferred', 'concatenated_integration_covariates', 'integration_donor', 'integration_biological_unit', 'integration_sample_status', 'integration_library_platform_coarse'

## Training

In [6]:
disentenglement_targets_configurations=[
    {
        "obs_key": "integration_sample_status",
        "n_reserved_latent": 3,
        "counteractive_minibatch_generator": {
            "method": "example_method",
            "method_kwargs": {"param1": "value1", "param2": True}
        },
        "auxillary_losses": {
            "loss_complete_latent": {"apply": True, "method": "mse", "weight": 1.0, "negative_sign": True, "method_kwargs": {}},
            "loss_subset_latent": {"apply": False, "method": "cross_entropy", "weight": 2.0, "negative_sign": True, "method_kwargs": {}},
        },
    },
    {
        "obs_key": "sample_ID",
        "n_reserved_latent": 4,
        "counteractive_minibatch_generator": {
            "method": "example_method",
            "method_kwargs": {"param1": "value1", "param2": True}
        },
        "auxillary_losses": {
            "loss_complete_latent": {"apply": True, "method": "mse", "weight": 1.0, "negative_sign": True, "method_kwargs": {}},
            "loss_subset_latent": {"apply": False, "method": "cross_entropy", "weight": 2.0, "negative_sign": True, "method_kwargs": {}},
        },
    },
]
model_params = dict(
    n_hidden=512,
    n_layers=3, 
    n_latent=20, 
    gene_likelihood="nb",
    dropout_rate = 0.1
)
train_params = dict(
    max_epochs=3,
    train_size=0.2
)
dataset_params = dict(
    layer=None, 
    # labels_key="cell_type",
    batch_key="concatenated_integration_covariates",
    disentenglement_targets_configurations=None
)

In [7]:
tardis.MyModel.setup_anndata(adata, **dataset_params)

In [8]:
# tardis.MyModel.setup_wandb(
#     wandb_configurations=tardis.config_local.wandb,
#     hyperparams=dict(
#         model_params=model_params,
#         train_params=train_params,
#         dataset_params=dataset_params,
#     )
# )

In [9]:
vae = tardis.MyModel(adata, **model_params)

In [10]:
vae.train(**train_params)

GPU available: True (mps), used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
  rank_zero_warn(


Epoch 3/3: 100%|██████████| 3/3 [00:01<00:00,  1.92it/s, v_num=1, train_loss_step=150, train_loss_epoch=151]

`Trainer.fit` stopped: `max_epochs=3` reached.


Epoch 3/3: 100%|██████████| 3/3 [00:01<00:00,  1.91it/s, v_num=1, train_loss_step=150, train_loss_epoch=151]


In [11]:
from tardis._disentenglementtargetmanager import DisentenglementTargetManager

In [12]:
DisentenglementTargetManager.configurations.unreserved_latent_indices

[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19]

In [13]:
DisentenglementTargetManager.configurations

DisentenglementTargetConfigurations(items=[], unreserved_latent_indices=[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19])

In [14]:
[c.reserved_latent_indices for c in DisentenglementTargetManager.configurations.items]

[]

In [15]:
DisentenglementTargetManager.configurations.unreserved_latent_indices

[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19]

In [16]:
from tardis._mytrainingplan import TrainingStepLogger
TrainingStepLogger.print_steps()

forward = 54
gglobal = 53
training = 54
validation = 0
test = 0
predict = 0


## Visualization

In [None]:
adata.obsm["X_scVI"] = vae.get_latent_representation()
sc.pp.neighbors(adata, n_neighbors = 30, use_rep="X_scVI")
sc.tl.umap(adata, min_dist=0.2)

In [None]:
with warnings.catch_warnings():
    warnings.simplefilter("ignore")
    sc.pl.umap(
        adata, 
        color=[metadata_of_interest, "cell_type", "concatenated_integration_covariates"], 
        ncols=3,
        frameon=False,
        title="",
        legend_fontsize="xx-small"
    )

# Playground