In [None]:
import torch
import scanpy as sc

from sklearn.metrics import adjusted_rand_score, silhouette_score, adjusted_mutual_info_score
from sklearn.preprocessing import LabelEncoder, StandardScaler
from sklearn.cluster import KMeans
import numpy as np
import os
from sklearn.decomposition import PCA
from scvi.model import SCVI
from scripts import constants
import pickle
from tqdm import tqdm
import matplotlib.pyplot as plt
import seaborn as sns
sns.set()

In [None]:
dataset = "fasolino_2021"

In [None]:
adata = sc.read_h5ad(
    os.path.join(
        constants.DEFAULT_DATA_PATH,
        f"{dataset}/preprocessed/adata_top_2000_genes.h5ad",
    )
)

target_adata = adata[adata.obs["disease_state"] != "Control"].copy()
background_adata = adata[adata.obs["disease_state"] == "Control"].copy()

target_trans_adata = target_adata.copy()
background_trans_adata = background_adata.copy()

SCVI.setup_anndata(target_adata, layer="count")
SCVI.setup_anndata(target_trans_adata)

In [None]:
torch_device = "cuda:2"
seeds = [123, 42, 789, 46, 999]

contrastiveVI_models = [
    torch.load(
        os.path.join(constants.DEFAULT_RESULTS_PATH, f"{dataset}/contrastiveVI/{seed}/model.ckpt"),
        map_location=torch_device,
    )
    for seed in tqdm(seeds)
]
scVI_models = [
    torch.load(
        os.path.join(constants.DEFAULT_RESULTS_PATH, f"{dataset}/scVI/{seed}/model.ckpt"),
        map_location=torch_device,
    )
    for seed in tqdm(seeds)
]
tc_contrastive_vi_models = [
    torch.load(
        os.path.join(constants.DEFAULT_RESULTS_PATH, f"{dataset}/TC_contrastiveVI/{seed}/model.ckpt"),
        map_location=torch_device,
    )
    for seed in tqdm(seeds)
]
cvae_models = [
    torch.load(
        os.path.join(constants.DEFAULT_RESULTS_PATH, f"{dataset}/cVAE/{seed}/model.ckpt"),
        map_location=torch_device,
    )
    for seed in tqdm(seeds)
]
cplvm_models = [
    pickle.load(
        open(os.path.join(constants.DEFAULT_RESULTS_PATH, f"{dataset}/CPLVM/{seed}/model.pkl"), "rb")
    )
    for seed in tqdm(seeds)
]
PCPCA_model = pickle.load(
    open(
        os.path.join(constants.DEFAULT_RESULTS_PATH, f"{dataset}/PCPCA/model.pkl"), "rb"
    )
)
cpca_model = pickle.load(
    open(
        os.path.join(constants.DEFAULT_RESULTS_PATH, f"{dataset}/cPCA/model.pkl"), "rb"
    )
)

In [None]:
scVI_latent_representations = [
    model.get_latent_representation(adata=target_adata)
    for model in tqdm(scVI_models)
]

contrastiveVI_salient_latent_representations = [
    model.get_latent_representation(adata=target_adata, representation_kind="salient")
    for model in tqdm(contrastiveVI_models)
]

tc_contrastive_vi_salient_representations = [
    model.get_latent_representation(adata=target_adata, representation_kind="salient")
    for model in tqdm(tc_contrastive_vi_models)
]
cvae_salient_representations = [
    model.get_latent_representation(adata=target_trans_adata, representation_kind="salient")
    for model in tqdm(cvae_models)
]
cplvm_salient_representations = [
    model["qty_mean"].transpose()
    for model in tqdm(cplvm_models)
]

In [None]:
PCPCA_salient_representations = PCPCA_model.transform(
    StandardScaler().fit_transform(target_adata.layers["count"]).transpose(),
    StandardScaler().fit_transform(background_adata.layers["count"]).transpose(),
)[0].transpose()

cpca_salient_representations = cpca_model.transform(
    target_adata.X, n_alphas_to_return=1
)[0]

