# Demo Run

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import os
import sys
import warnings
import anndata as ad
import scanpy as sc
from pathlib import Path

sys.path.append(str(Path(os.getcwd()).resolve().parents[0]))

In [None]:
import tardis

local_run = True
if local_run:
    tardis.config = tardis.config_local
else:
    tardis.config = tardis.config_server

In [None]:
adata_file_path = os.path.join(tardis.config.io_directories["processed"], "dataset_subset_sample_status_1.h5ad")
assert os.path.isfile(adata_file_path), f"File not already exist: `{adata_file_path}`"
metadata_of_interest = "integration_sample_status"

In [None]:
adata = ad.read_h5ad(adata_file_path)
adata

## Training

In [None]:
disentenglement_targets_configurations=[
    dict(
        key="integration_sample_status"
    ),
    dict(
        key="sample_ID"
    )
]

model_params = dict(
    n_hidden=512,
    n_layers=3, 
    n_latent=20, 
    gene_likelihood="nb",
    dropout_rate = 0.1
)
train_params = dict(
    max_epochs=3,
    train_size=0.2
)
dataset_params = dict(
    layer=None, 
    # labels_key="cell_type",
    batch_key="concatenated_integration_covariates",
    disentenglement_targets_configurations=disentenglement_targets_configurations
)

In [None]:
tardis.MyModel.setup_anndata(adata, **dataset_params)

In [None]:
# tardis.MyModel.setup_wandb(
#     wandb_configurations=tardis.config_local.wandb,
#     hyperparams=dict(
#         model_params=model_params,
#         train_params=train_params,
#         dataset_params=dataset_params,
#     )
# )

In [None]:
vae = tardis.MyModel(adata, **model_params)

In [None]:
vae.train(**train_params)

In [None]:
from tardis._mydatasplitter import CounteractiveMinibatchGenerator
CounteractiveMinibatchGenerator._disentenglement_targets_configurations

In [None]:
CounteractiveMinibatchGenerator._anndata_manager_state_registry

## Visualization

In [None]:
adata.obsm["X_scVI"] = vae.get_latent_representation()
sc.pp.neighbors(adata, n_neighbors = 30, use_rep="X_scVI")
sc.tl.umap(adata, min_dist=0.2)

In [None]:
with warnings.catch_warnings():
    warnings.simplefilter("ignore")
    sc.pl.umap(
        adata, 
        color=[metadata_of_interest, "cell_type", "concatenated_integration_covariates"], 
        ncols=3,
        frameon=False,
        title="",
        legend_fontsize="xx-small"
    )