In [1]:
import os
import scanpy as sc
import torch
import scarches as sca
from scarches.dataset.trvae.data_handling import remove_sparsity
import matplotlib.pyplot as plt
import numpy as np
import time
import json

In [2]:
sc.settings.set_figure_params(dpi=200, frameon=False)
sc.set_figure_params(dpi=200)
sc.set_figure_params(figsize=(4, 4))
torch.set_printoptions(precision=3, sci_mode=False, edgeitems=7)

In [3]:
test_nr = 3
eta = 1000

surgery_epochs = 500

early_stopping_kwargs = {
    "early_stopping_metric": "val_unweighted_loss",
    "threshold": 0,
    "patience": 20,
    "reduce_lr": True,
    "lr_patience": 13,
    "lr_factor": 0.1,
}
dir_path = os.path.expanduser(f'~/Documents/benchmarking_results/rqr/tranvae_{eta}/pancreas/test_{test_nr}/')
surg_path = f'{dir_path}query_unsupervised/'

In [4]:
adata_all = sc.read(os.path.expanduser(f'~/Documents/benchmarking_datasets/pancreas_normalized.h5ad'))
condition_key = 'study'
cell_type_key = 'cell_type'
if test_nr == 1:
    reference = ['Pancreas inDrop']
    query = ['Pancreas SS2', 'Pancreas CelSeq2', 'Pancreas CelSeq', 'Pancreas Fluidigm C1']
elif test_nr == 2:
    reference = ['Pancreas inDrop', 'Pancreas SS2']
    query = ['Pancreas CelSeq2', 'Pancreas CelSeq', 'Pancreas Fluidigm C1']
elif test_nr == 3:
    reference = ['Pancreas inDrop', 'Pancreas SS2', 'Pancreas CelSeq2']
    query = ['Pancreas CelSeq', 'Pancreas Fluidigm C1']
elif test_nr == 4:
    reference = ['Pancreas inDrop', 'Pancreas SS2', 'Pancreas CelSeq2', 'Pancreas CelSeq']
    query = ['Pancreas Fluidigm C1']
elif test_nr == 5:
    reference = ['Pancreas inDrop', 'Pancreas SS2', 'Pancreas CelSeq2', 'Pancreas CelSeq',
                 'Pancreas Fluidigm C1']

In [5]:
adata = adata_all.raw.to_adata()
adata = remove_sparsity(adata)
source_adata = adata[~adata.obs.study.isin(query)].copy()
target_adata = adata[adata.obs.study.isin(query)].copy()

In [6]:
surgery_path = f'{surg_path}surg_model/'

In [7]:
new_tranvae = sca.models.TRANVAE.load_query_data(
    adata=adata,
    reference_model=surgery_path,
)


INITIALIZING NEW NETWORK..............
Encoder Architecture:
	Input Layer in, out and cond: 1000 128 5
	Hidden Layer 1 in/out: 128 128
	Mean/Var Layer in/out: 128 10
Decoder Architecture:
	First Layer in, out and cond:  10 128 5
	Hidden Layer 1 in/out: 128 128
	Output Layer in/out:  128 1000 



In [8]:
print(new_tranvae.landmarks_labeled_)
print(new_tranvae.landmarks_labeled_.shape)

[[ 0.05 -0.04 -0.03  0.04  0.   -0.07  0.19 -0.06 -0.13 -0.05]
 [ 0.06 -0.02 -0.04  0.08  0.03 -0.04  0.18 -0.02 -0.13 -0.04]
 [ 0.09 -0.02 -0.17  0.06 -0.03 -0.01  0.19 -0.1  -0.13 -0.08]
 [ 0.11 -0.08 -0.2  -0.01  0.07 -0.04  0.22 -0.09  0.07 -0.04]
 [ 0.08 -0.13 -0.14 -0.06 -0.02  0.02  0.17 -0.11 -0.17 -0.02]
 [ 0.1  -0.01 -0.07  0.06  0.05  0.05  0.15 -0.11 -0.12 -0.05]
 [ 0.04  0.03 -0.15  0.01 -0.1   0.05  0.19 -0.1  -0.11 -0.16]
 [ 0.09  0.01 -0.21  0.   -0.03 -0.05  0.06 -0.05 -0.13 -0.12]]
(8, 10)


In [9]:
new_tranvae.model.landmarks_labeled = torch.tensor(new_tranvae.model.landmarks_labeled)
new_tranvae.model.landmarks_unlabeled = torch.tensor(new_tranvae.model.landmarks_unlabeled)
new_tranvae.model.landmarks_labeled.size()

torch.Size([8, 10])

In [10]:
print(new_tranvae.landmarks_unlabeled_)
print(new_tranvae.landmarks_unlabeled_.shape)

[[ 0.03  0.02 -0.17  0.03 -0.1   0.07  0.17 -0.1  -0.1  -0.15]
 [ 0.1   0.   -0.09  0.06  0.06  0.06  0.15 -0.12 -0.12 -0.04]
 [ 0.1  -0.   -0.22  0.01 -0.02 -0.04  0.03 -0.05 -0.13 -0.11]
 [ 0.1  -0.09 -0.2   0.01  0.08 -0.04  0.22 -0.1   0.08 -0.02]
 [ 0.05 -0.03 -0.03  0.07  0.02 -0.04  0.18 -0.03 -0.13 -0.04]
 [ 0.07 -0.14 -0.14 -0.05 -0.04  0.04  0.16 -0.13 -0.15  0.  ]
 [ 0.1  -0.03 -0.21  0.1  -0.02  0.    0.18 -0.1  -0.13 -0.07]
 [ 0.09 -0.03 -0.15  0.05 -0.03 -0.02  0.17 -0.1  -0.13 -0.05]]
