In [None]:
import os
import matplotlib.pyplot as plt
import scarches as sca
from scarches.dataset.trvae.data_handling import remove_sparsity
import scanpy as sc
import torch
import time
import json
import numpy as np

# Set Params

In [None]:
n_epochs_surgery = 300
leave_out_cell_types = ['Pancreas Alpha','Pancreas Gamma']

In [None]:
target_batches = ["Pancreas SS2", "Pancreas CelSeq2"]
batch_key = "study"
cell_type_key = "cell_type"

In [None]:
n_epochs_vae = 500
early_stopping_kwargs = {
    "early_stopping_metric": "val_unweighted_loss",
    "threshold": 0,
    "patience": 20,
    "reduce_lr": True,
    "lr_patience": 13,
    "lr_factor": 0.1,
}

In [None]:
# Save right dir path
dir_path = os.path.expanduser(f'~/Documents/benchmarking_results/figure_1/trvae_mse/ood_2/')

if not os.path.exists(dir_path):
    os.makedirs(dir_path)
control_path = f'{dir_path}controlling/'
if not os.path.exists(control_path):
    os.makedirs(control_path)

# Adata Handling

In [None]:
adata = sc.read(os.path.expanduser(f'~/Documents/benchmarking_datasets/pancreas_normalized.h5ad'))
#adata = adata_all.raw.to_adata()
adata = remove_sparsity(adata)
adata

In [None]:
query = np.array([s in target_batches for s in adata.obs[batch_key]])
query_1 = np.array([s in [target_batches[0]] for s in adata.obs[batch_key]])
query_2 = np.array([s in [target_batches[1]] for s in adata.obs[batch_key]])
adata_ref_full = adata[~query].copy()
adata_ref = adata_ref_full[~adata_ref_full.obs[cell_type_key].isin(leave_out_cell_types)].copy()
adata_query_1 = adata[query_1].copy()
adata_query_2 = adata[query_2].copy()

In [None]:
adata_ref

In [None]:
adata_query_1

In [None]:
adata_query_2

# Create SCVI model and train

In [None]:
trvae = sca.models.TRVAE(
    adata=adata_ref,
    condition_key=batch_key,
    hidden_layer_sizes=[128, 128],
    recon_loss='mse'
)

In [None]:
ref_time = time.time()
trvae.train(
    n_epochs=n_epochs_vae,
    alpha_epoch_anneal=200,
    early_stopping_kwargs=early_stopping_kwargs
)
ref_time = time.time() - ref_time

# Reference Evaluation

In [None]:
adata_latent_r = sc.AnnData(trvae.get_latent())
adata_latent_r.obs['celltype'] = adata_ref.obs[cell_type_key].tolist()
adata_latent_r.obs['batch'] = adata_ref.obs[batch_key].tolist()

In [None]:
sc.pp.neighbors(adata_latent_r, n_neighbors=8)
sc.tl.leiden(adata_latent_r)
sc.tl.umap(adata_latent_r)
adata_latent_r.write_h5ad(filename=f'{dir_path}reference_data.h5ad')

In [None]:
plt.figure()
sc.pl.umap(
    adata_latent_r,
    color=["batch", "celltype"],
    frameon=False,
    ncols=1,
    show=False
)
plt.savefig(f'{control_path}umap_reference.png', bbox_inches='tight')

In [None]:
ref_path = f'{dir_path}ref_model/'
if not os.path.exists(ref_path):
    os.makedirs(ref_path)
trvae.save(ref_path, overwrite=True)

# Run surgery on first query batch

In [None]:
new_trvae = sca.models.TRVAE.load_query_data(adata=adata_query_1, reference_model=ref_path)

In [None]:
query_1_time = time.time()
new_trvae.train(n_epochs=n_epochs_surgery, frequency=1, early_stopping_kwargs=early_stopping_kwargs, weight_decay=0)
query_1_time = time.time() - query_1_time

# Evaluation Surgery on Query1

In [None]:
adata_latent_q1 = sc.AnnData(new_trvae.get_latent())
adata_latent_q1.obs['celltype'] = adata_query_1.obs[cell_type_key].tolist()
adata_latent_q1.obs['batch'] = adata_query_1.obs[batch_key].tolist()

In [None]:
sc.pp.neighbors(adata_latent_q1)
sc.tl.leiden(adata_latent_q1)
sc.tl.umap(adata_latent_q1)
adata_latent_q1.write_h5ad(filename=f'{dir_path}query_1_data.h5ad')

