In [8]:
import torch
import torch.nn.functional as F
from torch_geometric.datasets import Planetoid
from torch_geometric.nn import GATConv, global_mean_pool
from torch_geometric.transforms import NormalizeFeatures

dataset = Planetoid(root='/tmp/Cora', name='Cora', transform=NormalizeFeatures())
data = dataset[0]

class GATWithPooling(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels):
        super(GATWithPooling, self).__init__()
        self.gat1 = GATConv(in_channels, hidden_channels, heads=8, dropout=0.6)
        self.gat2 = GATConv(8 * hidden_channels, out_channels, heads=1, concat=False, dropout=0.6)

    def forward(self, x, edge_index, batch):
        x = F.dropout(x, p=0.6, training=self.training)
        x = F.elu(self.gat1(x, edge_index))
        x = F.dropout(x, p=0.6, training=self.training)
        x = self.gat2(x, edge_index)
        return x

model = GATWithPooling(dataset.num_features, 8, dataset.num_classes)
optimizer = torch.optim.Adam(model.parameters(), lr=0.005, weight_decay=5e-4)

model.train()
for epoch in range(200):
    optimizer.zero_grad()
    out = model(data.x, data.edge_index, data.batch)
    loss = F.nll_loss(F.log_softmax(out[data.train_mask], dim=1), data.y[data.train_mask])
    loss.backward()
    optimizer.step()

model.eval()
_, pred = model(data.x, data.edge_index, data.batch).max(dim=1)
correct = int(pred[data.test_mask].eq(data.y[data.test_mask]).sum().item())
accuracy = correct / int(data.test_mask.sum())
print(f'Accuracy: {accuracy:.4f}')


Accuracy: 0.8220
