In [1]:
import os
import torch
import pandas as pd
import scanpy as sc
import numpy as np
from sklearn import metrics
import multiprocessing as mp
from sklearn.metrics import adjusted_rand_score
from GraphST import GraphST
from sklearn.metrics import accuracy_score
from anndata import AnnData
from sklearn.metrics import confusion_matrix
from scipy.optimize import linear_sum_assignment
import matplotlib.pyplot as plt
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
# the location of R, which is necessary for mclust algorithm. Please replace the path below with local R installation path
os.environ['R_HOME'] = r"C:\Program Files\R\R-4.4.2"

In [2]:
slicename = "BZ9"
n_clusters = 4

In [3]:
expr_path = fr"C:\E\JSU\BIO\file\STrafer\params\starmap\starmap_expr_{slicename}.csv"
spatial_path = fr"C:\E\JSU\BIO\file\STrafer\params\starmap\starmap_spatial_{slicename}.csv"
meta = pd.read_csv(expr_path, index_col=0)
spatial_data = pd.read_csv(spatial_path, index_col=0)
data = spatial_data.merge(meta, left_index=True, right_index=True, how='right')
labels = data['z']
labels = data['z'].replace(4, 0)

adata = AnnData(X=meta.values)
adata.obsm['spatial'] = spatial_data[['x', 'y']].values
adata.var_names = meta.columns
adata.obs_names = spatial_data.index
adata.write_h5ad(fr"C:\E\JSU\BIO\file\STrafer\params\starmap\starmap_{slicename}.h5ad")
adata = sc.read_h5ad(fr"C:\E\JSU\BIO\file\STrafer\params\starmap\starmap_{slicename}.h5ad")
adata = adata[data.index]
adata.var_names_make_unique()
# pre-process
# adata.layers['count'] = adata.X.toarray()
adata.layers['count'] = adata.X
sc.pp.filter_genes(adata, min_cells=50)
sc.pp.filter_genes(adata, min_counts=10)
sc.pp.normalize_total(adata, target_sum=1e6)
sc.pp.highly_variable_genes(adata, flavor="seurat_v3", layer='count', n_top_genes=150)
adata = adata[:, adata.var['highly_variable'] == True]
sc.pp.scale(adata)

  view_to_actual(adata)


In [4]:
# define model
model = GraphST.GraphST(adata, device=device)
# train model
adata = model.train()

Begin to train ST data...


100%|██████████| 600/600 [00:02<00:00, 218.57it/s]

Optimization finished for ST data!





In [5]:
radius = 50

tool = 'mclust' # mclust, leiden, and louvain

# clustering
from GraphST.utils import clustering

if tool == 'mclust':
   clustering(adata, n_clusters, radius=radius, method=tool, refinement=True) # For DLPFC dataset, we use optional refinement step.
elif tool in ['leiden', 'louvain']:
   clustering(adata, n_clusters, radius=radius, method=tool, start=0.1, end=2.0, increment=0.01, refinement=False)

R[write to console]:                    __           __ 
   ____ ___  _____/ /_  _______/ /_
  / __ `__ \/ ___/ / / / / ___/ __/
 / / / / / / /__/ / /_/ (__  ) /_  
/_/ /_/ /_/\___/_/\__,_/____/\__/   version 6.1.1
Type 'citation("mclust")' for citing this R package in publications.



fitting ...


In [6]:
y_pred = adata.obs['domain'].values.astype(int)-1

In [7]:
conf_mat = confusion_matrix(labels, y_pred, labels=np.arange(4))
row_ind, col_ind = linear_sum_assignment(-conf_mat)
mapping = {pred_label: true_label for true_label, pred_label in zip(row_ind, col_ind)}
y_pred = np.array([mapping[p] for p in y_pred]) 

In [8]:
 
ARI_s = adjusted_rand_score(y_pred, labels)
acc_s = accuracy_score(y_pred, labels)
print("ARI_s:", ARI_s)
print("acc_s", acc_s)

ARI_s: 0.5627355183119359
acc_s 0.8138651471984806
