In [1]:
import os
import sys
module_path = os.path.abspath(os.path.join('..'))
if module_path not in sys.path:
    sys.path.append(module_path)
    
import numpy as np
import scanpy as sc
import torch
import matplotlib.pyplot as plt
from sklearn.metrics import classification_report

sc.settings.set_figure_params(dpi=200, frameon=False)
sc.set_figure_params(dpi=200)
torch.set_printoptions(precision=3, sci_mode=False, edgeitems=7)

def set_axis_style(ax, labels):
    ax.get_xaxis().set_tick_params(direction='out')
    ax.xaxis.set_ticks_position('bottom')
    ax.set_xticks(np.arange(1, len(labels) + 1))
    ax.set_xticklabels(labels)
    ax.set_xlim(0.25, len(labels) + 0.75)
    ax.set_xlabel('Sample name')

In [2]:
import scarches
from scarches.dataset import remove_sparsity
from lataq.models import TRANVAE

# Create reference data without Delta cells and query data with Delta cells

In [3]:
adata = sc.read(
    f'../data/pancreas.h5ad'
)
adata

AnnData object with n_obs × n_vars = 16382 × 4000
    obs: 'study', 'cell_type', 'pred_label', 'pred_score'
    obsm: 'X_seurat', 'X_symphony'

In [4]:
condition_key = "study"
cell_type_keys = ["cell_type"]
remove_cts = ["delta"]

reference = ["inDrop1", "inDrop2", "inDrop3", "inDrop4", 
             "fluidigmc1", "smartseq2", "smarter"]
query = ["celseq", "celseq2"]

In [5]:
cell_type_key = cell_type_keys[-1]
adata = remove_sparsity(adata)
ref_adata = adata[adata.obs.study.isin(reference)].copy()
ref_adata = ref_adata[~ref_adata.obs.cell_type.isin(remove_cts)].copy()
query_adata = adata[adata.obs.study.isin(query)].copy()
adata = ref_adata.concatenate(query_adata)

#### Full Data processed

In [6]:
cts = adata.obs[cell_type_key].unique().tolist()
print(adata)
for celltype in cts:
    print(celltype, len(adata[adata.obs.cell_type.isin([celltype])]))

AnnData object with n_obs × n_vars = 15580 × 4000
    obs: 'study', 'cell_type', 'pred_label', 'pred_score', 'batch'
    obsm: 'X_seurat', 'X_symphony'
gamma 699
acinar 1669
alpha 5493
beta 4169
ductal 2142
endothelial 313
activated_stellate 464
schwann 25
mast 42
macrophage 79
epsilon 32
quiescent_stellate 193
t_cell 7
delta 253


#### Reference Data

In [7]:
cts = ref_adata.obs[cell_type_key].unique().tolist()
print(ref_adata)
for celltype in cts:
    print(celltype, len(ref_adata[ref_adata.obs.cell_type.isin([celltype])]))

AnnData object with n_obs × n_vars = 12291 × 4000
    obs: 'study', 'cell_type', 'pred_label', 'pred_score'
    obsm: 'X_seurat', 'X_symphony'
gamma 571
acinar 1167
alpha 4459
beta 3563
ductal 1557
endothelial 287
activated_stellate 355
schwann 20
mast 35
macrophage 63
epsilon 27
quiescent_stellate 180
t_cell 7


#### Query Data

In [8]:
cts = query_adata.obs[cell_type_key].unique().tolist()
print(query_adata)
for celltype in cts:
    print(celltype, len(query_adata[query_adata.obs.cell_type.isin([celltype])]))

AnnData object with n_obs × n_vars = 3289 × 4000
    obs: 'study', 'cell_type', 'pred_label', 'pred_score'
    obsm: 'X_seurat', 'X_symphony'
gamma 128
acinar 502
alpha 1034
delta 253
beta 606
ductal 585
endothelial 26
activated_stellate 109
schwann 5
mast 7
macrophage 16
epsilon 5
quiescent_stellate 13


#### Use the next cell only to simulate a semi labeled reference

In [9]:
#cells_per_ct = 500
#cell_type_key = "cell_type"
#
#indices = np.arange(len(ref_adata))
#labeled_ind = []
#cts = ref_adata.obs[cell_type_key].unique().tolist()
#for celltype in cts:
#    ct_indices = indices[ref_adata.obs[cell_type_key].isin([celltype])]
#    ct_sample_size = cells_per_ct
#    if cells_per_ct > len(ct_indices):
#        ct_sample_size = len(ct_indices)
#    ct_sel_ind = np.random.choice(
#        ct_indices, 
#        size=ct_sample_size, 
#        replace=False
#    )
#    labeled_ind += ct_sel_ind.tolist()
#    print(celltype, len(ct_indices), len(ct_sel_ind), len(labeled_ind))

