In [1]:
import os
os.chdir('../')
import warnings
warnings.simplefilter(action='ignore', category=FutureWarning)
warnings.simplefilter(action='ignore', category=UserWarning)

In [2]:
import scanpy as sc
import torch
import scarches as sca
from scarches.dataset.trvae.data_handling import remove_sparsity
import matplotlib.pyplot as plt
import numpy as np
import gdown

 captum (see https://github.com/pytorch/captum).
INFO:lightning_fabric.utilities.seed:Global seed set to 0
  from .autonotebook import tqdm as notebook_tqdm


In [3]:
sc.settings.set_figure_params(dpi=200, frameon=False)
sc.set_figure_params(dpi=200)
sc.set_figure_params(figsize=(4, 4))
torch.set_printoptions(precision=3, sci_mode=False, edgeitems=7)

In [4]:
condition_key = 'study'
cell_type_key = 'cell_type'
target_conditions = ['Pancreas CelSeq2', 'Pancreas SS2']

output = 'pancreas.h5ad'

trvae_epochs = 500
surgery_epochs = 500

early_stopping_kwargs = {
    "early_stopping_metric": "val_unweighted_loss",
    "threshold": 0,
    "patience": 20,
    "reduce_lr": True,
    "lr_patience": 13,
    "lr_factor": 0.1,
}

In [5]:
adata_all = sc.read(output)

source_adata = adata_all[
    (~adata_all.obs[condition_key].isin(target_conditions)) &
    (~adata_all.obs[cell_type_key].isin(["Pancreas Alpha"]))
].copy()

target_adata = adata_all[adata_all.obs[condition_key].isin(target_conditions)].copy()
source_conditions = source_adata.obs[condition_key].unique().tolist()

In [6]:
source_adata.write_h5ad('source__pancreas_no_alpha.h5ad')
target_adata.write_h5ad('target__pancreas_no_alpha.h5ad')

In [6]:
trvae = sca.models.TRVAE(
    adata=source_adata,
    condition_key=condition_key,
    conditions=source_conditions,
    hidden_layer_sizes=[128, 128],
)


INITIALIZING NEW NETWORK..............
Encoder Architecture:
	Input Layer in, out and cond: 1000 128 3
	Hidden Layer 1 in/out: 128 128
	Mean/Var Layer in/out: 128 10
Decoder Architecture:
	First Layer in, out and cond:  10 128 3
	Hidden Layer 1 in/out: 128 128
	Output Layer in/out:  128 1000 



In [7]:
trvae.train(
    n_epochs=trvae_epochs,
    alpha_epoch_anneal=200,
    early_stopping_kwargs=early_stopping_kwargs
)

 |██████--------------| 34.8%  - val_loss: 665.8823038737 - val_recon_loss: 650.7600301107 - val_kl_loss: 12.3072001139 - val_mmd_loss: 4.4765326182
ADJUSTED LR
 |███████-------------| 38.6%  - val_loss: 665.4104817708 - val_recon_loss: 649.3063252767 - val_kl_loss: 11.5611745516 - val_mmd_loss: 5.0054187775
ADJUSTED LR
 |████████------------| 41.4%  - val_loss: 665.9808959961 - val_recon_loss: 649.3853251139 - val_kl_loss: 11.4941643079 - val_mmd_loss: 5.1014048258
ADJUSTED LR
 |████████------------| 42.8%  - val_loss: 665.7905883789 - val_recon_loss: 649.6856486003 - val_kl_loss: 11.4868397713 - val_mmd_loss: 4.6181033452
Stopping early: no improvement of more than 0 nats in 20 epochs
If the early stopping criterion is too strong, please instantiate it with different parameters in the train method.
Saving best state of network...
Best State was in Epoch 213


In [8]:
ref_path = 'reference_model_pancreas_no_alpha/'
trvae.save(ref_path, overwrite=True)

In [9]:
new_trvae = sca.models.TRVAE.load_query_data(adata=target_adata, reference_model=ref_path)

AnnData object with n_obs × n_vars = 5387 × 1000
    obs: 'batch', 'study', 'cell_type', 'size_factors'

INITIALIZING NEW NETWORK..............
Encoder Architecture:
	Input Layer in, out and cond: 1000 128 5
	Hidden Layer 1 in/out: 128 128
	Mean/Var Layer in/out: 128 10
Decoder Architecture:
	First Layer in, out and cond:  10 128 5
	Hidden Layer 1 in/out: 128 128
	Output Layer in/out:  128 1000 



In [10]:
new_trvae.train(
    n_epochs=surgery_epochs,
    alpha_epoch_anneal=200,
    early_stopping_kwargs=early_stopping_kwargs,
    weight_decay=0
)

 |███-----------------| 15.2%  - val_loss: 929.1384277344 - val_recon_loss: 922.5340087891 - val_kl_loss: 16.1461505890 - val_mmd_loss: 0.54959564214
ADJUSTED LR
 |████----------------| 22.0%  - val_loss: 920.4357055664 - val_recon_loss: 911.1913208008 - val_kl_loss: 15.8712984085 - val_mmd_loss: 0.5945198059
ADJUSTED LR
 |████----------------| 24.8%  - val_loss: 939.7165161133 - val_recon_loss: 929.3240478516 - val_kl_loss: 15.9574316025 - val_mmd_loss: 0.5786411285
ADJUSTED LR
 |█████---------------| 26.2%  - val_loss: 929.7299316406 - val_recon_loss: 918.5905517578 - val_kl_loss: 15.9682592392 - val_mmd_loss: 0.7600162506
Stopping early: no improvement of more than 0 nats in 20 epochs
If the early stopping criterion is too strong, please instantiate it with different parameters in the train method.
Saving best state of network...
Best State was in Epoch 109


In [11]:
surg_path = 'surgery_model_pancreas_no_alpha'
new_trvae.save(surg_path, overwrite=True)