# 导包

In [12]:
import numpy as np
import networkx as nx
import torch
import torch.nn.functional as F
from torch_geometric_temporal.signal import DynamicGraphTemporalSignal
from torch_geometric_temporal.nn.recurrent import GConvLSTM

# 生成动态图数据集
def generate_dynamic_graph(num_nodes=10, num_timesteps=8):
    edge_indices, edge_weights, features, targets = [], [], [], []
    for t in range(num_timesteps):
        g = nx.erdos_renyi_graph(num_nodes, p=0.6, seed=t)
        g.add_nodes_from(range(num_nodes))

        # 过滤无效边
        max_node_id = num_nodes - 1
        edges = [(u, v) for u, v in g.edges() if u <= max_node_id and v <= max_node_id]

        if len(edges) == 0:
            edge_index = np.empty((2, 0), dtype=np.int64)
            edge_attr = np.empty((0, 2), dtype=np.float32)
        else:
            edge_index = np.array(edges, dtype=np.int64).T
            edge_attr = np.random.rand(edge_index.shape[1], 2).astype(np.float32)

        x = np.random.rand(num_nodes, 3).astype(np.float32)
        y = (edge_attr.sum(axis=1) > 1.0).astype(np.int64) if edge_attr.shape[0] > 0 else np.array([], dtype=np.int64)
        w = edge_attr[:, 0].astype(np.float32) if edge_attr.shape[0] > 0 else np.array([], dtype=np.float32)

        edge_indices.append(edge_index)
        edge_weights.append(w)
        features.append(x)
        targets.append(y)

    return DynamicGraphTemporalSignal(edge_indices, edge_weights, features, targets), num_timesteps

