### 1. Import SpaceFlow and squidpy package

In [12]:
import warnings
warnings.filterwarnings('ignore')
import squidpy as sq
import SpaceFlow
import scanpy as sc
import pandas as pd


In [13]:
from SpaceFlow import SpaceFlow 

### 2. Load  data

In [14]:
# section_id = '151671'
# input_dir = os.path.join('/home/workspace2/zhaofangyuan/SpaceFlow/SpaceFlow-master/Data', section_id)
# adata = sc.read_visium(path=input_dir, count_file=section_id+'_filtered_feature_bc_matrix.h5')
adata=sc.read_h5ad('/home/workspace2/zhaofangyuan/data_h5ad/without_groundtruth/Visium_demon/V1_Mouse_Brain_Sagittal_Posterior_Section_2.h5ad')
adata.var_names_make_unique()
adata

AnnData object with n_obs × n_vars = 3289 × 32285
    obs: 'in_tissue', 'array_row', 'array_col'
    var: 'gene_ids', 'feature_types', 'genome'
    uns: 'spatial'
    obsm: 'spatial'

### 3. Create SpaceFlow Object

We create a SpaceFlow object using the count matrix of gene expression and the corresponding spatial locations of cells (or spots):

In [15]:
sf = SpaceFlow.SpaceFlow(adata=adata)
sf

<SpaceFlow.SpaceFlow.SpaceFlow at 0x7f2aac11e690>

