# [1] Basic settings

In [None]:
# Run device, by default, the package is implemented on 'cpu'. We recommend using GPU.
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
# the location of R, which is necessary for mclust algorithm. Please replace the path below with local R installation path
os.environ['R_HOME'] = '/content/R-bag/R-4.0.3'
# the number of clusters
n_clusters = 7
dataset = '151673'
save_model_file = '/content/GAAEST-main/weights.pth'

# [2] Read data

In [None]:
file_fold = '/content/GAAEST-main/Data/' + str(dataset) #please replace 'file_fold' with the download path
adata = sc.read_visium(file_fold, count_file='filtered_feature_bc_matrix.h5', load_images=True)
adata.var_names_make_unique()

# [3] Train model

In [None]:
# define model
model = GAAEST.GAAEST(adata, device=device,epochs=600,alpha=10,beta=1,gama=1,lane=1,save_model_file=save_model_file)

# train model
adata = model.train()

# [4] clustering

In [None]:
from GAAEST.utils import clustering
# set radius to specify the number of neighbors considered during refinement
radius = 50
tool = 'mclust' # mclust, leiden, and louvain

if tool == 'mclust':
   clustering(adata, n_clusters, radius=radius, method=tool, refinement=True) # For DLPFC dataset, we use optional refinement step.
elif tool in ['leiden', 'louvain']:
   clustering(adata, n_clusters, radius=radius, method=tool, start=0.1, end=1.0, increment=0.01, refinement=True)

# [5] add ground_truth

In [None]:
import pandas as pd
# add ground_truth
df_meta = pd.read_csv(file_fold + '/metadata.tsv', sep='\t')
df_meta_layer = df_meta['layer_guess']
adata.obs['ground_truth'] = df_meta_layer.values
# filter out NA nodes
adata = adata[~pd.isnull(adata.obs['ground_truth'])]

# [6] calculate metric

In [None]:
# calculate metric
ARI = metrics.adjusted_rand_score(adata.obs['domain'], adata.obs['ground_truth'])
NMI= metrics.normalized_mutual_info_score(adata.obs['domain'], adata.obs['ground_truth'])
AMI = metrics.adjusted_mutual_info_score(adata.obs['domain'], adata.obs['ground_truth'])
FM = metrics.fowlkes_mallows_score(adata.obs['domain'], adata.obs['ground_truth'])

adata.uns['ARI'] = ARI
adata.uns['NMI'] = NMI
adata.uns['AMI'] = AMI
adata.uns['FM'] = FM

print('Dataset:', dataset)
print('ARI:', ARI)
print('NMI:', NMI)
print('AMI:', AMI)
print('FM:', FM)

# [7] plotting spatial clustering result


Spatial domian recognition

In [None]:
# plotting spatial clustering result
sc.pl.spatial(adata,
              img_key="hires",
              color=["ground_truth", "domain"],
              title=["ground_truth", "ARI=%.4f"%ARI+" NMI=%.4f"%NMI+" AMI=%.4f"%AMI+" FM=%.4f"%FM],
              show=True)

② UMAP and PAGA

In [None]:
#UMAP
sc.pp.neighbors(adata, use_rep='emb')
sc.tl.umap(adata)

In [None]:
used_adata = adata[adata.obs['ground_truth']!='nan',]

In [None]:
##PAGA
sc.tl.paga(used_adata, groups='domain')

In [None]:
import matplotlib.pyplot as plt
plt.rcParams["figure.figsize"] = (4,3)
sc.pl.paga_compare(used_adata, legend_fontsize=10, frameon=False, size=20,
                   title=dataset+'_GAAEST', legend_fontoutline=2, threshold=0.3 ,show=True)