In [1]:
import scanpy as sc
import pandas as pd
from sklearn import metrics
import torch
import matplotlib.pyplot as plt
import seaborn as sns
import os
from pathlib import Path
import warnings
warnings.filterwarnings('ignore')
import SEDR

import numpy as np
from sklearn.metrics import adjusted_rand_score
from sklearn.metrics import accuracy_score
from sklearn.metrics import confusion_matrix
from scipy.optimize import linear_sum_assignment


random_seed = 2023
SEDR.fix_seed(random_seed)

In [2]:
# gpu
device = 'cuda:0' if torch.cuda.is_available() else 'cpu'

# path
slicename = '26'
n_clusters = 8
adata = sc.read_h5ad(fr"C:\E\JSU\BIO\file\STrafer\params\merfish\{slicename}.h5ad")

print("Class labels:", adata.obs['ground_truth'].unique())

label_mapping = {'MPA': 1, 'MPN': 2, 'BST': 3, 'fx': 4, "PVH": 5, "PVT": 6, "V3": 7, 'PV': 0}
labels = adata.obs['ground_truth'].map(label_mapping)


Class labels: ['MPA', 'MPN', 'BST', 'fx', 'PVH', 'PVT', 'V3', 'PV']
Categories (8, object): ['BST', 'MPA', 'MPN', 'PV', 'PVH', 'PVT', 'V3', 'fx']


In [3]:
adata.var_names_make_unique()
# pre-process

adata.layers['count'] = adata.X
sc.pp.filter_genes(adata, min_cells=50)
sc.pp.filter_genes(adata, min_counts=10)
sc.pp.normalize_total(adata, target_sum=1e6)
sc.pp.highly_variable_genes(adata, flavor="seurat_v3", layer='count', n_top_genes=150)


adata = adata[:, adata.var['highly_variable'] == True]
sc.pp.scale(adata)
adata.obs['layer_guess'] = labels

from sklearn.decomposition import PCA  # sklearn PCA is used because PCA in scanpy is not stable.
adata_X = PCA(n_components=150, random_state=42).fit_transform(adata.X)
adata.obsm['X_pca'] = adata_X


In [4]:
graph_dict = SEDR.graph_construction(adata, 12)
print(graph_dict)

sedr_net = SEDR.Sedr(adata.obsm['X_pca'], graph_dict, mode='clustering', device=device)
using_dec = True
if using_dec:
    sedr_net.train_with_dec(N=1)
else:
    sedr_net.train_without_dec(N=1)
sedr_feat, _, _, _ = sedr_net.process()
adata.obsm['SEDR'] = sedr_feat


SEDR.mclust_R(adata, n_clusters, use_rep='SEDR', key_added='SEDR')



# sub_adata = adata[~pd.isnull(adata.obs['layer_guess'])]
# ARI = metrics.adjusted_rand_score(sub_adata.obs['layer_guess'], sub_adata.obs['SEDR'])


{'adj_norm': tensor(indices=tensor([[   0,    2,    3,  ..., 5178, 5554, 5556],
                       [   0,    0,    0,  ..., 5556, 5556, 5556]]),
       values=tensor([0.0714, 0.0741, 0.0714,  ..., 0.0741, 0.0741, 0.0769]),
       size=(5557, 5557), nnz=80015, layout=torch.sparse_coo), 'adj_label': tensor(indices=tensor([[   0,    0,    0,  ..., 5556, 5556, 5556],
                       [   0,    2,    3,  ..., 5178, 5554, 5556]]),
       values=tensor([1., 1., 1.,  ..., 1., 1., 1.]),
       size=(5557, 5557), nnz=80015, dtype=torch.float64,
       layout=torch.sparse_coo), 'norm_value': 0.5012989349366631}


100%|██████████| 200/200 [00:05<00:00, 37.33it/s]
100%|██████████| 200/200 [00:03<00:00, 56.65it/s]
R[write to console]:                    __           __ 
   ____ ___  _____/ /_  _______/ /_
  / __ `__ \/ ___/ / / / / ___/ __/
 / / / / / / /__/ / /_/ (__  ) /_  
/_/ /_/ /_/\___/_/\__,_/____/\__/   version 6.1.1
Type 'citation("mclust")' for citing this R package in publications.



fitting ...


AnnData object with n_obs × n_vars = 5557 × 150
    obs: 'cell_class', 'neuron_class', 'domain', 'Region', 'ground_truth', 'layer_guess', 'SEDR'
    var: 'n_cells', 'n_counts', 'highly_variable', 'highly_variable_rank', 'means', 'variances', 'variances_norm', 'mean', 'std'
    uns: 'domain_colors', 'hvg'
    obsm: 'spatial', 'X_pca', 'SEDR'
    layers: 'count'

In [9]:
pred = adata.obs['SEDR']
pred = pred.values.astype(int) - 1
conf_mat = confusion_matrix(labels, pred, labels=np.arange(8))
row_ind, col_ind = linear_sum_assignment(-conf_mat)
mapping = {pred_label: true_label for true_label, pred_label in zip(row_ind, col_ind)}
y_pred = np.array([mapping[p] for p in pred])
adata.obs["pred"] = y_pred

In [10]:
ARI_s = adjusted_rand_score(pred, labels)
acc_s = accuracy_score(pred, labels)
print("ARI_s:", ARI_s)
print("acc_s", acc_s)
folder_path = fr'C:\E\JSU\BIO\file\STrafer\params\merfish\SEDR\{slicename}'
os.makedirs(folder_path, exist_ok=True)
pred_labels_list = pd.DataFrame({
    'spot': list(range(1, len(adata.obs['SEDR']) + 1)),
    'pred': y_pred
})
file_path_pred = os.path.join(folder_path, 'pred_labels.csv')
pred_labels_list.to_csv(file_path_pred, index=False)
# pred = adata.obs['domain']

pred = pd.read_csv(fr"C:\E\JSU\BIO\file\STrafer\params\merfish\SEDR\{slicename}\pred_labels.csv", delimiter=',')


ARI_s: 0.323758828553087
acc_s 0.3904984703976966