In [None]:
plt.figure()
sc.pl.umap(
    adata_latent_q1,
    color=["batch", "celltype"],
    frameon=False,
    ncols=1,
    show=False
)
plt.savefig(f'{control_path}umap_query_1.png', bbox_inches='tight')

# Evaluation Query1 on reference

In [None]:
adata_full_1 = adata_ref.concatenate(adata_query_1)

In [None]:
adata_latent_f1 = sc.AnnData(new_trvae.get_latent(adata_full_1.X, adata_full_1.obs[batch_key]))
adata_latent_f1.obs['celltype'] = adata_full_1.obs[cell_type_key].tolist()
adata_latent_f1.obs['batch'] = adata_full_1.obs[batch_key].tolist()

In [None]:
sc.pp.neighbors(adata_latent_f1)
sc.tl.leiden(adata_latent_f1)
sc.tl.umap(adata_latent_f1)
adata_latent_f1.write_h5ad(filename=f'{dir_path}full_1_data.h5ad')

In [None]:
plt.figure()
sc.pl.umap(
    adata_latent_f1,
    color=["batch", "celltype"],
    frameon=False,
    ncols=1,
    show=False
)
plt.savefig(f'{control_path}umap_full_1.png', bbox_inches='tight')

In [None]:
surgery_1_path = f'{dir_path}surg_1_model/'
if not os.path.exists(surgery_1_path):
    os.makedirs(surgery_1_path)
new_trvae.save(surgery_1_path, overwrite=True)

# Run surgery on second query batch

In [None]:
new_trvae_2 = sca.models.TRVAE.load_query_data(adata=adata_query_2, reference_model=surgery_1_path)

In [None]:
query_2_time = time.time()
new_trvae_2.train(n_epochs=n_epochs_surgery, frequency=1, early_stopping_kwargs=early_stopping_kwargs, weight_decay=0)
query_2_time = time.time() - query_2_time

# Evaluation Surgery on Query2

In [None]:
adata_latent_q2 = sc.AnnData(new_trvae_2.get_latent())
adata_latent_q2.obs['celltype'] = adata_query_2.obs[cell_type_key].tolist()
adata_latent_q2.obs['batch'] = adata_query_2.obs[batch_key].tolist()

In [None]:
sc.pp.neighbors(adata_latent_q2)
sc.tl.leiden(adata_latent_q2)
sc.tl.umap(adata_latent_q2)
adata_latent_q2.write_h5ad(filename=f'{dir_path}query_2_data.h5ad')

In [None]:
plt.figure()
sc.pl.umap(
    adata_latent_q2,
    color=["batch", "celltype"],
    frameon=False,
    ncols=1,
    show=False
)
plt.savefig(f'{control_path}umap_query_2.png', bbox_inches='tight')

# Evaluation Query1 and Query2 on Reference

In [None]:
adata_full_2 = adata_full_1.concatenate(adata_query_2)

In [None]:
adata_latent_f2 = sc.AnnData(new_trvae_2.get_latent(adata_full_2.X, adata_full_2.obs[batch_key]))
adata_latent_f2.obs['celltype'] = adata_full_2.obs[cell_type_key].tolist()
adata_latent_f2.obs['batch'] = adata_full_2.obs[batch_key].tolist()

In [None]:
sc.pp.neighbors(adata_latent_f2)
sc.tl.leiden(adata_latent_f2)
sc.tl.umap(adata_latent_f2)
adata_latent_f2.write_h5ad(filename=f'{dir_path}full_2_data.h5ad')

In [None]:
plt.figure()
sc.pl.umap(
    adata_latent_f2,
    color=["batch", "celltype"],
    frameon=False,
    ncols=1,
    show=False
)
plt.savefig(f'{control_path}umap_full_2.png', bbox_inches='tight')

In [None]:
surgery_2_path = f'{dir_path}surg_2_model/'
if not os.path.exists(surgery_2_path):
    os.makedirs(surgery_2_path)
new_trvae_2.save(surgery_2_path, overwrite=True)

In [None]:
times = dict()
times["ref_time"] = ref_time
times["query_1_time"] = query_1_time
times["query_2_time"] = query_2_time
times["full_time"] = ref_time + query_1_time + query_2_time
with open(f'{dir_path}results_times.txt', 'w') as filehandle:
    json.dump(times, filehandle)