In [30]:
import torch
import torch.nn.functional as F
from torch_geometric.nn import GraphConv, global_mean_pool
from torch_geometric.datasets import TUDataset
from torch_geometric.loader import DataLoader

class GCN(torch.nn.Module):
    def __init__(self, num_features, num_classes):
        super(GCN, self).__init__()
        self.conv1 = GraphConv(num_features, 128)
        self.conv2 = GraphConv(128, 128)
        self.fc = torch.nn.Linear(128, num_classes)

    def forward(self, data):
        x, edge_index, batch = data.x, data.edge_index, data.batch

        x = F.relu(self.conv1(x, edge_index))
        x = F.relu(self.conv2(x, edge_index))
        x = global_mean_pool(x, batch)
        x = F.dropout(x, training=self.training)
        x = self.fc(x)

        return F.log_softmax(x, dim=1)

# Load the dataset
dataset = TUDataset(root='/tmp/ENZYMES', name='ENZYMES')
loader = DataLoader(dataset, batch_size=32, shuffle=True)

# Initialize the network and optimizer
model = GCN(dataset.num_node_features, dataset.num_classes).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

# Train the model
def train():
    model.train()
    total_loss = 0

    for data in loader:
        data = data.to(device)
        optimizer.zero_grad()
        out = model(data)
        loss = F.nll_loss(out, data.y)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()

    return total_loss / len(loader)

# Training loop
for epoch in range(1, 1000):
    loss = train()
    if epoch % 100 == 0:
        print(f'Epoch: {epoch}, Loss: {loss:.4f}')


Epoch: 100, Loss: 1.5687
Epoch: 200, Loss: 1.4333
Epoch: 300, Loss: 1.3568
Epoch: 400, Loss: 1.2901
Epoch: 500, Loss: 1.2186
Epoch: 600, Loss: 1.1556
Epoch: 700, Loss: 1.1132
Epoch: 800, Loss: 1.0740
Epoch: 900, Loss: 1.0430
