In [None]:
import warnings
warnings.filterwarnings("ignore")

import pandas as pd
import numpy as np
import scanpy as sc
import matplotlib.pyplot as plt
import os
import sys
from scipy.spatial import distance
from sklearn.metrics.cluster import adjusted_rand_score,normalized_mutual_info_score,homogeneity_score,contingency_matrix
from VGAE_GCN.adj import graph
from VGAE_GCN.train_VGAE import train_model
from VGAE_GCN.VGAE_model import VGAE
from VGAE_GCN.utils import *

In [None]:
def purity_score(y_true, y_pred):
    # compute contingency matrix (also called confusion matrix)
    cm = contingency_matrix(y_true, y_pred)
    # return purity
    return np.sum(np.amax(cm, axis=0)) / np.sum(cm)

In [None]:
seed_list = [0]
 
for seed in seed_list:
    DLPFC_dir = "../../../dataset/DLPFC"
    
    section_id = '151673'
    
    knn = 7
    
    print(section_id, knn)
    ann_data_raw = sc.read_visium(os.path.join(DLPFC_dir, section_id),
                              count_file=section_id + '_filtered_feature_bc_matrix.h5')
    # load groud truth
    ann_df = pd.read_csv(os.path.join(DLPFC_dir, section_id, section_id + "_truth.txt"),
                         sep="\t", header=None, index_col=0)
    ann_df.columns = ["Manual annotation"]
    ann_data_raw.obs.loc[:, "Manual annotation"] = ann_df.loc[ann_data_raw.obs_names, 'Manual annotation']
    ann_data_raw.var_names_make_unique()
    print(ann_data_raw)
    
    ann_data = ann_data_raw.copy()

    sc.pp.highly_variable_genes(ann_data, flavor="seurat_v3", n_top_genes=3000)
    sc.pp.normalize_total(ann_data, target_sum=1e4)
    sc.pp.log1p(ann_data)

    net = graph(ann_data)
    net.compute_spatial_net()
    net.Stats_Spatial_Net()

    ann_data = train_model(ann_data, input_dim=3000, seed=seed)

    adata = mclust_R(ann_data, used_obsm='z', num_cluster=knn)
    indices = np.logical_not(ann_data.obs["Manual annotation"].isna())
    ground_truth = ann_data.obs["Manual annotation"].dropna()
    mclust_ari = adjusted_rand_score(ann_data.obs['mclust'][indices], ground_truth[indices])
    print("mclust ari is: {:.4f}".format(mclust_ari))
    mclust_nmi = normalized_mutual_info_score(ann_data.obs['mclust'][indices], ground_truth[indices])
    print("mclust nmi is: {:.4f}".format(mclust_nmi))
    mclust_hs = homogeneity_score(ann_data.obs['mclust'][indices], ground_truth[indices])
    print("mclust hs is: {:.4f}".format(mclust_hs))
    mclust_purity = purity_score(ann_data.obs['mclust'][indices], ground_truth[indices])
    print("mclust purity is: {:.4f}".format(mclust_purity))

    adj_2d = distance.cdist(ann_data.obsm['spatial'], ann_data.obsm['spatial'], 'euclidean')
    refined_pred= refine(sample_id=ann_data.obs.index.tolist(), 
                                 pred=ann_data.obs["mclust"].tolist(), dis=adj_2d, shape="hexagon")
    ann_data.obs["mclust_refine"]= refined_pred
    indices = np.logical_not(ann_data.obs["Manual annotation"].isna())
    ground_truth = ann_data.obs["Manual annotation"].dropna()
    mclust_ari = adjusted_rand_score(ann_data.obs['mclust_refine'][indices], ground_truth[indices])
    print("mclust ari is: {:.4f}".format(mclust_ari))
    mclust_nmi = normalized_mutual_info_score(ann_data.obs['mclust_refine'][indices], ground_truth[indices])
    print("mclust nmi is: {:.4f}".format(mclust_nmi))
    mclust_hs = homogeneity_score(ann_data.obs['mclust_refine'][indices], ground_truth[indices])
    print("mclust hs is: {:.4f}".format(mclust_hs))
    mclust_purity = purity_score(ann_data.obs['mclust_refine'][indices], ground_truth[indices])
    print("mclust purity is: {:.4f}".format(mclust_purity))
    
    #file = ann_data.obs['mclust_refine']
    #np.save(os.path.join(section_id,'version_'+str(seed),section_id+'_pred.npy'), file)

In [None]:
from matplotlib.colors import ListedColormap, LinearSegmentedColormap
newcmp = LinearSegmentedColormap.from_list('new', ['#450756','#451464','#426189','#2BD32B','#F9F871'], N=1000)

list_genes = ['AQP1','KRT17','NTNG1','KCNG1','MT1H']

for gene_raw in list_genes:
    idx = ann_data.var.index.tolist().index(gene_raw) #输出基因的idx
    ann_data.obs[f'{gene_raw}(raw)'] = ann_data.X.todense()[:, idx]
    sc.pl.spatial(ann_data, img_key="hires",
              color=f'{gene_raw}(raw)',
              title=gene_raw,
              color_map=newcmp)

In [None]:
for gene in list_genes:
    idx = ann_data.var.index.tolist().index(gene) #输出基因的idx
    ann_data.obs[f'{gene}(denoised)'] = ann_data.obsm['ReX'][:, idx]
    sc.pl.spatial(ann_data, img_key="hires",
              color=f'{gene}(denoised)',
              title=gene,
              color_map=newcmp)

In [None]:
ann_data.layers['ReX'] = ann_data.obsm['ReX']
sc.pl.stacked_violin(ann_data,list_genes,groupby='Manual annotation',figsize=(8,4),swap_axes=True,layer='ReX')

In [None]:
sc.pl.stacked_violin(ann_data,list_genes,groupby='Manual annotation',figsize=(8,4),swap_axes=True)