In [None]:
import warnings
warnings.filterwarnings("ignore")

import pandas as pd
import numpy as np
import scanpy as sc
import matplotlib.pyplot as plt
import os
import sys
from scipy.spatial import distance
from sklearn.metrics.cluster import adjusted_rand_score,normalized_mutual_info_score,homogeneity_score,contingency_matrix
from VGAE_GCN.adj import graph
from VGAE_GCN.train_VGAE import train_model
from VGAE_GCN.VGAE_model import VGAE
from VGAE_GCN.utils import *

In [None]:
def purity_score(y_true, y_pred):
    # compute contingency matrix (also called confusion matrix)
    cm = contingency_matrix(y_true, y_pred)
    # return purity
    return np.sum(np.amax(cm, axis=0)) / np.sum(cm)

In [None]:
seed_list = [0]
 
for seed in seed_list:
    DLPFC_dir = "../../../dataset/BRCA1"
    
    knn = 20
    
    ann_data = sc.read_visium(os.path.join(DLPFC_dir,'V1_Human_Breast_Cancer_Block_A_Section_1'))
    # load groud truth
    ann_df = pd.read_csv(os.path.join(DLPFC_dir,"metadata.tsv"),
                         sep="\t", index_col=0)
    ann_data.obs.loc[:, "Manual annotation"] = ann_df.loc[ann_data.obs_names, 'fine_annot_type']
    ann_data.var_names_make_unique()
    print(ann_data)
    

    sc.pp.highly_variable_genes(ann_data, flavor="seurat_v3", n_top_genes=3000)
    sc.pp.normalize_total(ann_data, target_sum=1e4)
    sc.pp.log1p(ann_data)

    net = graph(ann_data)
    net.compute_spatial_net()
    net.Stats_Spatial_Net()

    ann_data = train_model(ann_data, input_dim=3000, seed=seed, rad_cutoff=300)

    adata = mclust_R(ann_data, used_obsm='z', num_cluster=knn)
    indices = np.logical_not(ann_data.obs["Manual annotation"].isna())
    ground_truth = ann_data.obs["Manual annotation"].dropna()
    mclust_ari = adjusted_rand_score(ann_data.obs['mclust'][indices], ground_truth[indices])
    print("mclust ari is: {:.4f}".format(mclust_ari))
    mclust_nmi = normalized_mutual_info_score(ann_data.obs['mclust'][indices], ground_truth[indices])
    print("mclust nmi is: {:.4f}".format(mclust_nmi))
    mclust_hs = homogeneity_score(ann_data.obs['mclust'][indices], ground_truth[indices])
    print("mclust hs is: {:.4f}".format(mclust_hs))
    mclust_purity = purity_score(ann_data.obs['mclust'][indices], ground_truth[indices])
    print("mclust purity is: {:.4f}".format(mclust_purity))

    adj_2d = distance.cdist(ann_data.obsm['spatial'], ann_data.obsm['spatial'], 'euclidean')
    refined_pred= refine(sample_id=ann_data.obs.index.tolist(), 
                                 pred=ann_data.obs["mclust"].tolist(), dis=adj_2d, shape="hexagon")
    ann_data.obs["mclust_refine"]= refined_pred
    indices = np.logical_not(ann_data.obs["Manual annotation"].isna())
    ground_truth = ann_data.obs["Manual annotation"].dropna()
    mclust_ari = adjusted_rand_score(ann_data.obs['mclust_refine'][indices], ground_truth[indices])
    print("mclust ari is: {:.4f}".format(mclust_ari))
    mclust_nmi = normalized_mutual_info_score(ann_data.obs['mclust_refine'][indices], ground_truth[indices])
    print("mclust nmi is: {:.4f}".format(mclust_nmi))
    mclust_hs = homogeneity_score(ann_data.obs['mclust_refine'][indices], ground_truth[indices])
    print("mclust hs is: {:.4f}".format(mclust_hs))
    mclust_purity = purity_score(ann_data.obs['mclust_refine'][indices], ground_truth[indices])
    print("mclust purity is: {:.4f}".format(mclust_purity))
    
    file = ann_data.obs['mclust_refine']
    np.save(os.path.join('result','version_'+str(seed),'pred.npy'),file)

In [None]:
sc.pl.spatial(ann_data, color='mclust', show=False)

In [None]:
sc.pp.neighbors(ann_data, use_rep='z', metric='cosine', n_pcs=8)
sc.tl.umap(ann_data)


sc.pl.umap(ann_data, color='mclust', show=False)

In [None]:
import glob
import itertools
from typing import List
from scipy.spatial import distance
from scipy.cluster import hierarchy


def labels_connectivity_mat(labels: np.ndarray):
    _labels = labels - np.min(labels)
    n_classes = np.unique(_labels)
    mat = np.zeros([labels.size, labels.size])
    for i in n_classes:
        indices = np.squeeze(np.where(_labels == i))  #将属于各个类的标签提取出来
        row_indices, col_indices = zip(*itertools.product(indices, indices))
        mat[row_indices, col_indices] = 1
    return mat


def consensus_matrix(labels_list: List[np.ndarray]):
    mat = 0
    for labels in labels_list:
        mat += labels_connectivity_mat(labels)
    return mat / float(len(labels_list))


def plot_consensus_map(cmat, method="average", return_linkage=True, **kwargs):
    row_linkage = hierarchy.linkage(distance.pdist(cmat), method=method)
    col_linkage = hierarchy.linkage(distance.pdist(cmat.T), method=method)
    figure = sns.clustermap(cmat, row_linkage=row_linkage, col_linkage=col_linkage, **kwargs)
    if return_linkage:
        return row_linkage, col_linkage, figure
    else:
        return figure

In [None]:
save_dir = os.path.join('result')
name = "pred.npy"
num_cluster = knn

sys.setrecursionlimit(100000)
label_files = glob.glob(save_dir + f"/version_*/{name}")
labels_list = list(map(lambda file: np.load(file), label_files))
cons_mat = consensus_matrix(labels_list)
row_linkage, _, figure = plot_consensus_map(cons_mat, return_linkage=True)  # 获取层次聚类结果和热度图
figure.savefig(os.path.join(save_dir, "consensus_clustering.png"), dpi=300)  # 保存图片
consensus_labels = hierarchy.cut_tree(row_linkage, num_cluster).squeeze()  # 得到y*
np.save(os.path.join(save_dir, "consensus_labels"), consensus_labels)

In [None]:
pred = np.load(os.path.join('result','consensus_labels.npy'))
indices = np.logical_not(ann_data.obs["Manual annotation"].isna())
ground_truth = ann_data.obs["Manual annotation"].dropna()
ari = adjusted_rand_score(pred[indices], ground_truth[indices])
print("ari is: {:.4f}".format(ari))
nmi = normalized_mutual_info_score(pred[indices], ground_truth[indices])
print("nmi is: {:.4f}".format(nmi))
hs = homogeneity_score(pred[indices], ground_truth[indices])
print("hs is: {:.4f}".format(hs))
purity = purity_score(pred[indices], ground_truth[indices])
print("purity is: {:.4f}".format(purity))