In [1]:
import os
import torch
import pandas as pd
import scanpy as sc
from sklearn import metrics
import multiprocessing as mp
from sklearn.metrics import adjusted_rand_score
from GraphST import GraphST
from sklearn.metrics import accuracy_score
from sklearn.metrics import confusion_matrix
from scipy.optimize import linear_sum_assignment
import numpy as np

In [2]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
# the location of R, which is necessary for mclust algorithm. Please replace the path below with local R installation path
os.environ['R_HOME'] = r"C:\Program Files\R\R-4.4.2"

In [3]:
n_clusters = 8
slicename = '29'

In [4]:
adata = sc.read_h5ad(fr"C:\E\JSU\BIO\file\STrafer\params\merfish\{slicename}.h5ad")
label_mapping = {'MPA': 1, 'MPN': 2, 'BST': 3, 'fx': 4, "PVH":5,"PVT":6,"V3":7, 'PV': 0}
labels = adata.obs['ground_truth'].map(label_mapping)

adata.var_names_make_unique()
# pre-process

adata.layers['count'] = adata.X
sc.pp.filter_genes(adata, min_cells=50)
sc.pp.filter_genes(adata, min_counts=10)
sc.pp.normalize_total(adata, target_sum=1e6)
sc.pp.highly_variable_genes(adata, flavor="seurat_v3", layer='count', n_top_genes=150)
adata = adata[:, adata.var['highly_variable'] == True]
sc.pp.scale(adata)

  view_to_actual(adata)


In [5]:
# define model
model = GraphST.GraphST(adata, device=device)
# train model
adata = model.train()

Begin to train ST data...


100%|██████████| 600/600 [00:13<00:00, 44.20it/s]

Optimization finished for ST data!





In [6]:
# set radius to specify the number of neighbors considered during refinement
radius = 50

tool = 'mclust' # mclust, leiden, and louvain

# clustering
from GraphST.utils import clustering

if tool == 'mclust':
   clustering(adata, n_clusters, radius=radius, method=tool, refinement=True) # For DLPFC dataset, we use optional refinement step.
elif tool in ['leiden', 'louvain']:
   clustering(adata, n_clusters, radius=radius, method=tool, start=0.1, end=2.0, increment=0.01, refinement=False)

R[write to console]:                    __           __ 
   ____ ___  _____/ /_  _______/ /_
  / __ `__ \/ ___/ / / / / ___/ __/
 / / / / / / /__/ / /_/ (__  ) /_  
/_/ /_/ /_/\___/_/\__,_/____/\__/   version 6.1.1
Type 'citation("mclust")' for citing this R package in publications.



fitting ...


In [7]:
y_pred = adata.obs['domain'].values.astype(int)-1

In [8]:
conf_mat = confusion_matrix(labels, y_pred, labels=np.arange(8))
row_ind, col_ind = linear_sum_assignment(-conf_mat)
mapping = {pred_label: true_label for true_label, pred_label in zip(row_ind, col_ind)}
y_pred = np.array([mapping[p] for p in y_pred])
adata.obs["pred"]= y_pred

In [9]:
pred = y_pred

ARI_s = adjusted_rand_score(pred, labels)
acc_s = accuracy_score(pred, labels)
print("ARI_s:", ARI_s)
print("acc_s", acc_s)

ARI_s: 0.26905598555354493
acc_s 0.5661194299116002


In [23]:
folder_path = fr'C:\E\JSU\BIO\file\STrafer\params\merfish\GraphST\{slicename}'
os.makedirs(folder_path, exist_ok=True)
pred_labels_list = pd.DataFrame({
    'spot': list(range(1, len(adata.obs['domain']) + 1)),
    'pred': y_pred
})
file_path_pred = os.path.join(folder_path, 'pred_labels.csv')
pred_labels_list.to_csv(file_path_pred, index=False)
# pred = adata.obs['domain']

pred = pd.read_csv(fr"C:\E\JSU\BIO\file\STrafer\params\merfish\GraphST\{slicename}\pred_labels.csv",delimiter=',')


In [8]:
import matplotlib.lines as mlines
import matplotlib.pyplot as plt

slicename = '25'
pred = pd.read_csv(fr"C:\E\JSU\BIO\file\STrafer\params\merfish\GraphST\{slicename}\pred_labels.csv",delimiter=',')

color_mapping = {
    0: "#5698D3",  # 蓝色
    1: "#C2A16C",  # 浅褐色
    2: "#6F6DAF",  # 深紫
    3: "#8B2A2A",  # 暗红
    4: "#A65B8D",  # 紫红
    5: "#E6C44D",  # 金黄
    6: "#66B2A6",  # 青绿色
    7: "#D17CA3",  # 粉紫
}

adata = sc.read_h5ad(fr"C:\E\JSU\BIO\file\STrafer\params\merfish\{slicename}.h5ad")
label_mapping = {'MPA': 1, 'MPN': 2, 'BST': 3, 'fx': 4, "PVH":5,"PVT":6,"V3":7, 'PV': 0}
labels = adata.obs['ground_truth'].map(label_mapping)

adata.var_names_make_unique()
# pre-process

adata.layers['count'] = adata.X
sc.pp.filter_genes(adata, min_cells=50)
sc.pp.filter_genes(adata, min_counts=10)
sc.pp.normalize_total(adata, target_sum=1e6)
sc.pp.highly_variable_genes(adata, flavor="seurat_v3", layer='count', n_top_genes=150)
adata = adata[:, adata.var['highly_variable'] == True]
sc.pp.scale(adata)

In [20]:
y_pred = pd.read_csv(fr"C:\E\JSU\BIO\file\STrafer\params\merfish\GraphST\{slicename}\pred_labels.csv",delimiter=',')
y_pred = y_pred['pred']

In [21]:
conf_mat = confusion_matrix(labels, y_pred, labels=np.arange(8))
row_ind, col_ind = linear_sum_assignment(-conf_mat)
mapping = {pred_label: true_label for true_label, pred_label in zip(row_ind, col_ind)}
y_pred = np.array([mapping[p] for p in y_pred])
adata.obs["pred"] = y_pred
pred = y_pred

ARI = adjusted_rand_score(pred, labels)
acc = accuracy_score(pred, labels)
print("ARI_s:", ARI_s)
print("acc_s", acc_s)

ARI_s: 0.1420817587115786
acc_s 0.37518221574344024
