In [None]:
import os
import sys

os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':4096:8'

sys.path.insert(0, os.path.abspath('.'))
sys.path.insert(0, os.path.abspath('./src'))

import yaml
import torch
import numpy as np
import random


def fix_seed(seed):
    os.environ['PYTHONHASHSEED'] = str(seed)
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    torch.use_deterministic_algorithms(True)


fix_seed(0)

from src.model import DeepST
from src.data_loader import SpatialDataLoader
from src.trainer import Trainer
from src.evaluator import Evaluator
from src.utils import Transfer_pytorch_Data

In [None]:
# ========== Load Configuration ==========
with open('./config.yaml', 'r') as f:
    config = yaml.safe_load(f)

print('Configuration loaded:')
print(f"  Model: {config['model']['name']}")
print(f"  Pre-epochs: {config['training']['pre_epochs']}")
print(f"  Epochs: {config['training']['epochs']}")
print(f"  Device: {config['training']['device']}")

In [None]:
# ========== Setup Device ==========
device_name = config['training'].get('device', 'cuda')
if device_name == 'cuda' and torch.cuda.is_available():
    device = torch.device('cuda:0')
    print(f'Using GPU: {torch.cuda.get_device_name(0)}')
else:
    device = torch.device('cpu')
    print('Using CPU')

In [None]:
# ========== Data Loading & Preprocessing ==========
loader = SpatialDataLoader(config)
adata, data = loader.load_and_preprocess(config['data']['base_path'])

print(f'Loaded {adata.n_obs} spots x {adata.n_vars} genes')
print(f'Processed data shape: {data.shape}')

In [None]:
# ========== Prepare Graph Data ==========
pyg_data, adj_label, norm = Transfer_pytorch_Data(adata, data)

print(f'Input features: {pyg_data.x.shape}')
print(f'Edge index: {pyg_data.edge_index.shape}')
print(f'Adjacency label: {adj_label.shape}')

In [None]:
# ========== Update Config Dimensions ==========
config['model']['encoder']['architecture']['in_dim'] = pyg_data.x.shape[1]
config['model']['graph_encoder']['architecture']['in_dim'] = config['model']['encoder']['architecture']['hidden_dims'][-1]
config['model']['decoder']['architecture']['in_dim'] = (
    config['model']['encoder']['architecture']['hidden_dims'][-1] +
    config['model']['graph_encoder']['architecture']['out_dim']
)
config['model']['decoder']['architecture']['out_dim'] = pyg_data.x.shape[1]

print('Config dimensions updated')

In [None]:
# ========== Create Model ==========
model = DeepST(config)

print(f"Model: {config['model']['name']} v{config['model']['version']}")
print(f"  Encoder: {config['model']['encoder']['type']}")
print(f"  Graph Encoder: {config['model']['graph_encoder']['type']}")
print(f"  Decoder: {config['model']['decoder']['type']}")
print(f"Total parameters: {sum(p.numel() for p in model.parameters()):,}")

In [None]:
# ========== Model Training ==========
trainer = Trainer(model, config, device)
embeddings = trainer.train(
    x=pyg_data.x,
    edge_index=pyg_data.edge_index,
    adj_label=adj_label,
    norm=norm,
    adata=adata
)

embedding_key = config['output'].get('embedding_key', 'DeepST_embed')
adata.obsm[embedding_key] = embeddings

print(f'Embeddings shape: {embeddings.shape}')

In [None]:
# ========== Clustering & Evaluation ==========
evaluator = Evaluator(config)
adata = evaluator.cluster(adata)
metrics = evaluator.compute_metrics(adata)

print('\n========== Results ==========')
for metric_name, value in metrics.items():
    print(f'{metric_name.upper()}: {value:.4f}')