In [44]:
import numpy as np
import scanpy as sc
import torch

import matplotlib.pyplot as plt
from scarches.dataset.trvae.data_handling import remove_sparsity
from tranvae.model import EMBEDCVAE, TRANVAE
from sklearn.metrics import classification_report

sc.settings.set_figure_params(dpi=200, frameon=False)
sc.set_figure_params(dpi=200)
torch.set_printoptions(precision=3, sci_mode=False, edgeitems=7)

%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [16]:
unlabeled_strat = "batch"
cells_per_ct = 2000

# Model Params
latent_dim = 10
use_mmd = False

# Training Params
tranvae_epochs = 500
pretraining_epochs = 100
alpha_epoch_anneal = 1e6
eta = 1
tau = 0
clustering_res = 2
labeled_loss_metric = "dist"
unlabeled_loss_metric = "dist"
class_metric = "dist"

early_stopping_kwargs = {
    "early_stopping_metric": "val_classifier_loss",
    "mode": "min",
    "threshold": 0,
    "patience": 20,
    "reduce_lr": True,
    "lr_patience": 13,
    "lr_factor": 0.1,
}

cell_type_key = ["cell_type"]

In [17]:
DATA_DIR = '../data'
adata = sc.read(
    f'{DATA_DIR}/benchmark_pbmc_shrinked.h5ad'
)
condition_key = 'condition'
reference = ['Oetjen', '10X', 'Sun']
query = ['Freytag']

In [20]:
adata = remove_sparsity(adata)

indices = np.arange(len(adata))
#stratified label/unlabeled split
if unlabeled_strat == "batch":
    labeled_ind = indices[adata.obs.study.isin(reference)].tolist()
    labeled_adata = adata[adata.obs.study.isin(reference)].copy()
    unlabeled_adata = adata[adata.obs.study.isin(query)].copy()
if unlabeled_strat == "ct":
    labeled_ind = []
    cts = adata.obs[cell_type_key[0]].unique().tolist()
    for celltype in cts:
        ct_indices = indices[adata.obs[cell_type_key[0]].isin([celltype])]
        ct_sel_ind = np.random.choice(ct_indices, size=cells_per_ct, replace=False)
        labeled_ind += ct_sel_ind.tolist()
        print(celltype, len(ct_indices), len(ct_sel_ind), len(labeled_ind))
    unlabeled_ind = np.delete(indices, labeled_ind).tolist()
    labeled_adata = adata[labeled_ind].copy()
    unlabeled_adata = adata[unlabeled_ind].copy()

In [49]:
embed = EMBEDCVAE(
    adata=adata,
    condition_key=condition_key,
    inject_condition = ['encoder', 'decoder'],
    embedding_dim = 10,
    cell_type_keys=cell_type_key,
    hidden_layer_sizes=[128, 128],
    latent_dim=latent_dim,
    use_mmd=use_mmd,
    labeled_indices=labeled_ind,
    unknown_ct_names=None
)

Encoder Architecture:
	Input Layer in, out and cond: 4000 128 10
	Hidden Layer 1 in/out: 128 128
	Mean/Var Layer in/out: 128 10
Decoder Architecture:
	First Layer in, out and cond:  10 128 10
	Hidden Layer 1 in/out: 128 128
	Output Layer in/out:  128 4000 



In [46]:
tranvae = TRANVAE(
    adata=adata,
    condition_key=condition_key,
    cell_type_keys=cell_type_key,
    hidden_layer_sizes=[128, 128],
    latent_dim=latent_dim,
    use_mmd=use_mmd,
    labeled_indices=labeled_ind,
    unknown_ct_names=None
)


INITIALIZING NEW NETWORK..............
Encoder Architecture:
	Input Layer in, out and cond: 4000 128 9
	Hidden Layer 1 in/out: 128 128
	Mean/Var Layer in/out: 128 10
Decoder Architecture:
	First Layer in, out and cond:  10 128 9
	Hidden Layer 1 in/out: 128 128
	Output Layer in/out:  128 4000 



In [51]:
tranvae.train(
    n_epochs=tranvae_epochs,
    early_stopping_kwargs=early_stopping_kwargs,
    pretraining_epochs=pretraining_epochs,
    alpha_epoch_anneal=alpha_epoch_anneal,
    eta=eta,
    tau=tau,
    clustering_res=clustering_res,
    labeled_loss_metric=labeled_loss_metric,
    unlabeled_loss_metric=unlabeled_loss_metric
)

 |████----------------| 20.0%  - val_loss: 1328.6840585562 - val_trvae_loss: 1328.6840585562
Initializing unlabeled landmarks with Leiden-Clustering with an unknown number of clusters.
Leiden Clustering succesful. Found 39 clusters.
 |███████-------------| 37.6%  - val_loss: 1333.2449153020 - val_trvae_loss: 1332.6018160306 - val_classifier_loss: 0.6430961856 - val_unlabeled_loss: 1.6528987609 - val_labeled_loss: 0.6414432892