# Supervised Reference Training (100% labels used)

In [10]:
tranvae_ref = TRANVAE(
    adata=ref_adata,
    condition_key=condition_key,
    cell_type_keys=cell_type_keys,
    hidden_layer_sizes=[128, 128],
)


INITIALIZING NEW NETWORK..............
Encoder Architecture:
	Input Layer in, out and cond: 4000 128 7
	Hidden Layer 1 in/out: 128 128
	Mean/Var Layer in/out: 128 10
Decoder Architecture:
	First Layer in, out and cond:  10 128 7
	Hidden Layer 1 in/out: 128 128
	Output Layer in/out:  128 4000 



In [11]:
early_stopping_kwargs = {
    "early_stopping_metric": "val_landmark_loss",
    "mode": "min",
    "threshold": 0,
    "patience": 20,
    "reduce_lr": True,
    "lr_patience": 13,
    "lr_factor": 0.1,
}

tranvae_ref.train(
    n_epochs=500,
    early_stopping_kwargs=early_stopping_kwargs,
    pretraining_epochs=200,
    alpha_epoch_anneal=1e6,
    eta=1,
    clustering_res=2,
)

 |███████████---------| 58.4%  - val_loss: 998.6225585938 - val_trvae_loss: 998.0628112793 - val_landmark_loss: 0.5597425550 - val_labeled_loss: 0.559742555004
ADJUSTED LR
 |███████████---------| 59.8%  - val_loss: 997.0893371582 - val_trvae_loss: 996.5737976074 - val_landmark_loss: 0.5155391365 - val_labeled_loss: 0.5155391365
Stopping early: no improvement of more than 0 nats in 20 epochs
If the early stopping criterion is too strong, please instantiate it with different parameters in the train method.
Saving best state of network...
Best State was in Epoch 277


# Unsupervised Query Training (0 labels used)

In [12]:
tranvae = TRANVAE.load_query_data(
    adata=query_adata,
    reference_model=tranvae_ref,
    labeled_indices=[],
)


INITIALIZING NEW NETWORK..............
Encoder Architecture:
	Input Layer in, out and cond: 4000 128 9
	Hidden Layer 1 in/out: 128 128
	Mean/Var Layer in/out: 128 10
Decoder Architecture:
	First Layer in, out and cond:  10 128 9
	Hidden Layer 1 in/out: 128 128
	Output Layer in/out:  128 4000 



In [13]:
early_stopping_kwargs = {
    "early_stopping_metric": "val_loss",
    "mode": "min",
    "threshold": 0,
    "patience": 20,
    "reduce_lr": True,
    "lr_patience": 13,
    "lr_factor": 0.1,
}

tranvae.train(
    n_epochs=500,
    early_stopping_kwargs=early_stopping_kwargs,
    pretraining_epochs=0,
    eta=1,
    weight_decay=0,
    clustering_res=2,
)

Therefore integer value of those labels is set to -1
Therefore integer value of those labels is set to -1

Initializing unlabeled landmarks with Leiden-Clustering with an unknown number of clusters.
Leiden Clustering succesful. Found 23 clusters.
 |███████████████-----| 77.0%  - val_loss: 1643.7260335286 - val_trvae_loss: 1643.7259114583 - val_landmark_loss: 0.0001158808 - val_unlabeled_loss: 0.1158808370
ADJUSTED LR
 |███████████████-----| 78.4%  - val_loss: 1641.0191650391 - val_trvae_loss: 1641.0190429688 - val_landmark_loss: 0.0001170005 - val_unlabeled_loss: 0.1170005451
Stopping early: no improvement of more than 0 nats in 20 epochs
If the early stopping criterion is too strong, please instantiate it with different parameters in the train method.
Saving best state of network...
Best State was in Epoch 370


In [14]:
tranvae.save(
    'tmp',
    overwrite=True
)

# Reloading point if trained model already exists

In [15]:
tranvae = TRANVAE.load(
    dir_path='tmp',
    adata=query_adata,
)

AnnData object with n_obs × n_vars = 3289 × 4000
    obs: 'study', 'cell_type', 'pred_label', 'pred_score', 'trvae_size_factors', 'trvae_labeled'
    obsm: 'X_seurat', 'X_symphony'

INITIALIZING NEW NETWORK..............
Encoder Architecture:
	Input Layer in, out and cond: 4000 128 9
	Hidden Layer 1 in/out: 128 128
	Mean/Var Layer in/out: 128 10
