In [14]:
import scanpy as sc
import pandas as pd
from sklearn import metrics
import torch
import matplotlib.pyplot as plt
import seaborn as sns
import os
from pathlib import Path
import warnings
warnings.filterwarnings('ignore')
import SEDR
import matplotlib.lines as mlines
import numpy as np
from sklearn.metrics import adjusted_rand_score
from sklearn.metrics import accuracy_score
from sklearn.metrics import confusion_matrix
from scipy.optimize import linear_sum_assignment


random_seed = 2023
SEDR.fix_seed(random_seed)

In [4]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

In [5]:
slicename = "151507"
n_clusters = 7


In [6]:
spatial_data = pd.read_csv(
    fr"C:\E\JSU\BIO\file\SpaGCN-master\SpaGCN-master\tutorial\{slicename}\spatial\tissue_positions_list.csv", sep=",",
    header=None)
spatial_data.columns = ['barcode', 'in_tissue', 'row', 'col', 'pxl_row_in_fullres', 'pxl_col_in_fullres']
spatial_data = spatial_data[spatial_data['in_tissue'] == 1]  # remain cells within tissue (denoted by 1)
meta = pd.read_csv(fr"C:\E\JSU\BIO\file\SpaGCN-master\SpaGCN-master\tutorial\{slicename}\metadata.tsv", sep="\t",
                   index_col=False)
meta.drop(columns=['row', 'col'], inplace=True)  # delete repeated columns

data = spatial_data.merge(meta, on='barcode', how='right')
data = data[['barcode', 'row', 'col', 'pxl_row_in_fullres', 'pxl_col_in_fullres', 'expr_chrM']]
data = data.dropna(subset=['expr_chrM'])  # delete missing data
# labels
label_mapping = {'L1': 1, 'L2': 2, 'L3': 3, 'L4': 4, 'L5': 5, 'L6': 6, 'WM': 0}  # DLPFC
labels = data['expr_chrM'].map(label_mapping)
# Annotated Data
file_path = fr"C:\E\JSU\BIO\file\SpaGCN-master\SpaGCN-master\tutorial\{slicename}\filtered_feature_bc_matrix.h5"
adata = sc.read_10x_h5(file_path)
adata = adata[data.index]
adata.obsm['spatial'] = data[['row', 'col']].values

In [8]:
from sklearn.decomposition import PCA  # sklearn PCA is used because PCA in scanpy is not stable. 
adata_X = PCA(n_components=200, random_state=42).fit_transform(adata.X)
adata.obsm['X_pca'] = adata_X
### Constructing neighborhood graph
graph_dict = SEDR.graph_construction(adata, 12)
print(graph_dict)

{'adj_norm': tensor(indices=tensor([[   0,   82,  247,  ..., 4020, 4132, 4220],
                       [   0,    0,    0,  ..., 4220, 4220, 4220]]),
       values=tensor([0.0769, 0.0693, 0.0741,  ..., 0.0769, 0.0769, 0.0769]),
       size=(4221, 4221), nnz=55977, layout=torch.sparse_coo), 'adj_label': tensor(indices=tensor([[   0,    0,    0,  ..., 4220, 4220, 4220],
                       [   0,   82,  247,  ..., 4020, 4132, 4220]]),
       values=tensor([1., 1., 1.,  ..., 1., 1., 1.]),
       size=(4221, 4221), nnz=55977, dtype=torch.float64,
       layout=torch.sparse_coo), 'norm_value': 0.5015758523909648}


In [9]:
sedr_net = SEDR.Sedr(adata.obsm['X_pca'], graph_dict, mode='clustering', device=device)
using_dec = True
if using_dec:
    sedr_net.train_with_dec(N=1)
else:
    sedr_net.train_without_dec(N=1)
sedr_feat, _, _, _ = sedr_net.process()
adata.obsm['SEDR'] = sedr_feat

100%|██████████| 200/200 [00:03<00:00, 52.23it/s] 
100%|██████████| 200/200 [00:01<00:00, 112.99it/s]


In [10]:
SEDR.mclust_R(adata, n_clusters, use_rep='SEDR', key_added='SEDR')

R[write to console]:                    __           __ 
   ____ ___  _____/ /_  _______/ /_
  / __ `__ \/ ___/ / / / / ___/ __/
 / / / / / / /__/ / /_/ (__  ) /_  
/_/ /_/ /_/\___/_/\__,_/____/\__/   version 6.1.1
Type 'citation("mclust")' for citing this R package in publications.



fitting ...


AnnData object with n_obs × n_vars = 4221 × 33538
    obs: 'SEDR'
    var: 'gene_ids', 'feature_types', 'genome'
    obsm: 'spatial', 'X_pca', 'SEDR'

In [12]:
pred = adata.obs['SEDR']
pred = pred.values.astype(int) - 1
conf_mat = confusion_matrix(labels, pred, labels=np.arange(7))
row_ind, col_ind = linear_sum_assignment(-conf_mat)
mapping = {pred_label: true_label for true_label, pred_label in zip(row_ind, col_ind)}
y_pred = np.array([mapping[p] for p in pred])
adata.obs["pred"] = y_pred
ARI_s = adjusted_rand_score(y_pred, labels)
acc_s = accuracy_score(y_pred, labels)
print("ARI_s:", ARI_s)
print("acc_s", acc_s)


ARI_s: 0.45153277809275477
acc_s 0.6226012793176973
