In [None]:
import torch
import scanpy as sc

from sklearn.metrics import adjusted_rand_score
from sklearn.cluster import KMeans
import pickle
import numpy as np
import os

from scvi.model import SCVI
from scripts import constants
from tqdm import tqdm
import matplotlib.pyplot as plt
import seaborn as sns

from rich.console import Console
sns.set()

In [None]:
adata = sc.read_h5ad(os.path.join(constants.DEFAULT_DATA_PATH, "haber_2017/preprocessed/adata_top_2000_genes.h5ad"))

target_adata = adata[adata.obs['condition'] != 'Control'].copy()
background_adata = adata[adata.obs['condition'] == 'Control'].copy()

SCVI.setup_anndata(
    target_adata, layer="count"
)

In [None]:
seeds = [123, 42, 789, 46, 999]
contrastiveVI_models = [
    torch.load(os.path.join(constants.DEFAULT_RESULTS_PATH, f"haber_2017/contrastiveVI/{x}/model.chkpt")) for x in seeds]
scVI_models = [
    torch.load(os.path.join(constants.DEFAULT_RESULTS_PATH, f"haber_2017/scVI/{x}/model.chkpt")) for x in seeds]

PCPCA_model = pickle.load(open(os.path.join(constants.DEFAULT_RESULTS_PATH, "haber_2017/PCPCA/model.pkl"), "rb"))

In [None]:
scVI_latent_representations = [
    model.get_latent_representation(
        adata=target_adata) for model in tqdm(scVI_models)]

contrastiveVI_salient_latent_representations = [
    model.get_latent_representation(
        adata=target_adata, representation_kind="salient") for model in tqdm(contrastiveVI_models)]

PCPCA_salient_representations = PCPCA_model.transform(
    target_adata.X.transpose(),
    background_adata.X.transpose()
)[0].transpose()

In [None]:
scVI_aris = [adjusted_rand_score(
    target_adata.obs['condition'],
    KMeans(n_clusters=2).fit(latent).labels_
) for latent in tqdm(scVI_latent_representations)]

contrastiveVI_aris = [adjusted_rand_score(
    target_adata.obs['condition'],
    KMeans(n_clusters=2).fit(salient_latent).labels_
) for salient_latent in tqdm(contrastiveVI_salient_latent_representations)]

PCPCA_ari = adjusted_rand_score(
    target_adata.obs['condition'],
    KMeans(n_clusters=2).fit(PCPCA_salient_representations).labels_
)

In [None]:
from sklearn.metrics import silhouette_score
from sklearn.metrics import adjusted_mutual_info_score
from sklearn.preprocessing import LabelEncoder


scVI_silhouettes = [silhouette_score(
    latent,
    LabelEncoder().fit_transform(target_adata.obs['condition'])
) for latent in tqdm(scVI_latent_representations)]

contrastiveVI_silhouettes = [silhouette_score(
    salient_latent,
    LabelEncoder().fit_transform(target_adata.obs['condition'])
) for salient_latent in tqdm(contrastiveVI_salient_latent_representations)]

PCPCA_silhouette = silhouette_score(
    PCPCA_salient_representations,
    LabelEncoder().fit_transform(target_adata.obs['condition'])
)

In [None]:
from sklearn.metrics import adjusted_mutual_info_score
from sklearn.preprocessing import LabelEncoder

scVI_nmis = [adjusted_mutual_info_score(
    LabelEncoder().fit_transform(target_adata.obs['condition']),
    KMeans(n_clusters=2).fit(latent).labels_
) for latent in tqdm(scVI_latent_representations)]

contrastiveVI_nmis = [adjusted_mutual_info_score(
    LabelEncoder().fit_transform(target_adata.obs['condition']),
    KMeans(n_clusters=2).fit(salient_latent).labels_
) for salient_latent in tqdm(contrastiveVI_salient_latent_representations)]

PCPCA_nmi = adjusted_mutual_info_score(
    target_adata.obs['condition'],
    KMeans(n_clusters=2).fit(PCPCA_salient_representations).labels_
)

In [None]:
from sklearn.decomposition import PCA

In [None]:
scVI_pca = PCA().fit_transform(scVI_latent_representations[1])
contrastiveVI_pca = PCA().fit_transform(contrastiveVI_salient_latent_representations[1])

In [None]:
target_labels = target_adata.obs['condition']

In [None]:
fig, (ax1, ax2, ax3) = plt.subplots(figsize=(15,5), nrows=1, ncols=3)

for i, label in enumerate(target_labels.unique()):
    ax1.scatter(
        contrastiveVI_pca[:, 0][target_labels == label],
        contrastiveVI_pca[:, 1][target_labels == label],
        s=1,
        label=label,
        color=sns.color_palette("Set1")[i]
    )

ax1.set_xlabel("PC1")
ax1.set_ylabel("PC2")
ax1.set_xticklabels([])
ax1.set_yticklabels([])
ax1.legend(markerscale=5)
ax1.set_title("contrastiveVI")

for i, label in enumerate(target_labels.unique()):
    ax2.scatter(
        scVI_pca[:, 0][target_labels == label],
        scVI_pca[:, 1][target_labels == label],
        s=1,
        label=label,
        color=sns.color_palette("Set1")[i]
    )

ax2.set_xlabel("PC1")
ax2.set_ylabel("PC2")
ax2.set_xticklabels([])
ax2.set_yticklabels([])
ax2.set_title("scVI")

labels = ['scVI', 'PCPCA', 'contrastiveVI']
width = 0.25  # the width of the bars
x = np.arange(len(labels))
rects1 = ax3.bar(
    x,
    [np.mean(scVI_silhouettes), PCPCA_silhouette, np.mean(contrastiveVI_silhouettes)],
    width,
    yerr=[np.std(scVI_silhouettes), 0, np.std(contrastiveVI_silhouettes)],
    capsize=10,
    label='ASW'
)
rects2 = ax3.bar(
    x + width,
    [np.mean(scVI_aris), PCPCA_ari, np.mean(contrastiveVI_aris)],
    width,
    yerr=[np.std(scVI_aris), 0, np.std(contrastiveVI_aris)],
    capsize=10,
    label='ARI'
)
rects3 = ax3.bar(
    x + 2*width,
    [np.mean(scVI_nmis), PCPCA_nmi, np.mean(contrastiveVI_nmis)],
    width,
    yerr=[np.std(scVI_nmis), 0, np.std(contrastiveVI_nmis)],
    capsize=10,
    label='AMI'
)

ax3.set_xticks(x + width)
ax3.set_xticklabels(labels)
ax3.legend()
ax3.set_title("Clustering Metrics")

plt.tight_layout()
plt.show()
    