Parameters:
- `expr_data`: the count matrix of gene expression, 2D numpy array of size (# of cells, # of genes)
- `spatial_locs`: spatial locations of cells (or spots) match to rows of the count matrix, 1D numpy array of size (n_locations,)

### 4. Preprocessing the ST Data
Next, we preprocess the ST data by run:

In [5]:
sf.preprocessing_data(n_top_genes=3000)

Parameters:
- `n_top_genes`: the number of the top highly variable genes.

The preprocessing includes the normalization and log-transformation of the expression count matrix, the selection of highly variable genes, and the construction of spatial proximity graph using spatial coordinates. (Details see the `preprocessing_data` function in `SpaceFlow/SpaceFlow.py`)


### 5. Train the deep graph network model

We then train a spatially regularized deep graph network model to learn a low-dimensional embedding that reflecting both expression similarity and the spatial proximity of cells in ST data.  


In [6]:
sf.train(spatial_regularization_strength=0.1, 
         z_dim=50, 
         lr=1e-3, 
         epochs=1000, 
         max_patience=50, 
         min_stop=100, 
         random_seed=42, 
         gpu=1, 
         regularization_acceleration=True, 
         edge_subset_sz=1000000)

Epoch 2/1000, Loss: 10.189188957214355
Epoch 12/1000, Loss: 1.6195722818374634
Epoch 22/1000, Loss: 1.516485571861267
Epoch 32/1000, Loss: 1.4355510473251343
Epoch 42/1000, Loss: 1.3794041872024536
Epoch 52/1000, Loss: 1.2694107294082642
Epoch 62/1000, Loss: 1.0864626169204712
Epoch 72/1000, Loss: 0.836830198764801
Epoch 82/1000, Loss: 0.6618722677230835
Epoch 92/1000, Loss: 0.4019826650619507
Epoch 102/1000, Loss: 0.3247548043727875
Epoch 112/1000, Loss: 0.23060129582881927
Epoch 122/1000, Loss: 0.20460256934165955
Epoch 132/1000, Loss: 0.19117172062397003
Epoch 142/1000, Loss: 0.17767533659934998
Epoch 152/1000, Loss: 0.13407307863235474
Epoch 162/1000, Loss: 0.13061141967773438
Epoch 172/1000, Loss: 0.1133759468793869
Epoch 182/1000, Loss: 0.11186105012893677
Epoch 192/1000, Loss: 0.10810726135969162
Epoch 202/1000, Loss: 0.09334376454353333
Epoch 212/1000, Loss: 0.09975259751081467
Epoch 222/1000, Loss: 0.08768016844987869
Epoch 232/1000, Loss: 0.08581777662038803
Epoch 242/1000, L

array([[-5.1645846 ,  5.642975  , 11.575688  , ...,  9.178689  ,
        -1.9711637 , 10.761907  ],
       [-5.0714226 ,  5.5076046 , 11.209752  , ...,  8.516841  ,
        -1.9660617 , 10.218894  ],
       [-4.9817624 ,  7.862143  ,  7.3670726 , ...,  6.973419  ,
        -2.5149636 ,  6.5497556 ],
       ...,
       [-5.9093585 , 13.145799  ,  1.49898   , ...,  7.506754  ,
        -3.9435444 ,  1.7111741 ],
       [-5.8447895 ,  9.390853  ,  8.1574745 , ...,  7.9970994 ,
        -3.0489066 ,  7.3244643 ],
       [-6.759815  , 16.571222  , -0.97939384, ...,  8.442     ,
        -4.851411  , -0.10696509]], dtype=float32)

Parameters:
- `spatial_regularization_strength`: the strength of spatial regularization, the larger the more of the spatial coherence in the identified spatial domains and spatiotemporal patterns. (default: 0.1)
- `z_dim`: the target size of the learned embedding. (default: 50)
- `lr`: learning rate for optimizing the model. (default: 1e-3)
- `epochs`: the max number of the epochs for model training. (default: 1000)
- `max_patience`: the max number of the epoch for waiting the loss decreasing. If loss does not decrease for epochs larger than this threshold, the learning will stop, and the model with the parameters that shows the minimal loss are kept as the best model. (default: 50) 
- `min_stop`: the earliest epoch the learning can stop if no decrease in loss for epochs larger than the `max_patience`. (default: 100) 
- `random_seed`: the random seed set to the random generators of the `random`, `numpy`, `torch` packages. (default: 42)
-  `gpu`: the index of the Nvidia GPU, if no GPU, the model will be trained via CPU, which is slower than the GPU training time. (default: 0) 
-  `regularization_acceleration`: whether or not accelerate the calculation of regularization loss using edge subsetting strategy (default: True)
-  `edge_subset_sz`: the edge subset size for regularization acceleration (default: 1000000)

### 6. Domain segmentation of the ST data

After the model training, the learned low-dimensional embedding can be accessed through `sf.embedding`.

SpaceFlow will use this learned embedding to identify the spatial domains based on [Leiden](https://www.nature.com/articles/s41598-019-41695-z) algorithm. 


In [7]:
# import pandas as pd

res_list=[0.1,0.5,1,1.5,2,2.5,3,3.5,4,4.5]
for i in range(10):
    sf_copy=sf
    sf_copy.segmentation(domain_label_save_filepath="./2_domains_{}.csv".format(i+1), 
                        n_neighbors=50, 
                        resolution=res_list[i])
    sf.plot_segmentation(segmentation_figure_save_filepath="./2_domain_segmentation_{}.pdf".format(i+1), 
                     colormap="tab20", 
                     scatter_sz=1., 
                     rsz=4., 
                     csz=4., 
                     wspace=.4, 
                     hspace=.5, 
                     left=0.125, 
                     right=0.9, 
                     bottom=0.1, 
                     top=0.9)
# pred=pd.read_csv('./domains.csv')
# pred

Performing domain segmentation
Segmentation complete, domain labels of cells or spots saved at ./2_domains_1.csv !
Plotting complete, segmentation figure saved at ./2_domain_segmentation_1.pdf !
Performing domain segmentation
Segmentation complete, domain labels of cells or spots saved at ./2_domains_2.csv !
Plotting complete, segmentation figure saved at ./2_domain_segmentation_2.pdf !
Performing domain segmentation
Segmentation complete, domain labels of cells or spots saved at ./2_domains_3.csv !
Plotting complete, segmentation figure saved at ./2_domain_segmentation_3.pdf !
Performing domain segmentation
Segmentation complete, domain labels of cells or spots saved at ./2_domains_4.csv !
Plotting complete, segmentation figure saved at ./2_domain_segmentation_4.pdf !
Performing domain segmentation
Segmentation complete, domain labels of cells or spots saved at ./2_domains_5.csv !
Plotting complete, segmentation figure saved at ./2_domain_segmentation_5.pdf !
Performing domain segment

Parameters:

- `domain_label_save_filepath`: the file path for saving the identified domain labels. (default: "./domains.tsv")
- `n_neighbors`: the number of the nearest neighbors for each cell for constructing the graph for Leiden using the embedding as input. (default: 50)
- `resolution`: the resolution of the Leiden clustering, the larger the coarser of the domains. (default: 1.0)


We can also visualize the expert annotation for comparison by:

In [9]:
#无法把sf对象赋给adata，写了新函数，好像传不进去
# import scanpy as sc
# sc.pl.spatial(adata, 
#     color="celltype_mapped_refined",
#     spot_size=0.03)

In [18]:
import numpy as np
adata=sc.read_h5ad('/home/workspace2/zhaofangyuan/data_h5ad/without_groundtruth/Visium_demon/V1_Mouse_Brain_Sagittal_Posterior_Section_2.h5ad')

for i in range(10):

    pred=pd.read_csv('./2_domains_{}.csv'.format(i+1),header=None)
    pred_list=pred.iloc[:,0].to_list()
    
    adata.obs['SpaceFlow_{}'.format(i+1)] = np.array(pred_list)

adata.write_h5ad('./SpaceFlow_V1_Mouse_Brain_Sagittal_Posterior_Section_2(resolution).h5ad')
adata

AnnData object with n_obs × n_vars = 3289 × 32285
    obs: 'in_tissue', 'array_row', 'array_col', 'SpaceFlow_1', 'SpaceFlow_2', 'SpaceFlow_3', 'SpaceFlow_4', 'SpaceFlow_5', 'SpaceFlow_6', 'SpaceFlow_7', 'SpaceFlow_8', 'SpaceFlow_9', 'SpaceFlow_10'
    var: 'gene_ids', 'feature_types', 'genome'
    uns: 'spatial'
    obsm: 'spatial'