In [1]:
import scanpy as sc
import pandas as pd
from sklearn import metrics
import torch
import matplotlib.pyplot as plt
import seaborn as sns
import os
from pathlib import Path
import warnings
warnings.filterwarnings('ignore')
import SEDR
from anndata import AnnData
from sklearn.metrics import confusion_matrix
from scipy.optimize import linear_sum_assignment
from sklearn.metrics import adjusted_rand_score
from sklearn.metrics import accuracy_score
import numpy as np


random_seed = 2023
SEDR.fix_seed(random_seed)

In [2]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

In [3]:
slicename = 'BZ14'
n_clusters = 4

In [4]:
expr_path = fr"C:\E\JSU\BIO\file\STrafer\params\starmap\starmap_expr_{slicename}.csv"
spatial_path = fr"C:\E\JSU\BIO\file\STrafer\params\starmap\starmap_spatial_{slicename}.csv"
meta = pd.read_csv(expr_path, index_col=0)
spatial_data = pd.read_csv(spatial_path, index_col=0)
data = spatial_data.merge(meta, left_index=True, right_index=True, how='right')
# labels = data['z']
labels = data['z'].replace(4, 0)

adata = AnnData(X=meta.values)
adata.obsm['spatial'] = spatial_data[['x', 'y']].values
adata.var_names = meta.columns
adata.obs_names = spatial_data.index
adata.write_h5ad(fr"C:\E\JSU\BIO\file\STrafer\params\starmap\starmap_{slicename}.h5ad")
adata = sc.read_h5ad(fr"C:\E\JSU\BIO\file\STrafer\params\starmap\starmap_{slicename}.h5ad")
adata = adata[data.index]
adata.var_names_make_unique()
# pre-process
# adata.layers['count'] = adata.X.toarray()
adata.layers['count'] = adata.X
sc.pp.filter_genes(adata, min_cells=50)
sc.pp.filter_genes(adata, min_counts=10)
sc.pp.normalize_total(adata, target_sum=1e6)
sc.pp.highly_variable_genes(adata, flavor="seurat_v3", layer='count', n_top_genes=140)
adata = adata[:, adata.var['highly_variable'] == True]
sc.pp.scale(adata)

In [5]:
adata.obs['layer_guess'] = labels

from sklearn.decomposition import PCA  # sklearn PCA is used because PCA in scanpy is not stable.
adata_X = PCA(n_components=140, random_state=42).fit_transform(adata.X)
adata.obsm['X_pca'] = adata_X

In [6]:
graph_dict = SEDR.graph_construction(adata, 12)
print(graph_dict)

sedr_net = SEDR.Sedr(adata.obsm['X_pca'], graph_dict, mode='clustering', device=device)
using_dec = True
if using_dec:
    sedr_net.train_with_dec(N=1)
else:
    sedr_net.train_without_dec(N=1)
sedr_feat, _, _, _ = sedr_net.process()
adata.obsm['SEDR'] = sedr_feat


SEDR.mclust_R(adata, n_clusters, use_rep='SEDR', key_added='SEDR')


{'adj_norm': tensor(indices=tensor([[   0,    2,    7,  ..., 1084, 1086, 1087],
                       [   0,    0,    0,  ..., 1087, 1087, 1087]]),
       values=tensor([0.0769, 0.0769, 0.0741,  ..., 0.0741, 0.0769, 0.0769]),
       size=(1088, 1088), nnz=15802, layout=torch.sparse_coo), 'adj_label': tensor(indices=tensor([[   0,    0,    0,  ..., 1087, 1087, 1087],
                       [   0,    2,    7,  ..., 1084, 1086, 1087]]),
       values=tensor([1., 1., 1.,  ..., 1., 1., 1.]),
       size=(1088, 1088), nnz=15802, dtype=torch.float64,
       layout=torch.sparse_coo), 'norm_value': 0.5067648907223132}


100%|██████████| 200/200 [00:02<00:00, 78.86it/s] 
100%|██████████| 200/200 [00:01<00:00, 101.08it/s]
R[write to console]:                    __           __ 
   ____ ___  _____/ /_  _______/ /_
  / __ `__ \/ ___/ / / / / ___/ __/
 / / / / / / /__/ / /_/ (__  ) /_  
/_/ /_/ /_/\___/_/\__,_/____/\__/   version 6.1.1
Type 'citation("mclust")' for citing this R package in publications.



fitting ...


AnnData object with n_obs × n_vars = 1088 × 140
    obs: 'layer_guess', 'SEDR'
    var: 'n_cells', 'n_counts', 'highly_variable', 'highly_variable_rank', 'means', 'variances', 'variances_norm', 'mean', 'std'
    uns: 'hvg'
    obsm: 'spatial', 'X_pca', 'SEDR'
    layers: 'count'

In [7]:
y_pred = adata.obs['SEDR'].values.astype(int) - 1
conf_mat = confusion_matrix(labels, y_pred, labels=np.arange(4))
row_ind, col_ind = linear_sum_assignment(-conf_mat)
mapping = {pred_label: true_label for true_label, pred_label in zip(row_ind, col_ind)}
y_pred = np.array([mapping[p] for p in y_pred])

In [8]:
pred = y_pred

ARI_s = adjusted_rand_score(pred, labels)
acc_s = accuracy_score(pred, labels)
print("ARI_s:", ARI_s)
print("acc_s", acc_s)

ARI_s: 0.2673438840753168
acc_s 0.5818014705882353