# 定义动态图边分类模型
class DynamicEdgeClassifier(torch.nn.Module):
    def __init__(self, node_features, hidden_dim, edge_features=1):
        super(DynamicEdgeClassifier, self).__init__()
        self.hidden_dim = hidden_dim
        self.recurrent = GConvLSTM(node_features, hidden_dim, 2)
        self.edge_mlp = torch.nn.Sequential(
            torch.nn.Linear(2 * hidden_dim + edge_features, hidden_dim),
            torch.nn.ReLU(),
            torch.nn.Linear(hidden_dim, hidden_dim // 2),
            torch.nn.ReLU(),
            torch.nn.Linear(hidden_dim // 2, 1)
        )

    def forward(self, x, edge_index, edge_weight, h, c):
        # 确保隐藏状态与节点数量匹配
        if h is not None and h.size(0) != x.size(0):
            h = torch.zeros(x.size(0), self.hidden_dim, device=x.device)
            c = torch.zeros(x.size(0), self.hidden_dim, device=x.device)
        elif h is None:
            h = torch.zeros(x.size(0), self.hidden_dim, device=x.device)
            c = torch.zeros(x.size(0), self.hidden_dim, device=x.device)

        num_nodes = x.size(0)
        num_edges = edge_index.size(1)

        # 处理空边情况
        if num_edges == 0:
            return torch.empty(0, 1, device=x.device), (h.detach(), c.detach())

        # 确保边权重与边数量匹配
        if edge_weight.numel() == 0:
            edge_weight = torch.zeros(num_edges, dtype=torch.float32, device=x.device)
        elif edge_weight.size(0) != num_edges:
            edge_weight = edge_weight[:num_edges] if edge_weight.size(0) > num_edges else \
                         torch.cat([edge_weight, torch.zeros(num_edges - edge_weight.size(0), dtype=torch.float32, device=x.device)])

        # 计算有效掩码并确保形状匹配
        valid_mask = (edge_index[0] < num_nodes) & (edge_index[1] < num_nodes)

        # 检查掩码和边权重形状是否匹配
        if valid_mask.size(0) != edge_weight.size(0):
            valid_mask = valid_mask[:edge_weight.size(0)] if valid_mask.size(0) > edge_weight.size(0) else \
                         torch.cat([valid_mask, torch.zeros(edge_weight.size(0) - valid_mask.size(0), dtype=torch.bool, device=x.device)])

        # 应用掩码前再次检查
        if edge_weight.size(0) > 0 and valid_mask.size(0) > 0:
            min_size = min(edge_weight.size(0), valid_mask.size(0))
            edge_weight = edge_weight[:min_size]
            valid_mask = valid_mask[:min_size]
            edge_index = edge_index[:, :min_size]

            # 应用掩码
            edge_index = edge_index[:, valid_mask]
            edge_weight = edge_weight[valid_mask]
        else:
            edge_index = edge_index[:, :0]
            edge_weight = edge_weight[:0]

        # 如果清理后没有有效边，返回空预测
        if edge_index.size(1) == 0:
            return torch.empty(0, 1, device=x.device), (h.detach(), c.detach())

        # 处理GConvLSTM不同版本的返回值
        try:
            (h, c) = self.recurrent(x, edge_index, edge_weight, h, c)
            x = h
        except ValueError:
            try:
                x, (h, c) = self.recurrent(x, edge_index, edge_weight, h, c)
            except ValueError:
                h = self.recurrent(x, edge_index, edge_weight, h, c)
                c = h
                x = h

        try:
            edge_u = x[edge_index[0]]
            edge_v = x[edge_index[1]]
        except IndexError as e:
            print(f"索引错误: {e}, 边索引: {edge_index.shape}, 节点数量: {x.size(0)}")
            return torch.empty(0, 1, device=x.device), (h.detach(), c.detach())

        # 拼接特征并预测
        edge_features = torch.cat([edge_u, edge_v, edge_weight.unsqueeze(1)], dim=1)
        edge_pred = self.edge_mlp(edge_features)
        return edge_pred, (h, c)

# 训练模型
def train_model(dataset, num_timesteps, num_epochs=200):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = DynamicEdgeClassifier(node_features=3, hidden_dim=32).to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
    criterion = torch.nn.BCEWithLogitsLoss()

    model.train()
    for epoch in range(num_epochs):
        total_loss = 0
        # 每个epoch重新初始化隐藏状态
        h, c = None, None

        # 转换为列表以便获取长度
        dataset_list = list(dataset)
        time_steps = len(dataset_list)

        for time, snapshot in enumerate(dataset_list):
            x = torch.tensor(snapshot.x, dtype=torch.float32).to(device)
            edge_index = torch.tensor(snapshot.edge_index, dtype=torch.long).to(device)

            # 处理边权重
            if snapshot.edge_weight is None or len(snapshot.edge_weight) == 0:
                edge_weight = torch.empty(0, dtype=torch.float32).to(device)
            else:
                edge_weight = torch.tensor(snapshot.edge_weight, dtype=torch.float32).to(device)

            y = torch.tensor(snapshot.y, dtype=torch.float32).to(device)

            # 前向传播
            optimizer.zero_grad()
            edge_pred, (h, c) = model(x, edge_index, edge_weight, h, c)

            # 确保目标与预测匹配
            if edge_pred.numel() == 0 or y.numel() == 0:
                # 分离隐藏状态，避免空边情况影响计算图
                h = h.detach()
                c = c.detach()
                continue

            if edge_pred.shape[0] != y.shape[0]:
                y = y[:edge_pred.shape[0]] if y.shape[0] > edge_pred.shape[0] else \
                    torch.cat([y, torch.zeros(edge_pred.shape[0] - y.shape[0], dtype=torch.float32, device=device)])

            # 计算损失
            loss = criterion(edge_pred.squeeze(), y)
            total_loss += loss.item()

            # 反向传播
            loss.backward()
            optimizer.step()

            # 分离隐藏状态，防止计算图累积
            h = h.detach()
            c = c.detach()

        if (epoch + 1) % 10 == 0:
            # 使用实际的时间步数计算平均损失
            print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {total_loss/time_steps:.4f}')

    return model

# 评估模型
def evaluate_model(model, dataset):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model.eval()
    total_correct = 0
    total_samples = 0

    # 转换为列表以便迭代
    dataset_list = list(dataset)

    with torch.no_grad():
        h, c = None, None
        for time, snapshot in enumerate(dataset_list):
            x = torch.tensor(snapshot.x, dtype=torch.float32).to(device)
            edge_index = torch.tensor(snapshot.edge_index, dtype=torch.long).to(device)

            if snapshot.edge_weight is None or len(snapshot.edge_weight) == 0:
                edge_weight = torch.empty(0, dtype=torch.float32).to(device)
            else:
                edge_weight = torch.tensor(snapshot.edge_weight, dtype=torch.float32).to(device)

            y = torch.tensor(snapshot.y, dtype=torch.float32).to(device)

            edge_pred, (h, c) = model(x, edge_index, edge_weight, h, c)

            if edge_pred.numel() == 0 or y.numel() == 0:
                continue

            if edge_pred.shape[0] != y.shape[0]:
                y = y[:edge_pred.shape[0]] if y.shape[0] > edge_pred.shape[0] else \
                    torch.cat([y, torch.zeros(edge_pred.shape[0] - y.shape[0], dtype=torch.float32, device=device)])

            pred = (torch.sigmoid(edge_pred) > 0.5).float().squeeze()
            correct = (pred == y).sum().item()

            total_correct += correct
            total_samples += y.size(0)

    accuracy = total_correct / total_samples if total_samples > 0 else 0
    print(f'Evaluation Accuracy: {accuracy:.4f}')
    return accuracy

# 主函数
def main():
    # 生成训练和测试数据，同时返回时间步数
    train_dataset, train_timesteps = generate_dynamic_graph(num_nodes=15, num_timesteps=20)
    test_dataset, test_timesteps = generate_dynamic_graph(num_nodes=15, num_timesteps=10)

    # 训练模型
    print("开始训练模型...")
    model = train_model(train_dataset, train_timesteps)

    # 评估模型
    print("\n在训练集上评估:")
    evaluate_model(model, train_dataset)

    print("\n在测试集上评估:")
    evaluate_model(model, test_dataset)

if __name__ == "__main__":
    main()


开始训练模型...


  x = torch.tensor(snapshot.x, dtype=torch.float32).to(device)
  edge_index = torch.tensor(snapshot.edge_index, dtype=torch.long).to(device)
  y = torch.tensor(snapshot.y, dtype=torch.float32).to(device)


Epoch [10/200], Loss: 0.6929
Epoch [20/200], Loss: 0.6928
Epoch [30/200], Loss: 0.6928
Epoch [40/200], Loss: 0.6910
Epoch [50/200], Loss: 0.6869
Epoch [60/200], Loss: 0.6826
Epoch [70/200], Loss: 0.6847
Epoch [80/200], Loss: 0.6659
Epoch [90/200], Loss: 0.6608
Epoch [100/200], Loss: 0.6386
Epoch [110/200], Loss: 0.6204
Epoch [120/200], Loss: 0.5778
Epoch [130/200], Loss: 0.5686
Epoch [140/200], Loss: 0.5282
Epoch [150/200], Loss: 0.4772
Epoch [160/200], Loss: 0.4380
Epoch [170/200], Loss: 0.4754
Epoch [180/200], Loss: 0.4539
Epoch [190/200], Loss: 0.2740
Epoch [200/200], Loss: 0.2816

在训练集上评估:
Evaluation Accuracy: 0.8709

在测试集上评估:
Evaluation Accuracy: 0.5024


  x = torch.tensor(snapshot.x, dtype=torch.float32).to(device)
  edge_index = torch.tensor(snapshot.edge_index, dtype=torch.long).to(device)
  y = torch.tensor(snapshot.y, dtype=torch.float32).to(device)