ADJUSTED LR
 |█████████-----------| 46.4%  - val_loss: 1329.4798630934 - val_trvae_loss: 1329.0014883188 - val_classifier_loss: 0.4783808532 - val_unlabeled_loss: 1.4617866644 - val_labeled_loss: 0.4769190664
ADJUSTED LR
 |█████████-----------| 47.8%  - val_loss: 1325.5334331806 - val_trvae_loss: 1325.0342829778 - val_classifier_loss: 0.4991648667 - val_unlabeled_loss: 1.4250707718 - val_labeled_loss: 0.4977397999
Stopping early: no improvement of more than 0 nats in 20 epochs
If the early stopping criterion is too strong, please instantiate it with different para

In [57]:
embed.train(
    n_epochs=100,
    early_stopping_kwargs=early_stopping_kwargs,
    pretraining_epochs=20,
    alpha_epoch_anneal=alpha_epoch_anneal,
    eta=eta,
    tau=tau,
    clustering_res=clustering_res,
    labeled_loss_metric=labeled_loss_metric,
    unlabeled_loss_metric=unlabeled_loss_metric
)

 |████----------------| 20.0%  - val_loss: 1311.4796142578 - val_trvae_loss: 1311.4796142578
Initializing unlabeled landmarks with Leiden-Clustering with an unknown number of clusters.
Leiden Clustering succesful. Found 38 clusters.
 |████████████████----| 82.0%  - val_loss: 1314.9159358098 - val_trvae_loss: 1314.3373084435 - val_classifier_loss: 0.5786209450 - val_unlabeled_loss: 0.6284089868 - val_labeled_loss: 0.5779925356
ADJUSTED LR
 |████████████████████| 100.0%  - val_loss: 1312.0196486253 - val_trvae_loss: 1311.5202636719 - val_classifier_loss: 0.4993728285 - val_unlabeled_loss: 0.5916365270 - val_labeled_loss: 0.4987811916
Saving best state of network...
Best State was in Epoch 98


In [64]:
embed.model.embedding.weight.detach().cpu().numpy()

array([[ 2.11e-01, -6.08e-02, -1.31e-03,  1.05e-01, -2.35e-02, -1.62e-03,
         1.04e-02, -5.92e-03,  2.50e-02,  2.67e-02],
       [ 1.42e-01,  2.03e-03, -1.68e-04, -1.93e-02,  6.06e-03,  1.69e-03,
        -1.89e-01,  1.90e-03, -3.68e-03, -6.43e-03],
       [ 2.04e-01,  1.09e-01, -1.64e-03,  7.70e-02,  5.83e-02,  6.50e-03,
         2.00e-02, -1.48e-03, -6.41e-02, -1.34e-02],
       [-2.13e-02, -2.73e-02,  3.16e-01,  9.28e-02,  3.83e-02,  2.43e-03,
        -8.39e-02,  8.09e-05, -4.30e-02,  6.10e-02],
       [-2.81e-02, -2.68e-03,  1.85e-04,  2.14e-01,  3.29e-03,  2.65e-03,
        -6.26e-02,  3.53e-03,  3.46e-03, -1.93e-01],
       [-8.59e-02, -1.54e-02, -4.78e-02,  1.64e-01, -5.57e-02,  1.27e-01,
        -1.06e-01,  2.43e-03, -1.56e-01,  1.12e-01],
       [-7.69e-02, -7.21e-02, -7.75e-02,  1.61e-01,  1.84e-01, -2.11e-02,
        -1.02e-01, -2.58e-03, -6.15e-02,  1.10e-01],
       [-5.12e-03,  1.30e-01,  1.11e-02,  1.81e-01, -8.33e-04,  3.22e-02,
        -6.48e-02,  1.05e-03,  1.64e-

In [67]:
data_latent = embed.get_latent(
    unlabeled_adata.X, 
    unlabeled_adata.obs[condition_key].values
)
adata_latent = sc.AnnData(data_latent)
adata_latent.obs['batch'] = unlabeled_adata.obs[condition_key].tolist()
results_dict = tranvae.classify(
    unlabeled_adata.X, 
    unlabeled_adata.obs[condition_key], 
    metric=class_metric
)

RuntimeError: Expected tensor for argument #1 'indices' to have scalar type Long; but got torch.cuda.DoubleTensor instead (while checking arguments for embedding)

In [None]:
embed_dec = EMBEDCVAE(
    adata=adata,
    condition_key=condition_key,
    inject_condition = ['decoder'],
    embedding_dim = 10,
    cell_type_keys=cell_type_key,
    hidden_layer_sizes=[128, 128],
    latent_dim=latent_dim,
    use_mmd=use_mmd,
    labeled_indices=labeled_ind,
    unknown_ct_names=None
)

In [None]:
embed_dec.train(
    n_epochs=tranvae_epochs,
    early_stopping_kwargs=early_stopping_kwargs,
    pretraining_epochs=pretraining_epochs,
    alpha_epoch_anneal=alpha_epoch_anneal,
    eta=eta,
    tau=tau,
    clustering_res=clustering_res,
    labeled_loss_metric=labeled_loss_metric,
    unlabeled_loss_metric=unlabeled_loss_metric
)