In [None]:
import warnings
warnings.filterwarnings("ignore")
import pandas as pd
import numpy as np
import scanpy as sc
import matplotlib.pyplot as plt
import os
import sys
import time

In [None]:
from stabox.model import STAGATE

In [None]:
counts_file = os.path.join('/mnt/disk1/LZJ/project/STABox/STABox_Data/Stero-seq/Dataset1_LiuLongQi_MouseOlfactoryBulb/Data/RNA_counts.tsv')
coor_file = os.path.join('/mnt/disk1/LZJ/project/STABox/STABox_Data/Stero-seq/Dataset1_LiuLongQi_MouseOlfactoryBulb/position.tsv')

In [None]:
counts = pd.read_csv(counts_file, sep='\t', index_col=0)
coor_df = pd.read_csv(coor_file, sep='\t')
print(counts.shape, coor_df.shape)

In [None]:
counts.columns = ['Spot_'+str(x) for x in counts.columns]
coor_df.index = coor_df['label'].map(lambda x: 'Spot_'+str(x))
coor_df = coor_df.loc[:, ['x','y']]

In [None]:
coor_df.head()

In [None]:
adata = sc.AnnData(counts.T)
adata.var_names_make_unique()

In [None]:
adata

In [None]:
coor_df = coor_df.loc[adata.obs_names, ['y', 'x']]
adata.obsm["spatial"] = coor_df.to_numpy()
sc.pp.calculate_qc_metrics(adata, inplace=True)

In [None]:
plt.rcParams["figure.figsize"] = (5,4)
sc.pl.embedding(adata, basis="spatial", color="n_genes_by_counts", show=False)
plt.title("")
plt.axis('off')

In [None]:
used_barcode = pd.read_csv(os.path.join('/mnt/disk1/LZJ/project/STABox/STABox_Data/Stero-seq/Dataset1_LiuLongQi_MouseOlfactoryBulb/used_barcodes.txt'), sep='\t', header=None)
used_barcode = used_barcode[0]
adata = adata[used_barcode,]

In [None]:
plt.rcParams["figure.figsize"] = (5,4)
sc.pl.embedding(adata, basis="spatial", color="n_genes_by_counts", show=False)
plt.title("")
plt.axis('off')

In [None]:
sc.pp.filter_genes(adata, min_cells=50)
print('After flitering: ', adata.shape)

In [None]:
#Normalization
sc.pp.highly_variable_genes(adata, flavor="seurat_v3", n_top_genes=3000)
sc.pp.normalize_total(adata, target_sum=1e4)
sc.pp.log1p(adata)

In [None]:
from stabox.model._utils import Cal_Spatial_Net, Stats_Spatial_Net

In [None]:
Cal_Spatial_Net(adata, rad_cutoff=50)
Stats_Spatial_Net(adata)

In [None]:
stagate_ = STAGATE(model_dir="/mnt/disk1/LZJ/project/STABox/lzj/LZJ/project/STABox/STABox_Data/Stero-seq", in_features=3000, hidden_dims=[512, 30])


In [None]:
adata=stagate_.train(adata)

In [None]:
adata=stagate_.train_subgraph(adata)

In [None]:
sc.pp.neighbors(adata, use_rep='STAGATE')
sc.tl.umap(adata)

In [None]:
sc.tl.louvain(adata, resolution=0.8)

In [None]:
plt.rcParams["figure.figsize"] = (3, 3)
sc.pl.embedding(adata, basis="spatial", color="louvain",s=6, show=False, title='STAGATE')
plt.axis('off')

In [None]:
sc.pl.umap(adata, color='louvain', title='STAGATE')

In [None]:
sc.pp.pca(adata, n_comps=30)

In [None]:
sc.pp.neighbors(adata, use_rep='X_pca')
sc.tl.louvain(adata, resolution=0.8)
sc.tl.umap(adata)

In [None]:
plt.rcParams["figure.figsize"] = (3, 3)
sc.pl.embedding(adata, basis="spatial", color="louvain",s=6, show=False, title='SCANPY')
plt.axis('off')

In [None]:
sc.pl.umap(adata, color='louvain', title='SCANPY')