In [12]:
import torch
import torch.nn.functional as F
from torch_geometric.nn import DenseSAGEConv, dense_diff_pool
from torch_geometric.datasets import TUDataset
from torch_geometric.data import DataLoader
from torch_geometric.utils import to_dense_adj, to_dense_batch

# Load a graph classification dataset
dataset = TUDataset(root='/tmp/ENZYMES', name='ENZYMES', use_node_attr=True)

# Split the dataset into train and test sets
torch.manual_seed(12345)
dataset = dataset.shuffle()
train_dataset = dataset[:540]
test_dataset = dataset[540:]

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)


ImportError: cannot import name 'DiffPool' from 'torch_geometric.nn' (/opt/homebrew/lib/python3.11/site-packages/torch_geometric/nn/__init__.py)

In [9]:
class DiffPoolNet(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels, num_classes, max_nodes):
        super(DiffPoolNet, self).__init__()
        self.conv1 = DenseSAGEConv(in_channels, hidden_channels)
        self.bn1 = torch.nn.BatchNorm1d(hidden_channels)
        self.conv2 = DenseSAGEConv(hidden_channels, out_channels)
        self.max_nodes = max_nodes

        # Adjust the input dimension of the first fully connected layer
        self.fc1 = torch.nn.Linear(out_channels * max_nodes, hidden_channels)
        self.fc2 = torch.nn.Linear(hidden_channels, num_classes)

    def forward(self, x, adj, mask=None):
        s = torch.softmax(self.conv1(x, adj), dim=1)
        x = F.relu(self.bn1(self.conv2(x, adj)))
        x, adj, l1, e1 = dense_diff_pool(x, adj, s, mask)
        
        # Flatten and reshape x to match the input size of self.fc1
        x = x.view(-1, self.max_nodes * x.size(-1))
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return F.log_softmax(x, dim=-1), l1, e1

# Find the maximum number of nodes in any graph in the dataset
max_nodes = max(data.num_nodes for data in dataset)
model = DiffPoolNet(dataset.num_node_features, 64, 64, dataset.num_classes, max_nodes)


In [10]:
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
criterion = torch.nn.NLLLoss()

def train():
    model.train()
    total_loss = 0
    for data in train_loader:
        optimizer.zero_grad()
        x, mask = to_dense_batch(data.x, data.batch)
        adj = to_dense_adj(data.edge_index, data.batch)
        out, l1, e1 = model(x, adj, mask)
        loss = criterion(out, data.y.view(-1)) + l1 + e1
        loss.backward()
        total_loss += loss.item() * data.num_graphs
        optimizer.step()
    return total_loss / len(train_loader.dataset)

def test(loader):
    model.eval()
    correct = 0
    for data in loader:
        x, mask = to_dense_batch(data.x, data.batch)
        adj = to_dense_adj(data.edge_index, data.batch)
        out, _, _ = model(x, adj, mask)
        pred = out.max(dim=1)[1]
        correct += pred.eq(data.y.view(-1)).sum().item()
    return correct / len(loader.dataset)


for epoch in range(1, 51):
    loss = train()
    train_acc = test(train_loader)
    test_acc = test(test_loader)
    print(f'Epoch: {epoch:03d}, Loss: {loss:.4f}, Train Acc: {train_acc:.4f}, Test Acc: {test_acc:.4f}')


RuntimeError: mat1 and mat2 shapes cannot be multiplied (3072x21 and 64x64)