In [1]:
import os
os.chdir('../')
import warnings
warnings.simplefilter(action='ignore', category=FutureWarning)
warnings.simplefilter(action='ignore', category=UserWarning)

In [2]:
import scanpy as sc
import torch
import scarches as sca
from scarches.dataset.trvae.data_handling import remove_sparsity
import matplotlib.pyplot as plt
import numpy as np
import gdown

 captum (see https://github.com/pytorch/captum).
INFO:lightning_fabric.utilities.seed:Global seed set to 0
  from .autonotebook import tqdm as notebook_tqdm


In [3]:
sc.settings.set_figure_params(dpi=200, frameon=False)
sc.set_figure_params(dpi=200)
sc.set_figure_params(figsize=(4, 4))
torch.set_printoptions(precision=3, sci_mode=False, edgeitems=7)

In [4]:
url = 'https://drive.google.com/uc?id=1ehxgfHTsMZXy6YzlFKGJOsBKQ5rrvMnd'
output = 'pancreas.h5ad'
gdown.download(url, output, quiet=False)

Downloading...
From (uriginal): https://drive.google.com/uc?id=1ehxgfHTsMZXy6YzlFKGJOsBKQ5rrvMnd
From (redirected): https://drive.google.com/uc?id=1ehxgfHTsMZXy6YzlFKGJOsBKQ5rrvMnd&confirm=t&uuid=67b6bc55-3610-4a1f-942b-097804229a81
To: /mnt/c/Helmholtz/archmap_data/scarches-api/pancreas.h5ad
100%|██████████| 126M/126M [00:12<00:00, 10.4MB/s] 


'pancreas.h5ad'

In [4]:
condition_key = 'study'
cell_type_key = 'cell_type'
target_conditions = ['Pancreas CelSeq2', 'Pancreas SS2']

output = 'pancreas.h5ad'

trvae_epochs = 500
surgery_epochs = 500

early_stopping_kwargs = {
    "early_stopping_metric": "val_unweighted_loss",
    "threshold": 0,
    "patience": 20,
    "reduce_lr": True,
    "lr_patience": 13,
    "lr_factor": 0.1,
}

In [5]:
adata_all = sc.read(output)

source_adata = adata_all[~adata_all.obs[condition_key].isin(target_conditions)].copy()
target_adata = adata_all[adata_all.obs[condition_key].isin(target_conditions)].copy()
source_conditions = source_adata.obs[condition_key].unique().tolist()

In [7]:
source_adata.write_h5ad('source__pancreas.h5ad')
target_adata.write_h5ad('target__pancreas.h5ad')

In [7]:
trvae = sca.models.TRVAE(
    adata=source_adata,
    condition_key=condition_key,
    conditions=source_conditions,
    hidden_layer_sizes=[128, 128],
)


INITIALIZING NEW NETWORK..............
Encoder Architecture:
	Input Layer in, out and cond: 1000 128 3
	Hidden Layer 1 in/out: 128 128
	Mean/Var Layer in/out: 128 10
Decoder Architecture:
	First Layer in, out and cond:  10 128 3
	Hidden Layer 1 in/out: 128 128
	Output Layer in/out:  128 1000 



In [8]:
trvae.train(
    n_epochs=trvae_epochs,
    alpha_epoch_anneal=200,
    early_stopping_kwargs=early_stopping_kwargs
)

 |██------------------| 10.6%  - val_loss: 658.2319335938 - val_recon_loss: 648.7462293837 - val_kl_loss: 19.0855223338 - val_mmd_loss: 4.5234735277
ADJUSTED LR
 |███-----------------| 17.8%  - val_loss: 669.1723564996 - val_recon_loss: 656.3727077908 - val_kl_loss: 16.2815569772 - val_mmd_loss: 5.6357627445
ADJUSTED LR
 |███-----------------| 19.2%  - val_loss: 663.8201633030 - val_recon_loss: 650.9449666341 - val_kl_loss: 15.9713758892 - val_mmd_loss: 5.2887816959
Stopping early: no improvement of more than 0 nats in 20 epochs
If the early stopping criterion is too strong, please instantiate it with different parameters in the train method.
Saving best state of network...
Best State was in Epoch 74


In [9]:
ref_path = 'reference_model_pancreas/'
trvae.save(ref_path, overwrite=True)

In [10]:
new_trvae = sca.models.TRVAE.load_query_data(adata=target_adata, reference_model=ref_path)

AnnData object with n_obs × n_vars = 5387 × 1000
    obs: 'batch', 'study', 'cell_type', 'size_factors'

INITIALIZING NEW NETWORK..............
Encoder Architecture:
	Input Layer in, out and cond: 1000 128 5
	Hidden Layer 1 in/out: 128 128
	Mean/Var Layer in/out: 128 10
Decoder Architecture:
	First Layer in, out and cond:  10 128 5
	Hidden Layer 1 in/out: 128 128
	Output Layer in/out:  128 1000 



In [11]:
new_trvae.train(
    n_epochs=surgery_epochs,
    alpha_epoch_anneal=200,
    early_stopping_kwargs=early_stopping_kwargs,
    weight_decay=0
)

 |████----------------| 22.2%  - val_loss: 911.1574096680 - val_recon_loss: 900.9128417969 - val_kl_loss: 17.5934219360 - val_mmd_loss: 0.56819458010
ADJUSTED LR
 |████----------------| 23.6%  - val_loss: 909.3723754883 - val_recon_loss: 898.4927490234 - val_kl_loss: 17.5979537964 - val_mmd_loss: 0.5848316193
Stopping early: no improvement of more than 0 nats in 20 epochs
If the early stopping criterion is too strong, please instantiate it with different parameters in the train method.
Saving best state of network...
Best State was in Epoch 96


In [12]:
surg_path = 'surgery_model_pancreas'
new_trvae.save(surg_path, overwrite=True)