In [1]:
import torch
import torch.nn.functional as F
from torch_geometric.datasets import Planetoid
from torch_geometric.nn import GCNConv, GATConv, SAGEConv
from torch.optim.lr_scheduler import StepLR

# 加载Cora数据集
dataset = Planetoid(root='D:\\temp\\Cora', name='Cora')
data = dataset[0]

# 定义模型
class GCN(torch.nn.Module):
    def __init__(self, activation=F.relu):
        super(GCN, self).__init__()
        self.conv1 = GCNConv(dataset.num_node_features, 16)
        self.conv2 = GCNConv(16, dataset.num_classes)
        self.activation = activation

    def forward(self, data):
        x, edge_index = data.x, data.edge_index
        x = self.conv1(x, edge_index)
        x = self.activation(x)
        x = F.dropout(x, training=self.training)
        x = self.conv2(x, edge_index)
        return F.log_softmax(x, dim=1)

class GAT(torch.nn.Module):
    def __init__(self, activation=F.elu):
        super(GAT, self).__init__()
        self.conv1 = GATConv(dataset.num_node_features, 16, heads=8, dropout=0.6)
        self.conv2 = GATConv(16 * 8, dataset.num_classes, heads=1, concat=False, dropout=0.6)
        self.activation = activation

    def forward(self, data):
        x, edge_index = data.x, data.edge_index
        x = self.conv1(x, edge_index)
        x = self.activation(x)
        x = F.dropout(x, training=self.training)
        x = self.conv2(x, edge_index)
        return F.log_softmax(x, dim=1)

class GraphSAGE(torch.nn.Module):
    def __init__(self, activation=F.relu):
        super(GraphSAGE, self).__init__()
        self.conv1 = SAGEConv(dataset.num_node_features, 16)
        self.conv2 = SAGEConv(16, dataset.num_classes)
        self.activation = activation

    def forward(self, data):
        x, edge_index = data.x, data.edge_index
        x = self.conv1(x, edge_index)
        x = self.activation(x)
        x = F.dropout(x, training=self.training)
        x = self.conv2(x, edge_index)
        return F.log_softmax(x, dim=1)

# 训练模型
def train(model, optimizer):
    model.train()
    optimizer.zero_grad()
    out = model(data)
    loss = F.nll_loss(out[data.train_mask], data.y[data.train_mask])
    loss.backward()
    optimizer.step()
    return loss.item()

# 测试模型
def test(model):
    model.eval()
    logits, accs = model(data), []
    for mask in [data.train_mask, data.val_mask, data.test_mask]:
        pred = logits[mask].max(1)[1]
        acc = pred.eq(data.y[mask]).sum().item() / mask.sum().item()
        accs.append(acc)
    return accs

# 运行和评估每个模型
models = {
    'GCN': GCN(),
    'GAT': GAT(),
    'GraphSAGE': GraphSAGE()
}

for model_name, model in models.items():
    optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)
    scheduler = StepLR(optimizer, step_size=50, gamma=0.5)  # 每50个epoch将学习率减半
    best_val_acc = 0
    patience = 10  # 早停的耐心值
    patience_counter = 0

    print(f'Running {model_name}...')

    for epoch in range(200):
        loss = train(model, optimizer)
        train_acc, val_acc, test_acc = test(model)
        scheduler.step()

        print(f'Epoch: {epoch:03d}, Loss: {loss:.4f}, Train Acc: {train_acc:.4f}, Val Acc: {val_acc:.4f}, Test Acc: {test_acc:.4f}')
        
        # 早停判断
        if val_acc > best_val_acc:
            best_val_acc = val_acc
            patience_counter = 0
        else:
            patience_counter += 1
        
        if patience_counter >= patience:
            print("Early stopping triggered")
            break
    
    print(f'Final Test Accuracy for {model_name}: {test_acc:.4f}')
    print('-'*50)


Running GCN...
Epoch: 000, Loss: 1.9624, Train Acc: 0.6214, Val Acc: 0.3380, Test Acc: 0.3820
Epoch: 001, Loss: 1.8509, Train Acc: 0.7429, Val Acc: 0.4620, Test Acc: 0.4840
Epoch: 002, Loss: 1.7288, Train Acc: 0.8071, Val Acc: 0.5000, Test Acc: 0.5180
Epoch: 003, Loss: 1.5787, Train Acc: 0.8000, Val Acc: 0.5120, Test Acc: 0.5030
Epoch: 004, Loss: 1.4589, Train Acc: 0.8071, Val Acc: 0.5060, Test Acc: 0.5070
Epoch: 005, Loss: 1.2983, Train Acc: 0.8143, Val Acc: 0.5240, Test Acc: 0.5140
Epoch: 006, Loss: 1.1849, Train Acc: 0.8571, Val Acc: 0.5420, Test Acc: 0.5430
Epoch: 007, Loss: 1.0986, Train Acc: 0.8714, Val Acc: 0.5820, Test Acc: 0.5820
Epoch: 008, Loss: 0.9332, Train Acc: 0.9143, Val Acc: 0.6260, Test Acc: 0.6230
Epoch: 009, Loss: 0.8674, Train Acc: 0.9357, Val Acc: 0.6740, Test Acc: 0.6810
Epoch: 010, Loss: 0.7500, Train Acc: 0.9500, Val Acc: 0.7160, Test Acc: 0.7300
Epoch: 011, Loss: 0.7055, Train Acc: 0.9714, Val Acc: 0.7380, Test Acc: 0.7540
Epoch: 012, Loss: 0.5670, Train Acc: 