In [1]:
import os
os.chdir('../')
import warnings
warnings.simplefilter(action='ignore', category=FutureWarning)
warnings.simplefilter(action='ignore', category=UserWarning)

In [2]:
import scanpy as sc
import torch
import scarches as sca
from scarches.dataset.trvae.data_handling import remove_sparsity
import matplotlib.pyplot as plt
import numpy as np
import gdown

 captum (see https://github.com/pytorch/captum).
INFO:lightning_fabric.utilities.seed:Global seed set to 0
  from .autonotebook import tqdm as notebook_tqdm


In [3]:
sc.settings.set_figure_params(dpi=200, frameon=False)
sc.set_figure_params(dpi=200)
sc.set_figure_params(figsize=(4, 4))
torch.set_printoptions(precision=3, sci_mode=False, edgeitems=7)

In [4]:
condition_key = 'batch' 
cell_type_key = 'final_annotation'
target_conditions = ['10X']

surgery_path = 'surgery_model'
output = 'Immune_ALL_human.h5ad'

trvae_epochs = 500
surgery_epochs = 500

early_stopping_kwargs = {
    "early_stopping_metric": "val_unweighted_loss",
    "threshold": 0,
    "patience": 20,
    "reduce_lr": True,
    "lr_patience": 13,
    "lr_factor": 0.1,
}

In [5]:
adata_all = sc.read(output)

source_adata = adata_all[~adata_all.obs[condition_key].isin(target_conditions)].copy()
target_adata = adata_all[adata_all.obs[condition_key].isin(target_conditions)].copy()
source_conditions = source_adata.obs[condition_key].unique().tolist()

In [6]:
trvae = sca.models.TRVAE(
    adata=source_adata,
    condition_key=condition_key,
    conditions=source_conditions,
    hidden_layer_sizes=[128, 128],
)


INITIALIZING NEW NETWORK..............
Encoder Architecture:
	Input Layer in, out and cond: 12303 128 9
	Hidden Layer 1 in/out: 128 128
	Mean/Var Layer in/out: 128 10
Decoder Architecture:
	First Layer in, out and cond:  10 128 9
	Hidden Layer 1 in/out: 128 128
	Output Layer in/out:  128 12303 



In [7]:
trvae.train(
    n_epochs=trvae_epochs,
    alpha_epoch_anneal=200,
    early_stopping_kwargs=early_stopping_kwargs
)

 |█-------------------| 9.2%  - val_loss: 3240.0711669922 - val_recon_loss: 3197.9678141276 - val_kl_loss: 36.7187317742 - val_mmd_loss: 33.8416170544
ADJUSTED LR
 |███-----------------| 17.0%  - val_loss: 3239.4399278429 - val_recon_loss: 3190.5209689670 - val_kl_loss: 36.5147238837 - val_mmd_loss: 33.5827758577
ADJUSTED LR
 |████----------------| 23.2%  - val_loss: 3246.1266954210 - val_recon_loss: 3191.3292914497 - val_kl_loss: 36.0852419535 - val_mmd_loss: 34.0484169854
ADJUSTED LR
 |████----------------| 24.6%  - val_loss: 3252.3390977648 - val_recon_loss: 3196.6609293620 - val_kl_loss: 36.0569481320 - val_mmd_loss: 33.6834758123
Stopping early: no improvement of more than 0 nats in 20 epochs
If the early stopping criterion is too strong, please instantiate it with different parameters in the train method.
Saving best state of network...
Best State was in Epoch 101


In [8]:
ref_path = 'reference_model/'
trvae.save(ref_path, overwrite=True)

In [9]:
new_trvae = sca.models.TRVAE.load_query_data(adata=target_adata, reference_model=ref_path)

AnnData object with n_obs × n_vars = 10727 × 12303
    obs: 'batch', 'chemistry', 'data_type', 'dpt_pseudotime', 'final_annotation', 'mt_frac', 'n_counts', 'n_genes', 'sample_ID', 'size_factors', 'species', 'study', 'tissue'
    layers: 'counts'

INITIALIZING NEW NETWORK..............
Encoder Architecture:
	Input Layer in, out and cond: 12303 128 10
	Hidden Layer 1 in/out: 128 128
	Mean/Var Layer in/out: 128 10
Decoder Architecture:
	First Layer in, out and cond:  10 128 10
	Hidden Layer 1 in/out: 128 128
	Output Layer in/out:  128 12303 



In [10]:
new_trvae.train(
    n_epochs=surgery_epochs,
    alpha_epoch_anneal=200,
    early_stopping_kwargs=early_stopping_kwargs,
    weight_decay=0
)

 |█████---------------| 28.8%  - val_loss: 3863.8521321615 - val_recon_loss: 3838.2415907118 - val_kl_loss: 35.8189591302 - val_mmd_loss: 0.0000000000
ADJUSTED LR
 |██████--------------| 33.6%  - val_loss: 3863.2724880642 - val_recon_loss: 3833.3790961372 - val_kl_loss: 35.8004112244 - val_mmd_loss: 0.0000000000
ADJUSTED LR
 |███████-------------| 36.8%  - val_loss: 3864.0731336806 - val_recon_loss: 3831.3164333767 - val_kl_loss: 35.7996584574 - val_mmd_loss: 0.0000000000
ADJUSTED LR
 |███████-------------| 39.6%  - val_loss: 3869.7326931424 - val_recon_loss: 3834.4635959201 - val_kl_loss: 35.8062333001 - val_mmd_loss: 0.0000000000
ADJUSTED LR
 |████████------------| 41.0%  - val_loss: 3870.6901312934 - val_recon_loss: 3834.8873155382 - val_kl_loss: 35.8028216892 - val_mmd_loss: 0.0000000000
Stopping early: no improvement of more than 0 nats in 20 epochs
If the early stopping criterion is too strong, please instantiate it with different parameters in the train method.
Saving best state

In [11]:
surg_path = 'surgery_model_proper'
new_trvae.save(surg_path, overwrite=True)