In [None]:
import os
os.chdir('../')
import warnings
warnings.simplefilter(action='ignore', category=FutureWarning)
warnings.simplefilter(action='ignore', category=UserWarning)

In [None]:
import scanpy as sc
import torch
import scarches as sca
from scarches.dataset.trvae.data_handling import remove_sparsity
import matplotlib.pyplot as plt
import numpy as np
import gdown

In [None]:
sc.settings.set_figure_params(dpi=200, frameon=False)
sc.set_figure_params(dpi=200)
sc.set_figure_params(figsize=(4, 4))
torch.set_printoptions(precision=3, sci_mode=False, edgeitems=7)

In [None]:
condition_key = 'study' 
cell_type_key = 'original_ann_nonharmonized'
target_conditions = ['Meyer_2021_5prime', 'Meyer_2021_3prime']

output = 'HLCA_v1_extended_raw_counts_2000hvgs.h5ad'

trvae_epochs = 500
surgery_epochs = 500

early_stopping_kwargs = {
    "early_stopping_metric": "val_unweighted_loss",
    "threshold": 0,
    "patience": 20,
    "reduce_lr": True,
    "lr_patience": 13,
    "lr_factor": 0.1,
}

In [None]:
adata_all = sc.read(output)

In [None]:
source_adata = adata_all[
    (~adata_all.obs[condition_key].isin(target_conditions)) &
    (adata_all.obs["condition"].isin(['Healthy']))
].copy()
target_adata = adata_all[
    (~adata_all.obs[condition_key].isin(target_conditions)) &
    (~adata_all.obs["condition"].isin(['Healthy']))
].copy()
source_conditions = source_adata.obs[condition_key].unique().tolist()

In [None]:
trvae = sca.models.TRVAE(
    adata=source_adata,
    condition_key=condition_key,
    conditions=source_conditions,
    hidden_layer_sizes=[128, 128],
)

In [None]:
trvae.train(
    n_epochs=trvae_epochs,
    alpha_epoch_anneal=200,
    early_stopping_kwargs=early_stopping_kwargs
)

In [None]:
ref_path = 'reference_model_hlca/'
trvae.save(ref_path, overwrite=True)

In [None]:
new_trvae = sca.models.TRVAE.load_query_data(adata=target_adata, reference_model=ref_path)

In [None]:
new_trvae.train(
    n_epochs=surgery_epochs,
    alpha_epoch_anneal=200,
    early_stopping_kwargs=early_stopping_kwargs,
    weight_decay=0
)

In [None]:
surg_path = 'surgery_model_hlca'
new_trvae.save(surg_path, overwrite=True)