(8, 10)


In [11]:
source_latent = new_tranvae.get_latent(source_adata.X, source_adata.obs.study)
source_latent.shape

(13778, 10)

In [12]:
source_labels = source_adata.obs.cell_type
source_labels.shape

(13778,)

In [13]:
source_label_uniq = source_adata.obs.cell_type.unique().tolist()
source_label_uniq

['Pancreas Endothelial',
 'Pancreas Acinar',
 'Pancreas Beta',
 'Pancreas Delta',
 'Pancreas Stellate',
 'Pancreas Ductal',
 'Pancreas Alpha',
 'Pancreas Gamma']

In [14]:
check_prob_ct = 'Pancreas Ductal'

In [15]:
alpha_check = np.where(source_labels == check_prob_ct)[0]
alpha_check.shape

(1675,)

In [16]:
alpha_latent = source_latent[alpha_check,:]
alpha_latent.shape

(1675, 10)

In [17]:
alpha_landmark = np.mean(alpha_latent, axis=0)
alpha_landmark

array([ 0.1 , -0.01, -0.07,  0.06,  0.05,  0.05,  0.15, -0.11, -0.12,
       -0.05], dtype=float32)

In [18]:
alpha_std = np.std(alpha_latent, axis=0)
alpha_std

array([0.03, 0.03, 0.04, 0.03, 0.04, 0.04, 0.03, 0.04, 0.03, 0.03],
      dtype=float32)

In [19]:
normal_dist = torch.distributions.Normal(torch.tensor(alpha_landmark),torch.tensor(alpha_std))
normal_dist.mean

tensor([ 0.101, -0.014, -0.074,  0.059,  0.049,  0.052,  0.153, -0.111, -0.122,
        -0.053])

In [None]:
target_alpha_idx = np.where(target_adata.obs.cell_type == check_prob_ct)[0]
target_alpha_idx.shape

In [None]:
target_latent = new_tranvae.get_latent(target_adata.X, target_adata.obs.study)
target_latent.shape

In [None]:
first_prob = normal_dist.cdf(torch.tensor(target_latent[target_alpha_idx[0],:]))

In [None]:
first_prob

In [None]:
torch.mean(first_prob,axis=0)

In [None]:
mean_prob = normal_dist.log_prob(torch.tensor(alpha_landmark)).exp()
normalize = torch.mean(mean_prob,axis=0)

In [None]:
alpha_prob = normal_dist.log_prob(torch.tensor(target_latent[target_alpha_idx,:])).exp() / normalize

alpha_prob.size()

In [None]:
alpha_prob = torch.mean(alpha_prob,axis=1)
alpha_prob.size()

In [None]:
alpha_prob.min()

In [None]:
alpha_prob.max()

In [None]:
target_pred = new_tranvae.classify(target_adata.X, target_adata.obs.study)
target_pred.shape

In [None]:
alpha_pred_idx = np.where(target_pred == check_prob_ct)[0]

correct_idx = []
incorrect_idx = []
for idx in target_alpha_idx:
    if idx in alpha_pred_idx:
        correct_idx.append(idx)
    else:
        incorrect_idx.append(idx)

In [None]:
len(correct_idx)

In [None]:
len(incorrect_idx)

In [None]:
correct_alpha_probs =  normal_dist.log_prob(torch.tensor(target_latent[correct_idx,:])).exp() / normalize
correct_alpha_probs = torch.mean(correct_alpha_probs,axis=1)
incorrect_alpha_probs =  normal_dist.log_prob(torch.tensor(target_latent[incorrect_idx,:])).exp() / normalize
incorrect_alpha_probs = torch.mean(incorrect_alpha_probs,axis=1)

In [None]:
correct_alpha_probs.shape

In [None]:
print(correct_alpha_probs.min())
print(correct_alpha_probs.max())

In [None]:
incorrect_alpha_probs.shape

In [None]:
print(incorrect_alpha_probs.min())
print(incorrect_alpha_probs.max())

In [None]:
data = [correct_alpha_probs.detach().numpy(), incorrect_alpha_probs.detach().numpy()]

In [None]:
def set_axis_style(ax, labels):
    ax.get_xaxis().set_tick_params(direction='out')
    ax.xaxis.set_ticks_position('bottom')
    ax.set_xticks(np.arange(1, len(labels) + 1))
    ax.set_xticklabels(labels)
    ax.set_xlim(0.25, len(labels) + 0.75)
    ax.set_xlabel('Sample name')


fig, (ax1, ax2) = plt.subplots(nrows=1, ncols=2, figsize=(9, 4), sharey=True)

ax1.set_title('Default violin plot')
ax1.set_ylabel('Observed values')
ax1.violinplot(data)

ax2.set_title('Customized violin plot')
parts = ax2.violinplot(
        data, showmeans=False, showmedians=False,
        showextrema=False)

for pc in parts['bodies']:
    pc.set_facecolor('#D43F3A')
    pc.set_edgecolor('black')
    pc.set_alpha(1)

# set style for the axes
labels = ['Correct', 'Incorrect']
for ax in [ax1, ax2]:
    set_axis_style(ax, labels)

plt.subplots_adjust(bottom=0.15, wspace=0.05)
plt.show()