PCPCA_salient_representations = [PCPCA_salient_representations]
cpca_salient_representations = [cpca_salient_representations]

In [None]:
salient_representation_dict = {
    "scVI" : scVI_latent_representations,
    "cPCA": cpca_salient_representations,
    "PCPCA": PCPCA_salient_representations,
    "cVAE": cvae_salient_representations,
    "CPLVM": cplvm_salient_representations,
    "contrastiveVI": contrastiveVI_salient_latent_representations,
    "TC_contrastiveVI": tc_contrastive_vi_salient_representations,
}

In [None]:
labels = target_adata.obs["disease_state"]
one_hot_labels = LabelEncoder().fit_transform(labels)

silhouette_results_dict = {}
ari_results_dict = {}
nmi_results_dict = {}
for model_name, latent_list in salient_representation_dict.items():
    print(f"Evaluating {model_name} representations...")
    silhouette_results = []
    ari_results = []
    nmi_results = []
    
    for latent in tqdm(latent_list):
        latent_clustering = KMeans(n_clusters=2).fit(latent).labels_
        silhouette_results.append(silhouette_score(latent, one_hot_labels))
        ari_results.append(adjusted_rand_score(labels, latent_clustering))
        nmi_results.append(adjusted_mutual_info_score(one_hot_labels, latent_clustering))
        
    silhouette_results_dict[model_name] = silhouette_results
    ari_results_dict[model_name] = ari_results
    nmi_results_dict[model_name] = nmi_results
    
print("Done!")

In [None]:
scvi_pca = PCA().fit_transform(scVI_latent_representations[0])
contrastive_vi_pca = PCA().fit_transform(contrastiveVI_salient_latent_representations[0])
tc_contrastive_vi_pca = PCA().fit_transform(tc_contrastive_vi_salient_representations[0])
pca_dict = {
    "scVI": scvi_pca,
    "contrastiveVI": contrastive_vi_pca,
    "TC_contrastiveVI": tc_contrastive_vi_pca,
}
target_labels = target_adata.obs["disease_state"]

In [None]:
fix, axes = plt.subplots(figsize=(15, 5), nrows=1, ncols=len(pca_dict))
colors = ["purple", "goldenrod"]

ax_counter = 0
for model_name, pca in pca_dict.items():
    ax = axes[ax_counter]
    for i, label in enumerate(target_labels.unique()):
        ax.scatter(
            pca[:, 0][target_labels == label],
            pca[:, 1][target_labels == label],
            s=1,
            label=label,
            color=colors[i],
            alpha=0.5,
        )
    ax.set_xlabel("PC1")
    ax.set_ylabel("PC2")
    ax.set_xticklabels([])
    ax.set_yticklabels([])
    ax.legend(markerscale=5)
    ax.set_title(model_name)
    ax_counter += 1

In [None]:
model_labels = [model_name for model_name in salient_representation_dict.keys()]
width = 0.25  # Bar width
x = np.arange(len(model_labels))
fig = plt.figure(figsize=(10, 5))
ax = fig.add_axes([0, 0, 1, 1])
ax.bar(
    x,
    [np.mean(silhouette_results_dict[model_label]) for model_label in model_labels],
    width,
    yerr=[np.std(silhouette_results_dict[model_label]) for model_label in model_labels],
    capsize=10,
    label="Silhouette score",
)
ax.bar(
    x + width,
    [np.mean(ari_results_dict[model_label]) for model_label in model_labels],
    width,
    yerr=[np.std(ari_results_dict[model_label]) for model_label in model_labels],
    capsize=10,
    label="Adjusted random index",
)
ax.bar(
    x + width * 2,
    [np.mean(nmi_results_dict[model_label]) for model_label in model_labels],
    width,
    yerr=[np.std(nmi_results_dict[model_label]) for model_label in model_labels],
    capsize=10,
    label="Adjusted mutual information",
)

ax.set_xticks(x + width)
ax.set_xticklabels(model_labels, rotation=15)
ax.legend()
ax.set_title("Clustering Performance")