In [3]:
import torch
from torch_geometric.datasets import TUDataset

dataset = TUDataset(root='data/TUDataset', name='MUTAG')

In [18]:
dataset.num_edge_features

4

In [19]:
data = dataset[0]
data.y

tensor([1])

In [22]:
torch.manual_seed(12345)
dataset = dataset.shuffle()
train_dataset = dataset[:150]
test_dataset = dataset[150:]

In [26]:
from torch_geometric.data import DataLoader

train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)

for step, data in enumerate(train_loader):
    print(f'Step {step + 1}:')
    print(f'Number of graphs in current step :', data.num_graphs)
    print(data)

Step 1:
Number of graphs in current step : 64
Batch(batch=[1135], edge_attr=[2512, 4], edge_index=[2, 2512], ptr=[65], x=[1135, 7], y=[64])
Step 2:
Number of graphs in current step : 64
Batch(batch=[1136], edge_attr=[2498, 4], edge_index=[2, 2498], ptr=[65], x=[1136, 7], y=[64])
Step 3:
Number of graphs in current step : 22
Batch(batch=[411], edge_attr=[916, 4], edge_index=[2, 916], ptr=[23], x=[411, 7], y=[22])


In [28]:
from torch.nn import Linear
import torch.nn.functional as F
from torch_geometric.nn import GCNConv
from torch_geometric.nn import global_mean_pool


class GCN(torch.nn.Module):
    def __init__(self, hidden_channels):
        super(GCN, self).__init__()
        torch.manual_seed(12345)
        self.conv1 = GCNConv(dataset.num_node_features, hidden_channels)
        self.conv2 = GCNConv(hidden_channels, hidden_channels)
        self.conv3 = GCNConv(hidden_channels, hidden_channels)
        self.lin = Linear(hidden_channels, dataset.num_classes)
        
    def forward(self, x, edge_index, batch):
        x = self.conv1(x, edge_index)
        x = x.relu()
        x = self.conv2(x, edge_index)
        x = x.relu()
        x = self.conv3(x, edge_index)
        
        x = global_mean_pool(x, batch)
        
        x = F.dropout(x, p=0.5, training=self.training)
        x = self.lin(x)
        
        return x

model = GCN(hidden_channels=64)
print(model)

GCN(
  (conv1): GCNConv(7, 64)
  (conv2): GCNConv(64, 64)
  (conv3): GCNConv(64, 64)
  (lin): Linear(in_features=64, out_features=2, bias=True)
)
