In [1]:
import numpy as np
import scanpy as sc
import torch

import matplotlib.pyplot as plt
from scarches.dataset.trvae.data_handling import remove_sparsity
from tranvae.model import EMBEDCVAE, TRANVAE
from sklearn.metrics import classification_report

sc.settings.set_figure_params(dpi=200, frameon=False)
sc.set_figure_params(dpi=200)
torch.set_printoptions(precision=3, sci_mode=False, edgeitems=7)

%load_ext autoreload
%autoreload 2

In [2]:
print(torch.__version__)

1.7.0


In [3]:
unlabeled_strat = "batch"
cells_per_ct = 2000

# Model Params
latent_dim = 10
use_mmd = False

# Training Params
alpha_epoch_anneal = 1e6
eta = 1
tau = 0
clustering_res = 2
labeled_loss_metric = "dist"
unlabeled_loss_metric = "dist"
class_metric = "dist"

early_stopping_kwargs = {
    "early_stopping_metric": "val_classifier_loss",
    "mode": "min",
    "threshold": 0,
    "patience": 20,
    "reduce_lr": True,
    "lr_patience": 13,
    "lr_factor": 0.1,
}

cell_type_key = ["cell_type"]

In [4]:
DATA_DIR = '../data'
adata = sc.read(
    f'{DATA_DIR}/benchmark_pbmc_shrinked.h5ad'
)
condition_key = 'condition'
reference = ['Oetjen', '10X', 'Sun']
query = ['Freytag']
adata_ref = adata[adata.obs[condition_key].isin(reference)]
adata_query = adata[adata.obs[condition_key].isin(query)]

In [5]:
adata = remove_sparsity(adata)

indices = np.arange(len(adata))
#stratified label/unlabeled split
if unlabeled_strat == "batch":
    labeled_ind = indices[adata.obs.study.isin(reference)].tolist()
    labeled_adata = adata[adata.obs.study.isin(reference)].copy()
    unlabeled_adata = adata[adata.obs.study.isin(query)].copy()
if unlabeled_strat == "ct":
    labeled_ind = []
    cts = adata.obs[cell_type_key[0]].unique().tolist()
    for celltype in cts:
        ct_indices = indices[adata.obs[cell_type_key[0]].isin([celltype])]
        ct_sel_ind = np.random.choice(ct_indices, size=cells_per_ct, replace=False)
        labeled_ind += ct_sel_ind.tolist()
        print(celltype, len(ct_indices), len(ct_sel_ind), len(labeled_ind))
    unlabeled_ind = np.delete(indices, labeled_ind).tolist()
    labeled_adata = adata[labeled_ind].copy()
    unlabeled_adata = adata[unlabeled_ind].copy()

In [6]:
embed = EMBEDCVAE(
    adata=adata,
    condition_key=condition_key,
    inject_condition = ['encoder', 'decoder'],
    embedding_dim = 10,
    cell_type_keys=cell_type_key,
    hidden_layer_sizes=[128, 128],
    latent_dim=latent_dim,
    use_mmd=use_mmd,
    labeled_indices=labeled_ind,
    unknown_ct_names=None
)

Embedding dictionary:
 	Num conditions: 9
 	Embedding dim: 10
Encoder Architecture:
	Input Layer in, out and cond: 4000 128 10
	Hidden Layer 1 in/out: 128 128
	Mean/Var Layer in/out: 128 10
Decoder Architecture:
	First Layer in, out and cond:  10 128 10
	Hidden Layer 1 in/out: 128 128
	Output Layer in/out:  128 4000 



In [7]:
tranvae = TRANVAE(
    adata=adata,
    condition_key=condition_key,
    cell_type_keys=cell_type_key,
    hidden_layer_sizes=[128, 128],
    latent_dim=latent_dim,
    use_mmd=use_mmd,
    labeled_indices=labeled_ind,
    unknown_ct_names=None
)


INITIALIZING NEW NETWORK..............
Encoder Architecture:
	Input Layer in, out and cond: 4000 128 9
	Hidden Layer 1 in/out: 128 128
	Mean/Var Layer in/out: 128 10
Decoder Architecture:
	First Layer in, out and cond:  10 128 9
	Hidden Layer 1 in/out: 128 128
	Output Layer in/out:  128 4000 



In [8]:
embed_dec = EMBEDCVAE(
    adata=adata,
    condition_key=condition_key,
    inject_condition = ['decoder'],
    embedding_dim = 10,
    cell_type_keys=cell_type_key,
    hidden_layer_sizes=[128, 128],
    latent_dim=latent_dim,
    use_mmd=use_mmd,
    labeled_indices=labeled_ind,
    unknown_ct_names=None
)

