In [1]:
import os
os.chdir('../')
import warnings
warnings.simplefilter(action='ignore', category=FutureWarning)
warnings.simplefilter(action='ignore', category=UserWarning)

In [2]:
import scanpy as sc
import torch
import scarches as sca
from scarches.dataset.trvae.data_handling import remove_sparsity
import matplotlib.pyplot as plt
import numpy as np
import gdown

 captum (see https://github.com/pytorch/captum).
INFO:lightning_fabric.utilities.seed:Global seed set to 0
  from .autonotebook import tqdm as notebook_tqdm


In [3]:
sc.settings.set_figure_params(dpi=200, frameon=False)
sc.set_figure_params(dpi=200)
sc.set_figure_params(figsize=(4, 4))
torch.set_printoptions(precision=3, sci_mode=False, edgeitems=7)

In [4]:
condition_key = 'batch' 
cell_type_key = 'final_annotation'
target_conditions = ['Oetjen_A','Sun_sample1_CS']

output = 'Immune_ALL_human.h5ad'

trvae_epochs = 500
surgery_epochs = 500

early_stopping_kwargs = {
    "early_stopping_metric": "val_unweighted_loss",
    "threshold": 0,
    "patience": 20,
    "reduce_lr": True,
    "lr_patience": 13,
    "lr_factor": 0.1,
}

In [5]:
adata_all = sc.read(output)

source_adata = adata_all[~adata_all.obs[condition_key].isin(target_conditions)].copy()
target_adata = adata_all[adata_all.obs[condition_key].isin(target_conditions)].copy()
source_conditions = source_adata.obs[condition_key].unique().tolist()

In [6]:
source_adata.write_h5ad('source__pbmc.h5ad')
target_adata.write_h5ad('target__pbmc.h5ad')

In [7]:
trvae = sca.models.TRVAE(
    adata=source_adata,
    condition_key=condition_key,
    conditions=source_conditions,
    hidden_layer_sizes=[128, 128],
)


INITIALIZING NEW NETWORK..............
Encoder Architecture:
	Input Layer in, out and cond: 12303 128 8
	Hidden Layer 1 in/out: 128 128
	Mean/Var Layer in/out: 128 10
Decoder Architecture:
	First Layer in, out and cond:  10 128 8
	Hidden Layer 1 in/out: 128 128
	Output Layer in/out:  128 12303 



In [8]:
trvae.train(
    n_epochs=trvae_epochs,
    alpha_epoch_anneal=200,
    early_stopping_kwargs=early_stopping_kwargs
)

 |██------------------| 14.2%  - val_loss: 3590.7711022418 - val_recon_loss: 3545.1896972656 - val_kl_loss: 26.6363117384 - val_mmd_loss: 36.2586589482
ADJUSTED LR
 |████----------------| 24.8%  - val_loss: 3594.2870669158 - val_recon_loss: 3539.5747282609 - val_kl_loss: 18.2229265130 - val_mmd_loss: 43.5052359208
ADJUSTED LR
 |█████---------------| 27.8%  - val_loss: 3590.4464058254 - val_recon_loss: 3540.0509298573 - val_kl_loss: 17.3613858430 - val_mmd_loss: 38.4161085046
ADJUSTED LR
 |██████--------------| 32.8%  - val_loss: 3593.2402237602 - val_recon_loss: 3539.5494968580 - val_kl_loss: 16.7590221737 - val_mmd_loss: 40.0321217413
ADJUSTED LR
 |██████--------------| 34.2%  - val_loss: 3593.9598654042 - val_recon_loss: 3540.1260296365 - val_kl_loss: 16.6795218924 - val_mmd_loss: 39.6562781956
Stopping early: no improvement of more than 0 nats in 20 epochs
If the early stopping criterion is too strong, please instantiate it with different parameters in the train method.
Saving best 

In [9]:
ref_path = 'reference_model_pbmc/'
trvae.save(ref_path, overwrite=True)

In [10]:
new_trvae = sca.models.TRVAE.load_query_data(adata=target_adata, reference_model=ref_path)

AnnData object with n_obs × n_vars = 4311 × 12303
    obs: 'batch', 'chemistry', 'data_type', 'dpt_pseudotime', 'final_annotation', 'mt_frac', 'n_counts', 'n_genes', 'sample_ID', 'size_factors', 'species', 'study', 'tissue'
    layers: 'counts'

INITIALIZING NEW NETWORK..............
Encoder Architecture:
	Input Layer in, out and cond: 12303 128 10
	Hidden Layer 1 in/out: 128 128
	Mean/Var Layer in/out: 128 10
Decoder Architecture:
	First Layer in, out and cond:  10 128 10
	Hidden Layer 1 in/out: 128 128
	Output Layer in/out:  128 12303 



In [11]:
new_trvae.train(
    n_epochs=surgery_epochs,
    alpha_epoch_anneal=200,
    early_stopping_kwargs=early_stopping_kwargs,
    weight_decay=0
)

 |██████--------------| 31.6%  - val_loss: 2805.0277099609 - val_recon_loss: 2793.0004882812 - val_kl_loss: 14.7646112442 - val_mmd_loss: 0.4369621277
ADJUSTED LR
 |██████--------------| 34.4%  - val_loss: 2818.3168334961 - val_recon_loss: 2805.2718505859 - val_kl_loss: 14.7420217991 - val_mmd_loss: 0.4404945374
ADJUSTED LR
 |███████-------------| 37.6%  - val_loss: 2806.5301513672 - val_recon_loss: 2792.3743286133 - val_kl_loss: 14.6795158386 - val_mmd_loss: 0.4304623604
ADJUSTED LR
 |████████------------| 41.2%  - val_loss: 2814.7688598633 - val_recon_loss: 2799.6450805664 - val_kl_loss: 14.7022061348 - val_mmd_loss: 0.4215445518
ADJUSTED LR
 |████████------------| 42.6%  - val_loss: 2825.6467285156 - val_recon_loss: 2810.4708251953 - val_kl_loss: 14.7250926495 - val_mmd_loss: 0.4507498741
Stopping early: no improvement of more than 0 nats in 20 epochs
If the early stopping criterion is too strong, please instantiate it with different parameters in the train method.
Saving best state

In [12]:
surg_path = 'surgery_model_pbmc'
new_trvae.save(surg_path, overwrite=True)