Decoder Architecture:
	First Layer in, out and cond:  10 128 9
	Hidden Layer 1 in/out: 128 128
	Output Layer in/out:  128 4000 



# Visualizing resulting adata latent representation and landmarks

#### Unlabeled Query Data accuracy

In [16]:
results_dict_q = tranvae.classify(
    #metric="euclidean"
)

preds_q = results_dict_q[cell_type_key]['preds']
probs_q = results_dict_q[cell_type_key]['probs']
print(classification_report(
    y_true=query_adata.obs[cell_type_key],
    y_pred=preds_q,
    labels=np.array(query_adata.obs[cell_type_key].unique().tolist())
))

                    precision    recall  f1-score   support

             gamma       0.61      0.96      0.75       128
            acinar       0.95      0.97      0.96       502
             alpha       1.00      0.97      0.98      1034
             delta       0.00      0.00      0.00       253
              beta       0.77      1.00      0.87       606
            ductal       0.97      0.96      0.96       585
       endothelial       1.00      1.00      1.00        26
activated_stellate       0.92      0.99      0.96       109
           schwann       0.62      1.00      0.77         5
              mast       1.00      0.71      0.83         7
        macrophage       1.00      0.94      0.97        16
           epsilon       0.22      1.00      0.36         5
quiescent_stellate       1.00      0.69      0.82        13

          accuracy                           0.90      3289
         macro avg       0.77      0.86      0.79      3289
      weighted avg       0.85      0.9

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


In [17]:
results_dict_q = tranvae.classify(
    metric="gaussian"
)

preds_q = results_dict_q[cell_type_key]['preds']
probs_q = results_dict_q[cell_type_key]['probs']
print(classification_report(
    y_true=query_adata.obs[cell_type_key],
    y_pred=preds_q,
    labels=np.array(query_adata.obs[cell_type_key].unique().tolist())
))

                    precision    recall  f1-score   support

             gamma       0.76      0.95      0.85       128
            acinar       0.98      0.93      0.95       502
             alpha       0.99      0.98      0.98      1034
             delta       0.00      0.00      0.00       253
              beta       0.76      1.00      0.86       606
            ductal       0.93      0.97      0.95       585
       endothelial       1.00      1.00      1.00        26
activated_stellate       0.91      0.98      0.95       109
           schwann       0.83      1.00      0.91         5
              mast       1.00      0.71      0.83         7
        macrophage       1.00      0.94      0.97        16
           epsilon       0.09      0.80      0.16         5
quiescent_stellate       0.82      0.69      0.75        13

          accuracy                           0.89      3289
         macro avg       0.77      0.84      0.78      3289
      weighted avg       0.84      0.8

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


#### Semi-Labeled Full Data accuracy

In [18]:
results_dict = tranvae.classify(
    adata.X,
    adata.obs[condition_key],
    metric="gaussian",
)

preds = results_dict[cell_type_key]['preds']
probs = results_dict[cell_type_key]['probs']
print(classification_report(
    y_true=adata.obs[cell_type_key],
    y_pred=preds,
    labels=np.array(adata.obs[cell_type_key].unique().tolist())
))

                    precision    recall  f1-score   support

             gamma       0.94      0.99      0.96       699
            acinar       0.91      0.96      0.93      1669
             alpha       0.99      0.99      0.99      5493
              beta       0.95      1.00      0.97      4169
            ductal       0.96      0.92      0.94      2142
       endothelial       1.00      0.98      0.99       313
activated_stellate       0.97      0.97      0.97       464
           schwann       0.89      0.96      0.92        25
              mast       0.97      0.90      0.94        42
        macrophage       0.97      0.99      0.98        79
           epsilon       0.40      0.97      0.56        32
quiescent_stellate       0.93      0.95      0.94       193
            t_cell       0.70      1.00      0.82         7
             delta       0.00      0.00      0.00       253

          accuracy                           0.96     15580
         macro avg       0.83      0.9

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


#### Get information for both landmark sets

In [19]:
labeled_set = tranvae.get_landmarks_info(
    metric="gaussian",
)
labeled_set

AnnData object with n_obs × n_vars = 13 × 10
    obs: 'study', 'cell_type', 'cell_type_pred', 'cell_type_prob'

In [20]:
unlabeled_set = tranvae.get_landmarks_info(
    landmark_set='unlabeled',
    metric="gaussian",
)
unlabeled_set

AnnData object with n_obs × n_vars = 23 × 10
    obs: 'study', 'cell_type', 'cell_type_pred', 'cell_type_prob'

#### Create adata object with cell  and landmark informations

