## Preparation

In [None]:
import sys
sys.path

In [None]:
#-*- coding : utf-8-*-
# coding:unicode_escape
import warnings
warnings.filterwarnings("ignore")

# import ST_utils
# import train_STAligner
import STAligner

# the location of R (used for the mclust clustering)
import os
os.environ['R_HOME'] = "D:\\anaconda\envs\STAligner\Lib\R"
os.environ['R_USER'] = "D:\\anaconda\envs\STAligner\Lib\site-packages\rpy2"

In [None]:
import rpy2.robjects as robjects
import rpy2.robjects.numpy2ri

In [None]:
import anndata as ad
import scanpy as sc
import pandas as pd
import numpy as np
import scipy.sparse as sp
import scipy.linalg

import torch
used_device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print(used_device)

In [None]:
sample_names = ["10X","slide","stereo"]
input_dir = 'G:/dataset/06-Mouse olfactory bulb/input/25um/'
output_dir = 'G:/dataset/06-Mouse olfactory bulb/output/25um/STAligner/'
experiment_name = 'MouseOlfactoryBulb25um'

## Load Data

In [None]:
from scipy import sparse

Batch_list = []
adj_list = []

for section_id in sample_names:
    print(section_id)

    adata = sc.read_h5ad(input_dir + section_id+'.h5ad')
    print(adata)
    df = adata.obs[['x','y']].astype('float32')
    adata.obsm['spatial'] = df.values
    print(adata.obs.head())
    
    # make spot name unique
    adata.obs_names = [x + '_' + section_id for x in adata.obs_names]

    # Constructing the spatial network
    if(section_id == '10X'):
        STAligner.Cal_Spatial_Net(adata, rad_cutoff=200)
    elif(section_id == 'slide'):
        STAligner.Cal_Spatial_Net(adata, rad_cutoff=50)
    else:
        STAligner.Cal_Spatial_Net(adata, rad_cutoff=3)
    # the spatial network are saved in adata.uns[‘adj’]
    # STAligner.Stats_Spatial_Net(adata) # plot the number of spatial neighbors

    # Normalization
    sc.pp.highly_variable_genes(adata, flavor="seurat_v3", n_top_genes=5000)
    sc.pp.normalize_total(adata, target_sum=1e4)
    sc.pp.log1p(adata)
    adata = adata[:, adata.var['highly_variable']]

    adj_list.append(adata.uns['adj'])
    Batch_list.append(adata)

## Concat the scanpy objects for multiple slices

In [None]:
adata_concat = ad.concat(Batch_list, label="slice_name", keys=sample_names)
# adata_concat.obs['Ground Truth'] = adata_concat.obs['Ground Truth'].astype('category')
adata_concat.obs["batch_name"] = adata_concat.obs["slice_name"].astype('category')
print('adata_concat.shape: ', adata_concat.shape)

## Concat the spatial network for multiple slices

In [None]:
adj_concat = np.asarray(adj_list[0].todense())
for batch_id in range(1,len(sample_names)):
    adj_concat = scipy.linalg.block_diag(adj_concat, np.asarray(adj_list[batch_id].todense()))
adata_concat.uns['edgeList'] = np.nonzero(adj_concat)

## Running STAligner

In [None]:
%%time
adata_concat = STAligner.train_STAligner(adata_concat, verbose=True, knn_neigh = 50, device=used_device)

# save embedding

In [None]:
# save embedding
np.savetxt(output_dir + experiment_name + '_STAligner.csv', adata_concat.obsm['STAligner'], delimiter=",")
np.savetxt(output_dir + experiment_name + '_STAGATE.csv', adata_concat.obsm['STAGATE'], delimiter=",")

## Clustering

In [None]:
adata_concat

In [None]:
# import seaborn as sns
# slice_colors = sns.color_palette(n_colors = 6).as_hex()
series = adata_concat.obs['Ground Truth'].astype("category")

celltype_num = len(series.value_counts().index)
celltypes = series.value_counts().index.tolist()
print(celltype_num)

import seaborn as sns
colors = sns.color_palette(n_colors = celltype_num).as_hex()

In [None]:
num_cluster = celltype_num

In [None]:
from STAligner import ST_utils
ST_utils.mclust_R(adata_concat, num_cluster=num_cluster, used_obsm='STAligner')
# adata_concat = adata_concat[adata_concat.obs['Ground Truth']!='unknown']

In [None]:
from sklearn.metrics import adjusted_rand_score as ari_score
print('mclust, ARI = %01.3f' % ari_score(adata_concat.obs['Ground Truth'], adata_concat.obs['mclust']))

## Visualization

In [None]:
sc.pp.neighbors(adata_concat, use_rep='STAligner', random_state=666)
sc.tl.umap(adata_concat, random_state=666)

# section_color = ['#f8766d', '#7cae00', '#00bfc4', '#c77cff']
section_color = sns.color_palette(n_colors = len(sample_names)).as_hex()
section_color_dict = dict(zip(sample_names, section_color))
adata_concat.uns['batch_name_colors'] = [section_color_dict[x] for x in adata_concat.obs.batch_name.cat.categories]
adata_concat.obs['mclust'] = pd.Series(ST_utils.match_cluster_labels(adata_concat.obs['Ground Truth'], adata_concat.obs['mclust'].values),
                                         index=adata_concat.obs.index, dtype='category')

import matplotlib.pyplot as plt
plt.rcParams['font.sans-serif'] = "Arial"
plt.rcParams["figure.figsize"] = (3, 3)
plt.rcParams['font.size'] = 12

sc.pl.umap(adata_concat, color=['batch_name', 'Ground Truth', 'mclust'], ncols=3,
           wspace=0.5, show=False)

plt.savefig(output_dir + experiment_name + '_umap.png', dpi=300)

In [None]:
Batch_list = []
for section_id in sample_names:
    Batch_list.append(adata_concat[adata_concat.obs['batch_name'] == section_id])

import matplotlib.pyplot as plt
spot_size = 1
title_size = 12
ARI_list = []
for bb in range(len(Batch_list)):
    ARI_list.append(round(ari_score(Batch_list[bb].obs['Ground Truth'], Batch_list[bb].obs['mclust']), 2))

fig, ax = plt.subplots(1, len(Batch_list), figsize=(10, 5), gridspec_kw={'wspace': 0.05, 'hspace': 0.1})
_sc_0 = sc.pl.spatial(Batch_list[0], img_key=None, color=['mclust'], title=[''],
                      legend_loc=None, legend_fontsize=12, show=False, ax=ax[0], frameon=False,
                      spot_size=spot_size)
_sc_0[0].set_title("ARI=" + str(ARI_list[0]), size=title_size)
_sc_1 = sc.pl.spatial(Batch_list[1], img_key=None, color=['mclust'], title=[''],
                      legend_loc=None, legend_fontsize=12, show=False, ax=ax[1], frameon=False,
                      spot_size=spot_size)
_sc_1[0].set_title("ARI=" + str(ARI_list[1]), size=title_size)

plt.savefig(output_dir + experiment_name + '_ARI.png', dpi=300)
plt.show()

# save h5ad

In [None]:
print(type(adata_concat))
print(adata_concat.obsm)
print(adata_concat)

adata_concat.obs = adata_concat.obs.astype('str')
#  tuple 不能保存为h5ad
adata_concat.uns['edgeList'] = list(adata_concat.uns['edgeList'])

print(adata_concat.isbacked)
adata_concat.filename = output_dir + experiment_name + '.h5ad'
print(adata_concat.isbacked)