### 1. Import SpaceFlow and squidpy package

In [13]:
import warnings
warnings.filterwarnings('ignore')
import squidpy as sq
import SpaceFlow
import scanpy as sc
import pandas as pd


In [14]:
from SpaceFlow import SpaceFlow 

### 2. Load  data

In [15]:
# section_id = '151671'
# input_dir = os.path.join('/home/workspace2/zhaofangyuan/SpaceFlow/SpaceFlow-master/Data', section_id)
# adata = sc.read_visium(path=input_dir, count_file=section_id+'_filtered_feature_bc_matrix.h5')
adata=sc.read_h5ad('/home/workspace2/zhaofangyuan/data_h5ad/without_groundtruth/Visium_demon/V1_Mouse_Brain_Sagittal_Posterior.h5ad')
adata.var_names_make_unique()
adata

AnnData object with n_obs × n_vars = 3355 × 32285
    obs: 'in_tissue', 'array_row', 'array_col'
    var: 'gene_ids', 'feature_types', 'genome'
    uns: 'spatial'
    obsm: 'spatial'

### 3. Create SpaceFlow Object

We create a SpaceFlow object using the count matrix of gene expression and the corresponding spatial locations of cells (or spots):

In [16]:
sf = SpaceFlow.SpaceFlow(adata=adata)
sf

<SpaceFlow.SpaceFlow.SpaceFlow at 0x7fac8957a490>