In [21]:
data_latent = tranvae.get_latent()
adata_latent = sc.AnnData(data_latent)
adata_latent.obs[condition_key] = query_adata.obs[condition_key].tolist()
adata_latent.obs[cell_type_key] = query_adata.obs[cell_type_key].tolist()
adata_latent.obs[f'{cell_type_key}_pred'] = preds_q.tolist()
adata_latent.obs[f'{cell_type_key}_prob'] = probs_q.tolist()
adata_latent.obs['query'] = adata_latent.obs
adata_latent

ValueError: Wrong number of items passed 4, placement implies 1

In [None]:
full_latent = adata_latent.concatenate(labeled_set, unlabeled_set)
full_latent

#### Visualize Combined adata object

In [None]:
sc.pp.neighbors(full_latent, n_neighbors=8)
sc.tl.leiden(full_latent)
sc.tl.umap(full_latent)

In [None]:
full_latent.obs

In [None]:
sc.pl.umap(
    full_latent,
    size=10,
    color='query'
)

In [None]:
landmark_keys = ["Landmark-Set Unlabeled","Landmark-Set Labeled"]
sc.pl.umap(
    full_latent[~full_latent.obs[condition_key].isin(landmark_keys)],
    size=10,
    color=[cell_type_key],
)

In [None]:
ax = sc.pl.umap(full_latent, size=10, show=False)
sc.pl.umap(
    full_latent[full_latent.obs[condition_key] == "Landmark-Set Labeled"],
    size=100,
    color=[cell_type_key],
    ax=ax
)

In [None]:
ax = sc.pl.umap(full_latent, size=10, show=False)
sc.pl.umap(
    full_latent[full_latent.obs[condition_key] == "Landmark-Set Unlabeled"],
    size=100,
    color=[cell_type_key + '_pred'],
    ax=ax
)

# Unseen cell type detection

#### Check for novel cell types

In [None]:
print(unlabeled_set.obs["cell_type_prob"])

In [None]:
print(unlabeled_set.obs["cell_type_pred"])

In [None]:
check = np.array([4,19 ,21])
check_shifted = check + len(adata_latent) + len(labeled_set)

In [None]:
ax = sc.pl.umap(full_latent, size=10, show=False)

sc.pl.umap(
    full_latent[check_shifted],
    size=100,
    color=[cell_type_key + '_pred'],
    ax=ax
)

#### Add novel cell type with corresponding landmark(s) to the model

In [None]:
tranvae.add_new_cell_type(
    "delta", 
    cell_type_key, 
    check.tolist()
)

# Visualize results with updated model

In [None]:
results_dict = tranvae.classify(
    metric="gaussian"
)

preds = results_dict[cell_type_key]['preds']
probs = results_dict[cell_type_key]['probs']
print(classification_report(
    y_true=query_adata.obs[cell_type_key],
    y_pred=preds,
    labels=np.array(query_adata.obs[cell_type_key].unique().tolist())
))

In [None]:
labeled_set = tranvae.get_landmarks_info(
    metric="gaussian",
)
unlabeled_set = tranvae.get_landmarks_info(
    landmark_set='unlabeled',
    metric="gaussian",
)
labeled_set

In [None]:
results_dict = tranvae.classify(
    adata.X,
    adata.obs[condition_key],
    metric="gaussian",
)

preds = results_dict[cell_type_key]['preds']
probs = results_dict[cell_type_key]['probs']

data_latent = tranvae.get_latent(
    adata.X,
    adata.obs[condition_key],
)
adata_latent = sc.AnnData(data_latent)
adata_latent.obs[condition_key] = adata.obs[condition_key].tolist()
adata_latent.obs[cell_type_key] = adata.obs[cell_type_key].tolist()
adata_latent.obs[f'{cell_type_key}_pred'] = preds.tolist()
adata_latent.obs[f'{cell_type_key}_prob'] = probs.tolist()
full_latent = adata_latent.concatenate(labeled_set, unlabeled_set)

In [None]:
sc.pp.neighbors(full_latent, n_neighbors=8)
sc.tl.leiden(full_latent)
sc.tl.umap(full_latent)

In [None]:
landmark_keys = ["Landmark-Set Unlabeled","Landmark-Set Labeled"]
sc.pl.umap(
    full_latent[~full_latent.obs[condition_key].isin(landmark_keys)],
    size=10,
    color=[cell_type_key],
)

In [None]:
ax = sc.pl.umap(full_latent, size=10, show=False)
sc.pl.umap(
    full_latent[full_latent.obs[condition_key] == "Landmark-Set Labeled"],
    size=100,
    color=[cell_type_key + '_pred'],
    ax=ax
)