In [1]:
import numpy as np
import scanpy as sc
import torch

import matplotlib.pyplot as plt
from scarches.dataset.trvae.data_handling import remove_sparsity
from lataq.models import TRANVAE
from sklearn.metrics import classification_report

sc.settings.set_figure_params(dpi=200, frameon=False)
sc.set_figure_params(dpi=200)
torch.set_printoptions(precision=3, sci_mode=False, edgeitems=7)

I found this bug that appears once you reinstantiate the model. I need this to work for hyperparameter tuning. This is happening for some reason only with the scvelo and lung dataset.

In the cells I have run you can see two debug messages (I deleted them in the script so you need to readd them if you want to see them in your run):
- 1 - print statememt on line 373 of lataq/trainers/lataq.py
- 2 - print statement on line 383 of lataq/trainers/lataq.py

It seems the issue is in update_labeled_landmarks(), but I am not sure why this happens when reinstantiating a new model, and with these two datasets.

In [2]:
%load_ext autoreload
%autoreload 2

In [3]:
unlabeled_strat = "batch"
cells_per_ct = 2000

# Model Params
latent_dim = 10
use_mmd = False

# Training Params
alpha_epoch_anneal = 1e6
eta = 1
tau = 0
clustering_res = 2
labeled_loss_metric = "dist"
unlabeled_loss_metric = "dist"
class_metric = "dist"

early_stopping_kwargs = {
    "early_stopping_metric": "val_classifier_loss",
    "mode": "min",
    "threshold": 0,
    "patience": 20,
    "reduce_lr": True,
    "lr_patience": 13,
    "lr_factor": 0.1,
}

cell_type_key = ["cell_type"]

In [4]:
DATA_DIR = '../data'
adata = sc.read(
    f'{DATA_DIR}/benchmark_scvelo_shrinked.h5ad'
)
condition_key = 'study'
reference = ["12.5", "13.5"]
query = ["14.5", "15.5"]
adata.obs['query'] = adata.obs['study'].isin(query).astype('category')

In [5]:
source_adata = adata[adata.obs.study.isin(reference)].copy()
target_adata = adata[adata.obs.study.isin(query)].copy()

In [6]:
tranvae = TRANVAE(
    adata=source_adata,
    condition_key=condition_key,
    cell_type_keys=cell_type_key,
    hidden_layer_sizes=[128, 128],
    latent_dim=latent_dim,
    use_mmd=use_mmd,
)


INITIALIZING NEW NETWORK..............
Encoder Architecture:
	Input Layer in, out and cond: 4000 128 2
	Hidden Layer 1 in/out: 128 128
	Mean/Var Layer in/out: 128 10
Decoder Architecture:
	First Layer in, out and cond:  10 128 2
	Hidden Layer 1 in/out: 128 128
	Output Layer in/out:  128 4000 



In [7]:
tranvae.train(
    n_epochs=20,
    early_stopping_kwargs=early_stopping_kwargs,
    pretraining_epochs=4,
    alpha_epoch_anneal=alpha_epoch_anneal,
    eta=eta,
    tau=tau,
    clustering_res=clustering_res,
    labeled_loss_metric=labeled_loss_metric,
    unlabeled_loss_metric=unlabeled_loss_metric
)

None
 |████----------------| 20.0%  - val_loss: 1498.1462402344 - val_trvae_loss: 1498.14624023441
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
 |█████---------------| 25.0%  - val_loss: 1492.2831373948 - val_trvae_loss: 1481.6147648738 - val_classifier_loss: 10.6683737681 - val_labeled_loss: 10.66837376811
1
1
1
1
1
1
1
1
1
1
 |██████--------------| 30.0%  - val_loss: 1478.5383112981 - val_trvae_loss: 1471.7208721454 - val_classifier_loss: 6.8174372086 - val_labeled_loss: 6.81743720861
1
1
1
1
1
1
1
1
1
1
 |███████-------------| 35.0%  - val_loss: 1472.7420372596 - val_trvae_loss: 1467.3430833083 - val_classifier_loss: 5.3989509436 - val_labeled_loss: 5.39895094361
1
1
1
1
1
1
1
1
1
1
 |████████------------| 40.0%  - val_loss: 1471.6360520583 - val_trvae_loss: 1467.5347524790 - val_classifier_loss: 4.1013013950 - val_labeled_loss: 4.10130139501
1
1
1
1
1
1
1
1
1
1
 |█████████-----------| 45.0%  - val_loss: 1457.3917893630 - val_trvae_loss: 1453.3532339243 - val_classifier_loss: 4.0385620

In [8]:
ref_path = f'tmp_model'
tranvae.save(ref_path, overwrite=True)

In [9]:
tranvae_query = TRANVAE.load_query_data(
    adata=target_adata,
    reference_model=f'tmp_model',
    labeled_indices=[],
)

AnnData object with n_obs × n_vars = 10128 × 4000
    obs: 'cell_type', 'study', 'query'

INITIALIZING NEW NETWORK..............
Encoder Architecture:
	Input Layer in, out and cond: 4000 128 4
	Hidden Layer 1 in/out: 128 128
	Mean/Var Layer in/out: 128 10
Decoder Architecture:
	First Layer in, out and cond:  10 128 4
	Hidden Layer 1 in/out: 128 128
	Output Layer in/out:  128 4000 



In [10]:
tranvae_query.train(
    n_epochs=20,
    early_stopping_kwargs=early_stopping_kwargs,
    pretraining_epochs=5,
    eta=eta,
    tau=tau,
    weight_decay=0,
    clustering_res=clustering_res,
    labeled_loss_metric=labeled_loss_metric,
    unlabeled_loss_metric=unlabeled_loss_metric
)

None
 |█████---------------| 25.0%  - val_loss: 1591.2893371582 - val_trvae_loss: 1591.2893371582
Initializing unlabeled landmarks with Leiden-Clustering with an unknown number of clusters.
Leiden Clustering succesful. Found 32 clusters.
 |████████████████████| 100.0%  - val_loss: 1537.8473358154 - val_trvae_loss: 1537.8471221924 - val_classifier_loss: 0.0002085557 - val_unlabeled_loss: 0.2085556593
Saving best state of network...
Best State was in Epoch 6


In [11]:
tranvae = TRANVAE(
    adata=source_adata,
    condition_key=condition_key,
    cell_type_keys=cell_type_key,
    hidden_layer_sizes=[128, 128],
    latent_dim=latent_dim,
    use_mmd=use_mmd,
)


INITIALIZING NEW NETWORK..............
Encoder Architecture:
	Input Layer in, out and cond: 4000 128 2
	Hidden Layer 1 in/out: 128 128
	Mean/Var Layer in/out: 128 10
Decoder Architecture:
	First Layer in, out and cond:  10 128 2
	Hidden Layer 1 in/out: 128 128
	Output Layer in/out:  128 4000 



In [12]:
tranvae.train(
    n_epochs=20,
    early_stopping_kwargs=early_stopping_kwargs,
    pretraining_epochs=4,
    alpha_epoch_anneal=alpha_epoch_anneal,
    eta=eta,
    tau=tau,
    clustering_res=clustering_res,
    labeled_loss_metric=labeled_loss_metric,
    unlabeled_loss_metric=unlabeled_loss_metric
)

None
 |████----------------| 20.0%  - val_loss: 1494.1521371695 - val_trvae_loss: 1494.15213716951
1
1
1
1
1
1
1
1
2


TypeError: 'NoneType' object is not subscriptable