Parameters:
- `expr_data`: the count matrix of gene expression, 2D numpy array of size (# of cells, # of genes)
- `spatial_locs`: spatial locations of cells (or spots) match to rows of the count matrix, 1D numpy array of size (n_locations,)

### 4. Preprocessing the ST Data
Next, we preprocess the ST data by run:

In [17]:
sf.preprocessing_data(n_top_genes=3000)

Parameters:
- `n_top_genes`: the number of the top highly variable genes.

The preprocessing includes the normalization and log-transformation of the expression count matrix, the selection of highly variable genes, and the construction of spatial proximity graph using spatial coordinates. (Details see the `preprocessing_data` function in `SpaceFlow/SpaceFlow.py`)


### 5. Train the deep graph network model

We then train a spatially regularized deep graph network model to learn a low-dimensional embedding that reflecting both expression similarity and the spatial proximity of cells in ST data.  


In [18]:
sf.train(spatial_regularization_strength=0.1, 
         z_dim=50, 
         lr=1e-3, 
         epochs=1000, 
         max_patience=50, 
         min_stop=100, 
         random_seed=42, 
         gpu=1, 
         regularization_acceleration=True, 
         edge_subset_sz=1000000)

Epoch 2/1000, Loss: 10.531244277954102
Epoch 12/1000, Loss: 1.5191911458969116
Epoch 22/1000, Loss: 1.5773530006408691
Epoch 32/1000, Loss: 1.478605031967163
Epoch 42/1000, Loss: 1.4916328191757202
Epoch 52/1000, Loss: 1.4785279035568237
Epoch 62/1000, Loss: 1.4595776796340942
Epoch 72/1000, Loss: 1.4390500783920288
Epoch 82/1000, Loss: 1.4085972309112549
Epoch 92/1000, Loss: 1.4495173692703247
Epoch 102/1000, Loss: 1.3247249126434326
Epoch 112/1000, Loss: 1.1839287281036377
Epoch 122/1000, Loss: 1.0184615850448608
Epoch 132/1000, Loss: 1.1569331884384155
Epoch 142/1000, Loss: 1.1278963088989258
Epoch 152/1000, Loss: 0.7133669257164001
Epoch 162/1000, Loss: 0.5549595952033997
Epoch 172/1000, Loss: 0.4323069155216217
Epoch 182/1000, Loss: 0.38233083486557007
Epoch 192/1000, Loss: 0.3534121811389923
Epoch 202/1000, Loss: 0.2911977171897888
Epoch 212/1000, Loss: 0.25484201312065125
Epoch 222/1000, Loss: 0.23788391053676605
Epoch 232/1000, Loss: 0.20471474528312683
Epoch 242/1000, Loss: 0.

array([[ -1.4690933,  -4.968678 ,  24.578875 , ...,  10.1328745,
         -6.662312 ,   1.384343 ],
       [ -1.7599701,  -4.7800403,  39.92357  , ...,  13.512619 ,
        -11.591892 ,   9.435382 ],
       [ -1.9447842,  -4.8705354,  37.380737 , ...,  13.049883 ,
         -9.610833 ,   8.610879 ],
       ...,
       [ -1.9743259,  -5.0803266,  42.98103  , ...,  14.443941 ,
        -12.594458 ,  10.349493 ],
       [ -1.9314985,  -5.074664 ,  42.21298  , ...,  14.327375 ,
        -12.289925 ,  10.132979 ],
       [ -1.878352 ,  -5.5857487,  37.157997 , ...,  13.330538 ,
        -10.15905  ,   6.477151 ]], dtype=float32)

Parameters:
- `spatial_regularization_strength`: the strength of spatial regularization, the larger the more of the spatial coherence in the identified spatial domains and spatiotemporal patterns. (default: 0.1)
- `z_dim`: the target size of the learned embedding. (default: 50)
- `lr`: learning rate for optimizing the model. (default: 1e-3)
- `epochs`: the max number of the epochs for model training. (default: 1000)
- `max_patience`: the max number of the epoch for waiting the loss decreasing. If loss does not decrease for epochs larger than this threshold, the learning will stop, and the model with the parameters that shows the minimal loss are kept as the best model. (default: 50) 
- `min_stop`: the earliest epoch the learning can stop if no decrease in loss for epochs larger than the `max_patience`. (default: 100) 
- `random_seed`: the random seed set to the random generators of the `random`, `numpy`, `torch` packages. (default: 42)
-  `gpu`: the index of the Nvidia GPU, if no GPU, the model will be trained via CPU, which is slower than the GPU training time. (default: 0) 
-  `regularization_acceleration`: whether or not accelerate the calculation of regularization loss using edge subsetting strategy (default: True)
-  `edge_subset_sz`: the edge subset size for regularization acceleration (default: 1000000)

### 6. Domain segmentation of the ST data

After the model training, the learned low-dimensional embedding can be accessed through `sf.embedding`.

SpaceFlow will use this learned embedding to identify the spatial domains based on [Leiden](https://www.nature.com/articles/s41598-019-41695-z) algorithm. 


In [19]:
# import pandas as pd

res_list=[0.1,0.5,1,1.5,2,2.5,3,3.5,4,4.5]
for i in range(10):
    sf_copy=sf
    sf_copy.segmentation(domain_label_save_filepath="./domains_{}.csv".format(i+1), 
                        n_neighbors=50, 
                        resolution=res_list[i])
                        
    sf.plot_segmentation(segmentation_figure_save_filepath="./domain_segmentation_{}.pdf".format(i+1), 
                     colormap="tab20", 
                     scatter_sz=1., 
                     rsz=4., 
                     csz=4., 
                     wspace=.4, 
                     hspace=.5, 
                     left=0.125, 
                     right=0.9, 
                     bottom=0.1, 
                     top=0.9)
# pred=pd.read_csv('./domains.csv')
# pred

Performing domain segmentation
Segmentation complete, domain labels of cells or spots saved at ./domains_1.csv !
Plotting complete, segmentation figure saved at ./domain_segmentation_1.pdf !
Performing domain segmentation
Segmentation complete, domain labels of cells or spots saved at ./domains_2.csv !
Plotting complete, segmentation figure saved at ./domain_segmentation_2.pdf !
Performing domain segmentation
Segmentation complete, domain labels of cells or spots saved at ./domains_3.csv !
Plotting complete, segmentation figure saved at ./domain_segmentation_3.pdf !
Performing domain segmentation
Segmentation complete, domain labels of cells or spots saved at ./domains_4.csv !
Plotting complete, segmentation figure saved at ./domain_segmentation_4.pdf !
Performing domain segmentation
Segmentation complete, domain labels of cells or spots saved at ./domains_5.csv !
Plotting complete, segmentation figure saved at ./domain_segmentation_5.pdf !
Performing domain segmentation
Segmentation c

Parameters:

- `domain_label_save_filepath`: the file path for saving the identified domain labels. (default: "./domains.tsv")
- `n_neighbors`: the number of the nearest neighbors for each cell for constructing the graph for Leiden using the embedding as input. (default: 50)
- `resolution`: the resolution of the Leiden clustering, the larger the coarser of the domains. (default: 1.0)


In [20]:
sf

<SpaceFlow.SpaceFlow.SpaceFlow at 0x7fac8957a490>

We can also visualize the expert annotation for comparison by:

In [21]:
#无法把sf对象赋给adata，写了新函数，好像传不进去
# import scanpy as sc
# sc.pl.spatial(adata, 
#     color="celltype_mapped_refined",
#     spot_size=0.03)

In [42]:
import numpy as np
adata=sc.read_h5ad('/home/workspace2/zhaofangyuan/data_h5ad/without_groundtruth/Visium_demon/V1_Mouse_Brain_Sagittal_Posterior.h5ad')

for i in range(10):

    pred=pd.read_csv('./domains_{}.csv'.format(i+1),header=None)
    pred_list=pred.iloc[:,0].to_list()
    
    adata.obs['SpaceFlow_{}'.format(i+1)] = np.array(pred_list)

adata.write_h5ad('./SpaceFlow_V1_Mouse_Brain_Sagittal_Posterior(resolution).h5ad')
adata


AnnData object with n_obs × n_vars = 3355 × 32285
    obs: 'in_tissue', 'array_row', 'array_col', 'SpaceFlow_1', 'SpaceFlow_2', 'SpaceFlow_3', 'SpaceFlow_4', 'SpaceFlow_5', 'SpaceFlow_6', 'SpaceFlow_7', 'SpaceFlow_8', 'SpaceFlow_9', 'SpaceFlow_10'
    var: 'gene_ids', 'feature_types', 'genome'
    uns: 'spatial'
    obsm: 'spatial'