Tutorials of CPS using DLPFC dataset
===
1. read the dataset
2. construct the graph (construct the multi-scale features, option)
2. train the model
3. generate the arbitrary resolution spots
4. visualize the results

In [1]:
import os,sys
sys.path.append(os.path.realpath(os.path.join(os.getcwd(), '..')))
import warnings
warnings.filterwarnings("ignore")

In [2]:
import scanpy as sc
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import CPS

In [3]:
opt = CPS.config()
args = opt.parse_args(['--seed', '2025'])
args.prep_scale = False
args.dataset_path = '/mnt/d/Dataset/SRT_Dataset/1-DLPFC/'
CPS.set_random_seed(args.seed)
args

Namespace(batch_size=256, clusters=7, coord_dim=2, dataset_path='/mnt/d/Dataset/SRT_Dataset/1-DLPFC/', decoder='MLP', decoder_latent=[256, 512], distill=0.5, dropout=0.2, flow='source_to_target', freq=32, gpu=0, hvgs=3000, inr_latent=[256, 256, 256], k_list=[0, 1, 2, 3, 4, 5, 6, 7], latent_dim=64, lr=0.001, max_epoch=1000, max_neighbors=6, n_spot=0, num_heads=4, prep_scale=False, radius=150, seed=2025, self_loops=True, sh_weights=True, sigma=10.0, visual=True, weight_decay=0.0001)

Read the adata

In [4]:
section = '151676'
adata = sc.read_visium(os.path.join(args.dataset_path+section))
Ann_df = pd.read_csv(os.path.join(args.dataset_path+'1-DLPFC_annotations/', section+'_truth.txt'), 
                     sep='\t', header=None, index_col=0)
Ann_df.columns = ['Ground Truth']
adata.obs['Ground Truth'] = Ann_df.loc[adata.obs_names, 'Ground Truth']
adata

AnnData object with n_obs × n_vars = 3460 × 33538
    obs: 'in_tissue', 'array_row', 'array_col', 'Ground Truth'
    var: 'gene_ids', 'feature_types', 'genome'
    uns: 'spatial'
    obsm: 'spatial'

Preprocess adata

In [5]:
adata.var_names_make_unique()
sc.pp.highly_variable_genes(adata, flavor="seurat_v3", n_top_genes=args.hvgs)
sc.pp.normalize_total(adata, target_sum=1e4)
sc.pp.log1p(adata)
sc.pp.scale(adata, zero_center=True, max_value=10)

adata_hvg = adata[:, adata.var['highly_variable']].copy()
adata.obsm['hvg_features'] = adata_hvg.X

Construct the spatial graph

In [6]:
spatial_edge = CPS.SpatialGraphBuilder(args)
pyg_data = spatial_edge.build_single_graph(adata, method='rknn')
print(pyg_data.num_nodes, pyg_data.num_genes)

3460 3000


Train the CPS model

In [7]:
trainer = CPS.CPSTrainer(args)
trainer.fit(pyg_data)

  0%|          | 0/999 [00:00<?, ?it/s]

1.4713244438171387
1.4097410440444946
1.3893407583236694
1.3783512115478516
1.3678979873657227
1.3578038215637207
1.3489861488342285
1.3418567180633545
1.335758924484253
1.330413818359375
1.3253411054611206
1.3197290897369385
1.3139832019805908
1.3076108694076538
1.3012821674346924
1.2940051555633545
1.2876754999160767
1.2836048603057861
1.2804367542266846
1.276856541633606
1.2738580703735352
1.2715327739715576
1.2696685791015625
1.2685606479644775
1.2674000263214111
1.2668397426605225
1.265986680984497
1.2657361030578613
1.264556646347046
1.2650012969970703
1.2640924453735352
1.2650649547576904
1.2652314901351929
1.2668567895889282
1.2672855854034424
1.2688945531845093
1.269211769104004
1.2702429294586182
1.271840214729309
1.2721827030181885
1.2744202613830566
1.2744550704956055
1.2750498056411743
1.2757185697555542
1.2763829231262207
1.2760915756225586
1.2772095203399658
1.2771894931793213
1.2769213914871216
1.2771360874176025
1.2765833139419556
1.2766450643539429
1.2758382558822632


Downstream analysis