Embedding dictionary:
 	Num conditions: 9
 	Embedding dim: 10
Encoder Architecture:
	Input Layer in, out and cond: 4000 128 0
	Hidden Layer 1 in/out: 128 128
	Mean/Var Layer in/out: 128 10
Decoder Architecture:
	First Layer in, out and cond:  10 128 10
	Hidden Layer 1 in/out: 128 128
	Output Layer in/out:  128 4000 



In [11]:
tranvae.train(
    n_epochs=50,
    early_stopping_kwargs=early_stopping_kwargs,
    pretraining_epochs=10,
    alpha_epoch_anneal=alpha_epoch_anneal,
    eta=eta,
    tau=tau,
    clustering_res=clustering_res,
    labeled_loss_metric=labeled_loss_metric,
    unlabeled_loss_metric=unlabeled_loss_metric
)

cuda
cuda
cuda
cuda
cuda
cuda
cuda
cuda
cuda
cuda
cuda
cuda
cuda
cuda
cuda
cuda
cuda
cuda
cuda
cuda
cuda
cuda
cuda
cuda
cuda
cuda
cuda
cuda
cuda
cuda
cuda
cuda
cuda
cuda
cuda
cuda
cuda
cuda
cuda
cuda
cuda
cuda
cuda
cuda
cuda
cuda
cuda
cuda
cuda
cuda
cuda
cuda
cuda
cuda
cuda
cuda
cuda
cuda
cuda
cuda
cuda
cuda
cuda
cuda
cuda
cuda
cuda
cuda
cuda
cuda
cuda
cuda
cuda
cuda
cuda
cuda
cuda
cuda
cuda
cuda
cuda
cuda
cuda
cuda
cuda
cuda
cuda
cuda
cuda
cuda
cuda
cuda
cuda
cuda
cuda
cuda
cuda
cuda
cuda
cuda
cuda
cuda
cuda
cuda
cuda
cuda
cuda
cuda
cuda
cuda
cuda
cuda
cuda
cuda
cuda
cuda
cuda
cuda
cuda
cuda
cuda
cuda
cuda
cuda
cuda
cuda
cuda
cuda
cuda
cuda
cuda
cuda
cuda
cuda
cuda
cuda
cuda
cuda
cuda
cuda
cuda
cuda
cuda
cuda
cuda
cuda
cuda
cuda
cuda
cuda
cuda
cuda
cuda
cuda
cuda
cuda
cuda
cuda
cuda
cuda
cuda
cuda
cuda
cuda
cuda
cuda
cuda
cuda
cuda
cuda
cuda
cuda
cuda
cuda
cuda
cuda
cuda
cuda
cuda
cuda
cuda
cuda
cuda
cuda
cuda
cuda
cuda
cuda
cuda
cuda
cuda
cuda
cuda
cuda
cuda
cuda
cuda
cuda
cuda
cuda


KeyboardInterrupt: 

In [None]:
data_latent = tranvae.get_latent(
    labeled_adata.X, 
    labeled_adata.obs[condition_key].values
)
adata_latent_tranvae = sc.AnnData(data_latent)
adata_latent_tranvae.obs['batch'] = labeled_adata.obs[condition_key].tolist()
results_dict = tranvae.classify(
    labeled_adata.X, 
    labeled_adata.obs[condition_key], 
    metric=class_metric
)
print(results_dict)
sc.pp.neighbors(adata_latent_tranvae)
sc.tl.umap(adata_latent_tranvae)
sc.pl.umap(adata_latent_tranvae, color='batch')

In [15]:
embed.train(
    n_epochs=10,
    early_stopping_kwargs=early_stopping_kwargs,
    pretraining_epochs=2,
    alpha_epoch_anneal=alpha_epoch_anneal,
    eta=eta,
    tau=tau,
    clustering_res=clustering_res,
    labeled_loss_metric=labeled_loss_metric,
    unlabeled_loss_metric=unlabeled_loss_metric
)

 |████----------------| 20.0%  - val_loss: 1333.9616980919 - val_trvae_loss: 1333.9616980919
Initializing unlabeled landmarks with Leiden-Clustering with an unknown number of clusters.
Leiden Clustering succesful. Found 38 clusters.
 |████████████████████| 100.0%  - val_loss: 1333.6302067683 - val_trvae_loss: 1331.3054480919 - val_classifier_loss: 2.3247675621 - val_unlabeled_loss: 0.8368574810 - val_labeled_loss: 2.3239368384
Saving best state of network...
Best State was in Epoch 9


AttributeError: 'NoneType' object has no attribute 'unsqueeze'

