In [27]:
import os.path as osp
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch_geometric.transforms as T
from torch_geometric.datasets import Planetoid
from torch_geometric.nn import GATConv

In [28]:
dataset_name = 'Cora'
path = osp.join(".", "data", "Planetoid")
dataset = Planetoid(path, dataset_name, transform=T.NormalizeFeatures())
data = dataset[0]

In [29]:
dataset

Cora()

In [30]:
data

Data(edge_index=[2, 10556], test_mask=[2708], train_mask=[2708], val_mask=[2708], x=[2708, 1433], y=[2708])

In [44]:
data.y

tensor([3, 4, 4,  ..., 3, 3, 3], device='cuda:0')

In [45]:
data.train_mask

tensor([ True,  True,  True,  ..., False, False, False], device='cuda:0')

In [46]:
print(data.y[data.test_mask])

tensor([3, 2, 2, 2, 2, 0, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 3, 2, 2, 2,
        2, 2, 2, 1, 2, 2, 2, 2, 2, 3, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
        2, 2, 2, 2, 2, 2, 2, 2, 5, 2, 2, 1, 1, 1, 1, 1, 1, 1, 4, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 4, 1, 1, 1, 1, 1, 1, 3, 4, 4, 4, 4, 1, 1, 3, 1, 0, 3, 0,
        2, 1, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 5, 5, 5, 5,
        5, 5, 2, 2, 2, 2, 1, 6, 6, 3, 0, 0, 5, 0, 5, 0, 3, 5, 3, 0, 0, 6, 0, 6,
        3, 3, 1, 3, 1, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3,
        3, 3, 5, 5, 5, 5, 5, 5, 5, 5, 2, 2, 2, 4, 4, 4, 0, 3, 3, 2, 5, 5, 5, 5,
        6, 5, 5, 5, 5, 0, 4, 4, 4, 0, 0, 5, 0, 0, 6, 6, 6, 6, 6, 6, 0, 0, 0, 0,
        3, 0, 0, 0, 3, 3, 0, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3,
        3, 3, 3, 3, 3, 5, 5, 5, 5, 3, 5, 5, 5, 5, 5, 5, 4, 4, 4, 4, 4, 4, 4, 4,
        6, 6, 5, 6, 6, 3, 5, 5, 5, 0, 5, 0, 4, 4, 3, 3, 3, 2, 2, 1, 3, 3, 3, 3,
        3, 3, 5, 3, 3, 4, 4, 3, 3, 3, 3,

In [31]:
class GAT(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        
        self.conv1 = GATConv(in_channels, 8, heads=8, dropout=0.6)
        self.conv2 = GATConv(8 * 8, out_channels, heads=1, concat=False, dropout=0.6)
        
    def forward(self, x, edge_index):
        x = F.dropout(x, p=0.6, training=self.training)
        x = F.elu(self.conv1(x, edge_index))
        x = F.dropout(x, p=0.6, training=self.training)
        x = self.conv2(x, edge_index)
        return F.log_softmax(x, dim=1)

In [32]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [33]:
device

device(type='cuda')

In [34]:
model = GAT(dataset.num_features, dataset.num_classes).to(device)

In [35]:
print(model)

GAT(
  (conv1): GATConv(1433, 8, heads=8)
  (conv2): GATConv(64, 7, heads=1)
)


In [36]:
data = data.to(device)

In [37]:
data

Data(edge_index=[2, 10556], test_mask=[2708], train_mask=[2708], val_mask=[2708], x=[2708, 1433], y=[2708])

In [38]:
optimizer = torch.optim.Adam(model.parameters(), lr=0.005, weight_decay=5e-4)

In [39]:
def train(data):
    model.train()
    optimizer.zero_grad()
    out = model(data.x, data.edge_index)
    loss = F.nll_loss(out[data.train_mask], data.y[data.train_mask]) # 半监督
    loss.backward()
    optimizer.step()

In [42]:
@torch.no_grad()
def test(data):
    model.eval()
    out, accs = model(data.x, data.edge_index), []
    for _, mask in data('train_mask', 'val_mask', 'test_mask'):
        acc = int((out[mask].argmax(-1) == data.y[mask]).sum())/ int(mask.sum())
        accs.append(acc)
    return accs

In [43]:
for epoch in range(1, 201):
    train(data)
    train_acc, val_acc, test_acc = test(data)
    print(f'Epoch: {epoch:03d}, Train: {train_acc:.4f}, Val: {val_acc:.4f}, '
          f'Test: {test_acc:.4f}')

Epoch: 001, Train: 0.9857, Val: 0.7980, Test: 0.8050
Epoch: 002, Train: 0.9857, Val: 0.7960, Test: 0.8070
Epoch: 003, Train: 0.9857, Val: 0.7960, Test: 0.8070
Epoch: 004, Train: 0.9857, Val: 0.7960, Test: 0.8060
Epoch: 005, Train: 0.9857, Val: 0.7920, Test: 0.8060
Epoch: 006, Train: 0.9857, Val: 0.7940, Test: 0.8060
Epoch: 007, Train: 0.9857, Val: 0.7980, Test: 0.8140
Epoch: 008, Train: 0.9857, Val: 0.7980, Test: 0.8150
Epoch: 009, Train: 0.9857, Val: 0.7960, Test: 0.8160
Epoch: 010, Train: 0.9929, Val: 0.8000, Test: 0.8190
Epoch: 011, Train: 0.9929, Val: 0.8060, Test: 0.8230
Epoch: 012, Train: 0.9929, Val: 0.8120, Test: 0.8240
Epoch: 013, Train: 0.9929, Val: 0.8200, Test: 0.8240
Epoch: 014, Train: 0.9929, Val: 0.8200, Test: 0.8280
Epoch: 015, Train: 0.9857, Val: 0.8180, Test: 0.8320
Epoch: 016, Train: 0.9857, Val: 0.8160, Test: 0.8300
Epoch: 017, Train: 0.9857, Val: 0.8120, Test: 0.8290
Epoch: 018, Train: 0.9857, Val: 0.8120, Test: 0.8280
Epoch: 019, Train: 0.9857, Val: 0.8080, Test: 

Epoch: 169, Train: 1.0000, Val: 0.8040, Test: 0.8270
Epoch: 170, Train: 1.0000, Val: 0.8020, Test: 0.8270
Epoch: 171, Train: 1.0000, Val: 0.8000, Test: 0.8290
Epoch: 172, Train: 1.0000, Val: 0.8000, Test: 0.8300
Epoch: 173, Train: 1.0000, Val: 0.8020, Test: 0.8280
Epoch: 174, Train: 1.0000, Val: 0.8000, Test: 0.8270
Epoch: 175, Train: 0.9929, Val: 0.8000, Test: 0.8260
Epoch: 176, Train: 0.9929, Val: 0.8040, Test: 0.8240
Epoch: 177, Train: 0.9929, Val: 0.8000, Test: 0.8230
Epoch: 178, Train: 0.9929, Val: 0.7960, Test: 0.8240
Epoch: 179, Train: 0.9929, Val: 0.8000, Test: 0.8230
Epoch: 180, Train: 0.9929, Val: 0.8000, Test: 0.8220
Epoch: 181, Train: 1.0000, Val: 0.8020, Test: 0.8230
Epoch: 182, Train: 1.0000, Val: 0.8080, Test: 0.8250
Epoch: 183, Train: 1.0000, Val: 0.8140, Test: 0.8260
Epoch: 184, Train: 1.0000, Val: 0.8140, Test: 0.8310
Epoch: 185, Train: 1.0000, Val: 0.8080, Test: 0.8260
Epoch: 186, Train: 1.0000, Val: 0.8080, Test: 0.8260
Epoch: 187, Train: 1.0000, Val: 0.8040, Test: 