# SIGNN 数据 → ego_nets_plus 子图（参考 ESAN）

用最少代码：加载 SIGNN 一张图 → 按 ESAN 的 ego_nets_plus 生成子图 → 拼成 SubgraphData。

In [2]:
import sys
from pathlib import Path
ROOT = Path.cwd() if (Path.cwd() / "src").is_dir() else Path.cwd().parent
sys.path.insert(0, str(ROOT))

import torch
from torch_geometric.data import Data, Batch
from src.data.multi_grid_dataloader import MultiGridPowerGridDataset
from src.data.subgraph import ego_nets_plus_subgraphs

In [3]:
# 使用 src.data.subgraph 的官方 API（与 trainer 一致）
def to_subgraph_data(sample, num_hops=2):
    """将 SIGNN 的 PowerGridGraphData 转为批量子图 Data（用于可视化/理解结构）。"""
    subs = ego_nets_plus_subgraphs(sample, num_hops=num_hops)
    if not subs:
        return None
    n = sample.num_nodes
    # 把 list of (x_plus, ei, ea, y) 转成 PyG Data list 再 batch
    data_list = [
        Data(x=x_plus, edge_index=ei, edge_attr=ea, y=y_sub)
        for x_plus, ei, ea, y_sub in subs
    ]
    batch = Batch.from_data_list(data_list)
    return Data(
        x=batch.x,
        edge_index=batch.edge_index,
        edge_attr=batch.edge_attr,
        subgraph_batch=batch.batch,
        y=sample.y,
        num_subgraphs=torch.tensor(len(subs)),
        num_nodes_per_subgraph=torch.full((len(subs),), n, dtype=torch.long, device=sample.x.device),
    )

In [4]:
dataset = MultiGridPowerGridDataset(str(ROOT / "dataset"))
sample = dataset[0]
print("SIGNN 原图:", sample)
print("  x:", sample.x.shape, " edge_index:", sample.edge_index.shape, " y:", sample.y.shape)

INFO:src.data.multi_grid_dataloader:Loading CSV files...
INFO:src.data.multi_grid_dataloader:Found 159 grids, using 159 grids
INFO:src.data.multi_grid_dataloader:Using 11 scenarios: ['status_2', 'status_1', 'status_6', 'status_4', 'status_5', 'status_7', 'status_0', 'status_9', 'status_8', 'status_10', 'status_3']
INFO:src.data.multi_grid_dataloader:Total samples: 1749 (159 grids × 11 scenarios)
INFO:src.data.multi_grid_dataloader:Prepared data for 159 grids


SIGNN 原图: PowerGridGraphData(grid_id=0, scenario=status_2, num_nodes=1237, num_edges=1204)
  x: torch.Size([1237, 29])  edge_index: torch.Size([2, 1204])  y: torch.Size([1204])


In [5]:
sg = to_subgraph_data(sample, num_hops=2)
if sg is None:
    print("该样本无边，跳过")
else:
    print("SubgraphData (ego_nets_plus):")
    print("  x:", sg.x.shape, "(节点特征前 2 维为中心标记)")
    print("  edge_index:", sg.edge_index.shape)
    print("  subgraph_batch:", sg.subgraph_batch.shape)
    print("  num_subgraphs:", sg.num_subgraphs.item() if hasattr(sg.num_subgraphs, 'item') else sg.num_subgraphs)
    print("  num_nodes_per_subgraph:", sg.num_nodes_per_subgraph.shape)
    print("  y (沿用原图边标签):", sg.y.shape)

SubgraphData (ego_nets_plus):
  x: torch.Size([1530169, 31]) (节点特征前 2 维为中心标记)
  edge_index: torch.Size([2, 5468])
  subgraph_batch: torch.Size([1530169])
  num_subgraphs: 1237
  num_nodes_per_subgraph: torch.Size([1237])
  y (沿用原图边标签): torch.Size([1204])


In [6]:
n_sub = sg.num_subgraphs.item() if hasattr(sg.num_subgraphs, 'item') else int(sg.num_subgraphs)
print("数量关系: 1 张原图 →", n_sub, "张子图 (每节点一个 k-hop 邻域)")
print("结构: 所有子图节点已拼成一张大图 → x 行数 =", sg.x.shape[0], "= num_subgraphs × num_nodes =", n_sub * sample.num_nodes)

数量关系: 1 张原图 → 1237 张子图 (每节点一个 k-hop 邻域)
结构: 所有子图节点已拼成一张大图 → x 行数 = 1530169 = num_subgraphs × num_nodes = 1530169


### 在训练中启用子图模式

在 `src/training/run_train.py` 中把 `config.use_subgraph` 设为 `True` 即可用 ego_nets_plus 子图训练：

```python
config.use_subgraph = True   # 启用子图训练
config.subgraph_num_hops = 2 # k-hop 邻域
```

模型输入维度已统一为 `node_features + 2`（中心标记）；非子图模式下由 `pad_center_marker` 补零。