In [1]:
import torch
import torch.nn.functional as F
import torch_geometric
from torch_geometric.datasets import Planetoid
from torch_geometric.nn import SAGEConv

In [2]:
# Enable torch to use mps
device = torch.device('mps' if torch.backends.mps.is_available() else 'cpu')
print(f'Using device: {device}')

Using device: mps


In [3]:
# NOTE: The cora dataset contains a single undirected citation graph
dataset = Planetoid(root='data/Cora', name='Cora')

In [4]:
print(f'Dataset: {dataset.name}')
print(f'Number of graphs: {len(dataset)}')
print(f'Number of classes: {dataset.num_classes}')
print(f'Number of node features: {dataset.num_node_features}')
print(f'Number of edge features: {dataset.num_edge_features}')

Dataset: Cora
Number of graphs: 1
Number of classes: 7
Number of node features: 1433
Number of edge features: 0


In [5]:
data = dataset[0]
print(data)

Data(x=[2708, 1433], edge_index=[2, 10556], y=[2708], train_mask=[2708], val_mask=[2708], test_mask=[2708])


In [6]:
# Print x shape, edge_index shape, y shape
print(f'Node features shape: {data.x.shape}')
print(f'Edge index shape: {data.edge_index.shape}')
print(f'Labels shape: {data.y.shape}')
print(f'Number of training nodes: {data.train_mask.sum().item()}')
print(f'Number of validation nodes: {data.val_mask.sum().item()}')
print(f'Number of test nodes: {data.test_mask.sum().item()}')
# NOTE: The graph has 2708 nodes and 10556/2 = 5278 edges.

Node features shape: torch.Size([2708, 1433])
Edge index shape: torch.Size([2, 10556])
Labels shape: torch.Size([2708])
Number of training nodes: 140
Number of validation nodes: 500
Number of test nodes: 1000


In [7]:
class Net(torch.nn.Module):

    def __init__(self, hidden_channels=16, dropout=0.5, layer_kwargs=None):
        super(Net, self).__init__()
        self.conv1 = SAGEConv(dataset.num_node_features, hidden_channels, **(layer_kwargs or {}))
        self.conv2 = SAGEConv(hidden_channels, dataset.num_classes, **(layer_kwargs or {}))
        self.dropout = dropout
    
    def forward(self, data):
        x, edge_index = data.x, data.edge_index
        x = self.conv1(x, edge_index)
        x = F.relu(x)
        x = F.dropout(x, p=self.dropout, training=self.training)
        x = self.conv2(x, edge_index)
        return F.log_softmax(x, dim=1)

In [8]:
model = Net().to(device)
data = dataset[0].to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)

In [9]:
def train():
    model.train()
    optimizer.zero_grad()
    out = model(data)
    loss = F.nll_loss(out[data.train_mask], data.y[data.train_mask])
    loss.backward()
    optimizer.step()
    return loss.item()

def test():
    model.eval()
    out = model(data)
    pred = out.argmax(dim=1)
    accs = []
    for mask in [data.train_mask, data.val_mask, data.test_mask]:
        correct = pred[mask] == data.y[mask]
        accs.append(int(correct.sum()) / int(mask.sum()))
    return accs

In [10]:
for epoch in range(1, 301):
    loss = train()
    train_acc, val_acc, test_acc = test()
    if epoch % 10 == 0:
        print(f'Epoch: {epoch:03d}, Loss: {loss:.4f}, '
              f'Train: {train_acc:.4f}, Val: {val_acc:.4f}, Test: {test_acc:.4f}')

Epoch: 010, Loss: 0.4914, Train: 0.9929, Val: 0.7400, Test: 0.7500
Epoch: 020, Loss: 0.1075, Train: 1.0000, Val: 0.7660, Test: 0.7700
Epoch: 030, Loss: 0.0551, Train: 1.0000, Val: 0.7600, Test: 0.7570
Epoch: 040, Loss: 0.0254, Train: 1.0000, Val: 0.7560, Test: 0.7660
Epoch: 050, Loss: 0.0220, Train: 1.0000, Val: 0.7620, Test: 0.7750
Epoch: 060, Loss: 0.0158, Train: 1.0000, Val: 0.7640, Test: 0.7840
Epoch: 070, Loss: 0.0174, Train: 1.0000, Val: 0.7600, Test: 0.7810
Epoch: 080, Loss: 0.0115, Train: 1.0000, Val: 0.7640, Test: 0.7880
Epoch: 090, Loss: 0.0284, Train: 1.0000, Val: 0.7680, Test: 0.7920
Epoch: 100, Loss: 0.0181, Train: 1.0000, Val: 0.7640, Test: 0.7850
Epoch: 110, Loss: 0.0125, Train: 1.0000, Val: 0.7680, Test: 0.7890
Epoch: 120, Loss: 0.0261, Train: 1.0000, Val: 0.7860, Test: 0.7940
Epoch: 130, Loss: 0.0251, Train: 1.0000, Val: 0.7680, Test: 0.7820
Epoch: 140, Loss: 0.0146, Train: 1.0000, Val: 0.7760, Test: 0.7790
Epoch: 150, Loss: 0.0147, Train: 1.0000, Val: 0.7840, Test: 0.