In [1]:
# uncomment this if you want to use interactive plot (only works in Jupyter not works in VScode)
# %matplotlib widget

import scanpy as sc
import numpy as np
import pandas as pd

import scSLAT
from scSLAT.model import Cal_Spatial_Net, load_anndatas, run_SLAT, spatial_match
from scSLAT.viz import match_3D_multi, hist, Sankey
from scSLAT.metrics import region_statistics

In [2]:
sc.set_figure_params(scanpy=True, dpi=100, dpi_save=150, frameon=True, vector_friendly=True, fontsize=14)

In [15]:
datasets = ['151675', '151676']
input_dir = 'D:/dataset/'  # Replace it with your file path
output_dir = 'G:/dataset/1_DLPFC/output/SLAT/'

adata_list = []

for dataset in datasets:
    adata = sc.read_visium(input_dir + dataset)
    adata.var_names_make_unique()
    adata.obs_names_make_unique()
    adata_label = pd.read_csv(input_dir + dataset + '/' + 'truth.csv', index_col=0)
    adata.obs['batch'] = adata_label['batch']
    adata.obs['annotation'] = adata_label['ground.truth']
    adata_list.append(adata)

Cal_Spatial_Net(adata_list[0], k_cutoff=10, model='KNN')
Cal_Spatial_Net(adata_list[1], k_cutoff=10, model='KNN')

edges, features = load_anndatas([adata_list[0], adata_list[1]], feature='DPCA', check_order=False)

embd0, embd1, time = run_SLAT(features, edges)


np.savetxt(output_dir + datasets[0] + '_' + datasets[1] + '__' + datasets[0] + '_SLAT_embeddeing.csv',embd0.cpu().detach().numpy(), delimiter=',')
np.savetxt(output_dir + datasets[0] + '_' + datasets[1] + '__' + datasets[1] + '_SLAT_embeddeing.csv',embd1.cpu().detach().numpy(), delimiter=',')


  utils.warn_names_duplicates("var")
  utils.warn_names_duplicates("var")
  utils.warn_names_duplicates("var")
  utils.warn_names_duplicates("var")


Calculating spatial neighbor graph ...
The graph contains 37544 edges, 3592 cells.
10.452115812917596 neighbors per cell on average.
Calculating spatial neighbor graph ...
The graph contains 36220 edges, 3460 cells.
10.468208092485549 neighbors per cell on average.
Use DPCA feature to format graph



See the tutorial for concat at: https://anndata.readthedocs.io/en/latest/concatenation.html
  view_to_actual(adata)
  view_to_actual(adata)
  view_to_actual(adata)


Choose GPU:0 as device
Running
---------- epochs: 1 ----------
---------- epochs: 2 ----------
---------- epochs: 3 ----------
---------- epochs: 4 ----------
---------- epochs: 5 ----------
---------- epochs: 6 ----------
Training model time: 2.06


In [4]:
print(embd0)

tensor([[ 1.8645,  2.8196, -0.3658,  ..., -0.1658, -2.0866,  2.2218],
        [-0.0311,  0.0182, -0.7954,  ...,  0.4941,  0.0616, -0.0624],
        [ 0.3433,  0.6250, -0.2759,  ...,  0.6012,  0.1874,  1.3433],
        ...,
        [ 0.0812, -0.9032,  0.1259,  ...,  0.1541,  0.1325,  1.0335],
        [-0.3637,  0.1392, -0.6755,  ...,  0.5062,  0.4794, -0.3499],
        [ 0.9065,  0.6386,  0.2656,  ...,  0.2948,  0.1795,  1.1589]],
       device='cuda:0', grad_fn=<AddmmBackward0>)


In [None]:
best, index, distance = spatial_match(features, adatas=[adata_list[0],adata_list[1]], reorder=False)

adata1_df = pd.DataFrame({'index': range(embd0.shape[0]),
                        'x': adata_list[0].obsm['spatial'][:,0],
                        'y': adata_list[0].obsm['spatial'][:,1],
                        'celltype': adata_list[0].obs['annotation']})
adata2_df = pd.DataFrame({'index': range(embd1.shape[0]),
                        'x': adata_list[1].obsm['spatial'][:,0],
                        'y': adata_list[1].obsm['spatial'][:,1],
                        'celltype': adata_list[1].obs['annotation']})
print(adata1_df)
print(adata2_df)

matching = np.array([range(index.shape[0]), best])
best_match = distance[:,0]

multi_align = match_3D_multi(adata1_df, adata2_df, matching,meta='celltype',
                            scale_coordinate=True, subsample_size=300)
multi_align.draw_3D(size=[7, 8], line_width=1, point_size=[1.5,1.5], hide_axis=True)

%matplotlib inline
hist(best_match, cut=0.8)

In [18]:
distance

array([[0.7575611 , 0.75636154, 0.74787045, ..., 0.6993324 , 0.69892645,
        0.6970204 ],
       [0.8028911 , 0.5620574 , 0.5379353 , ..., 0.44516423, 0.44315094,
        0.44096938],
       [0.6922504 , 0.6554412 , 0.6330521 , ..., 0.5615372 , 0.5610854 ,
        0.55364746],
       ...,
       [0.8993812 , 0.88761175, 0.87537175, ..., 0.7974684 , 0.7950389 ,
        0.79437107],
       [0.878142  , 0.86335534, 0.8330437 , ..., 0.72919446, 0.72818345,
        0.727426  ],
       [0.66471416, 0.58902645, 0.57814384, ..., 0.48240066, 0.48009697,
        0.47809172]], dtype=float32)