In [13]:
data_latent = embed.get_latent(
    labeled_adata.X, 
    labeled_adata.obs[condition_key].values
)
adata_latent_embed = sc.AnnData(data_latent)
adata_latent_embed.obs['batch'] = labeled_adata.obs[condition_key].tolist()
results_dict_embed = embed.classify(
    labeled_adata.X, 
    labeled_adata.obs[condition_key], 
    metric=class_metric
)
print(results_dict_embed)
sc.pp.neighbors(adata_latent_embed)
sc.tl.umap(adata_latent_embed)
sc.pl.umap(adata_latent_embed, color='batch')

TypeError: 'NoneType' object is not subscriptable

In [None]:
embed_dec.train(
    n_epochs=500,
    early_stopping_kwargs=early_stopping_kwargs,
    pretraining_epochs=100,
    alpha_epoch_anneal=alpha_epoch_anneal,
    eta=eta,
    tau=tau,
    clustering_res=clustering_res,
    labeled_loss_metric=labeled_loss_metric,
    unlabeled_loss_metric=unlabeled_loss_metric
)

In [12]:
data_latent = embed_dec.get_latent(
    labeled_adata.X, 
    labeled_adata.obs[condition_key].values
)
adata_latent_embed_dec = sc.AnnData(data_latent)
adata_latent_embed_dec.obs['batch'] = labeled_adata.obs[condition_key].tolist()
results_dict_embed_dec = embed_dec.classify(
    labeled_adata.X, 
    labeled_adata.obs[condition_key], 
    metric=class_metric
)
print(results_dict_embed_dec)
sc.pp.neighbors(adata_latent_embed_dec)
sc.tl.umap(adata_latent_embed_dec)
sc.pl.umap(adata_latent_embed_dec, color='batch')

TypeError: 'NoneType' object is not subscriptable

In [None]:
embedding = embed.model.embedding.weight.detach().cpu().numpy()
embedding_adata = sc.AnnData(embedding)
sc.pp.pca(embedding_adata)
embedding_adata.obs['condition'] = adata.obs['condition'].unique()
sc.pl.pca(embedding_adata, size=50, color='condition')

In [None]:
embedding_dec = embed_dec.model.embedding.weight.detach().cpu().numpy()
embedding_dec_adata = sc.AnnData(embedding_dec)
sc.pp.pca(embedding_dec_adata)
embedding_dec_adata.obs['condition'] = adata.obs['condition'].unique()
sc.pl.pca(embedding_dec_adata, size=50, color='condition')

In [None]:
embed_ctrl = EMBEDCVAE(
    adata=adata,
    condition_key=condition_key,
    inject_condition = [],
    embedding_dim = 10,
    cell_type_keys=cell_type_key,
    hidden_layer_sizes=[128, 128],
    latent_dim=latent_dim,
    use_mmd=use_mmd,
    labeled_indices=labeled_ind,
    unknown_ct_names=None
)

In [None]:
embed_ctrl.train(
    n_epochs=500,
    early_stopping_kwargs=early_stopping_kwargs,
    pretraining_epochs=100,
    alpha_epoch_anneal=alpha_epoch_anneal,
    eta=eta,
    tau=tau,
    clustering_res=clustering_res,
    labeled_loss_metric=labeled_loss_metric,
    unlabeled_loss_metric=unlabeled_loss_metric
)

In [None]:
data_latent = embed_ctrl.get_latent(
    labeled_adata.X, 
    labeled_adata.obs[condition_key].values
)
adata_latent_embed_ctrl = sc.AnnData(data_latent)
adata_latent_embed_ctrl.obs['batch'] = labeled_adata.obs[condition_key].tolist()
results_dict_embed_ctrl = embed_ctrl.classify(
    labeled_adata.X, 
    labeled_adata.obs[condition_key], 
    metric=class_metric
)
print(results_dict_embed_ctrl)
sc.pp.neighbors(adata_latent_embed_ctrl)
sc.tl.umap(adata_latent_embed_ctrl)
sc.pl.umap(adata_latent_embed_ctrl, color='batch')

In [None]:
embed_dec.save('../tranvae_benchmarks/embed.tar', overwrite=True)

In [None]:
embed_dec_query.model

In [None]:
DATA_DIR = '../data'
adata = sc.read(
    f'{DATA_DIR}/benchmark_pbmc_shrinked.h5ad'
)
condition_key = 'condition'
reference = ['Oetjen', '10X', 'Sun']
query = ['Freytag']
adata_ref = adata[adata.obs['study'].isin(reference)]
adata_query = adata[adata.obs['study'].isin(query)]

In [None]:
embed_surg = EMBEDCVAE(
    adata=adata_ref,
    condition_key=condition_key,
    inject_condition = ['decoder'],
    embedding_dim = 10,
    recon_loss = 'mse',
    cell_type_keys=cell_type_key,
    hidden_layer_sizes=[128, 128],
    latent_dim=latent_dim,
    unknown_ct_names=None
)

In [None]:
embed_surg.train(
    n_epochs=500,
    early_stopping_kwargs=early_stopping_kwargs,
    pretraining_epochs=100,
    alpha_epoch_anneal=alpha_epoch_anneal,
    eta=eta,
    tau=tau,
    clustering_res=clustering_res,
    labeled_loss_metric=labeled_loss_metric,
    unlabeled_loss_metric=unlabeled_loss_metric
)

In [None]:
embed_surg.model.embedding

In [None]:
embedding_dec_adata.obs

In [None]:
embedding_dec = embed_surg.model.embedding.weight.detach().cpu().numpy()
embedding_dec_adata = sc.AnnData(embedding_dec)
sc.pp.pca(embedding_dec_adata)
embedding_dec_adata.obs['condition'] = adata_ref.obs['condition'].unique()
sc.pl.pca(embedding_dec_adata, size=50, color='condition')

In [None]:
embed_surg.save('./../tranvae_benchmarks/embed.tar', overwrite=True)

In [None]:
embed_dec_query = EMBEDCVAE.load_query_data(
    adata=adata_query,
    reference_model=f'../tranvae_benchmarks/embed.tar',
    labeled_indices=[],
)
embed_dec_query.train(
    n_epochs=500,
    early_stopping_kwargs=early_stopping_kwargs,
    pretraining_epochs=100,
    eta=eta,
    tau=tau,
    weight_decay=0,
    clustering_res=clustering_res,
    labeled_loss_metric=labeled_loss_metric,
    unlabeled_loss_metric=unlabeled_loss_metric
)

In [None]:
embedding_dec = embed_dec_query.model.embedding.weight.detach().cpu().numpy()
embedding_dec_adata = sc.AnnData(embedding_dec)
sc.pp.pca(embedding_dec_adata)
embedding_dec_adata.obs['condition'] = (
    adata_ref.obs['condition'].unique().tolist()
    + adata_query.obs['condition'].unique().tolist()
)
embedding_dec_adata.obs['condition'] = embedding_dec_adata.obs['condition'].astype('category')
sc.pl.pca(embedding_dec_adata, size=50, color='condition')

In [None]:
data_latent = embed_dec_query.get_latent(
    adata_ref.X.A, 
    adata_ref.obs[condition_key].values
)
adata_latent_embed_ctrl = sc.AnnData(data_latent)
adata_latent_embed_ctrl.obs['batch'] = adata_ref.obs[condition_key].tolist()
results_dict_embed_ctrl = embed_ctrl.classify(
    adata_ref.X.A, 
    adata_ref.obs[condition_key], 
    metric=class_metric
)
print(results_dict_embed_ctrl)
sc.pp.neighbors(adata_latent_embed_ctrl)
sc.tl.umap(adata_latent_embed_ctrl)
sc.pl.umap(adata_latent_embed_ctrl, color='batch')

In [None]:
data_latent = embed_dec_query.get_latent(
    adata_query.X.A, 
    adata_query.obs[condition_key].values
)
adata_latent_embed_ctrl = sc.AnnData(data_latent)
adata_latent_embed_ctrl.obs['batch'] = adata_query.obs[condition_key].tolist()
results_dict_embed_ctrl = embed_ctrl.classify(
    adata_query.X.A, 
    adata_query.obs[condition_key], 
    metric=class_metric
)
print(results_dict_embed_ctrl)
sc.pp.neighbors(adata_latent_embed_ctrl)
sc.tl.umap(adata_latent_embed_ctrl)
sc.pl.umap(adata_latent_embed_ctrl, color='batch')

In [None]:
data_latent = embed_dec_query.get_latent(
    adata.X.A, 
    adata.obs[condition_key].values
)
adata_latent_embed_ctrl = sc.AnnData(data_latent)
adata_latent_embed_ctrl.obs['batch'] = adata.obs[condition_key].tolist()
results_dict_embed_ctrl = embed_ctrl.classify(
    adata.X.A, 
    adata.obs[condition_key], 
    metric=class_metric
)
print(results_dict_embed_ctrl)
sc.pp.neighbors(adata_latent_embed_ctrl)
sc.tl.umap(adata_latent_embed_ctrl)
sc.pl.umap(adata_latent_embed_ctrl, color='batch')