In [1]:
import os
import os.path as osp
import sys
from math import ceil
import numpy as np

import torch
import torch.nn.functional as F
import torch_geometric
import torch_geometric.transforms as T
from dataset import GraphDataset
from models import *
from sklearn.metrics import normalized_mutual_info_score as NMI
from sklearn.model_selection import train_test_split
from torch.nn import Linear
from torch_geometric import utils
from torch_geometric.loader import DataLoader
from torch_geometric.datasets import Planetoid, TUDataset
from torch_geometric.nn import DMoNPooling, GCNConv, GraphConv, Sequential, TopKPooling
from torch_geometric.nn.conv.gcn_conv import gcn_norm

In [2]:
dataset = GraphDataset(torch.load("../data/cycle_line_star_complete_1.pt"))

train_dataset, test_dataset = train_test_split(dataset, test_size=0.2, random_state=42)
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)

In [3]:
import torch
import torch.nn.functional as F
from torch_geometric.nn import GCNConv, TopKPooling


class TopKPool(torch.nn.Module):
    def __init__(
        self,
        mp_units,
        mp_act,
        in_channels,
        n_clusters,
        mlp_units=[],
        mlp_act="Identity",
    ):
        super().__init__()
        self.name = "MinCut"

        mp_act = getattr(torch.nn, mp_act)(inplace=True)
        mlp_act = getattr(torch.nn, mlp_act)(inplace=True)

        mp = [
            (
                GCNConv(in_channels, mp_units[0], normalize=False, cached=False),
                "x, edge_index -> x",
            ),
            mp_act,
        ]
        for i in range(len(mp_units) - 1):
            mp.append(
                (
                    GCNConv(
                        mp_units[i], mp_units[i + 1], normalize=False, cached=False
                    ),
                    "x, edge_index -> x",
                )
            )
            mp.append(mp_act)
        self.mp = Sequential("x, edge_index", mp)
        out_chan = mp_units[-1]

        self.mlp = torch.nn.Sequential()
        for units in mlp_units:
            self.mlp.append(Linear(out_chan, units))
            out_chan = units
            self.mlp.append(mlp_act)
        self.mlp.append(Linear(out_chan, n_clusters))

        self.pool = TopKPooling(mp_units[-1], ratio=8)

    def forward(self, data):
        x, edge_index, batch = data.x, data.edge_index, data.batch
        x = self.mp(x, edge_index)
        s = self.mlp(x)
        adj = utils.to_dense_adj(edge_index)
        x_pooled, edge_index_pooled, edge_attr, _, perm, score = self.pool(
            x, edge_index, None, batch
        )

        return torch.softmax(s, dim=-1), perm

In [4]:
import torch
import torch.nn.functional as F

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = TopKPool(
    [64] * 5,
    "ReLU",
    dataset.num_features,
    len(np.unique(dataset[0].node_classes)),
    [16],
    "ReLU",
).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)


def train():
    model.train()
    total_loss = 0
    for data in train_loader:
        data = data.to(device)
        optimizer.zero_grad()
        out, perm = model(data)

        target = [cls[0] for cls in data.node_classes[0]]
        target = torch.tensor(target, device=device, dtype=torch.long)
        # target = target[perm]
        target = target - 1
        loss = F.cross_entropy(out, target)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    return total_loss / len(train_loader)


@torch.no_grad()
def test(loader):
    model.eval()
    correct = 0
    total = 0
    for data in loader:
        data = data.to(device)
        out, perm = model(data)
        target_indices = set(range(8))
        selected_indices = set(perm.cpu().numpy())
        correct += len(target_indices.intersection(selected_indices))
        total += len(target_indices)
    accuracy = correct / total
    print(f"Correct: {correct}, Total: {total}, Accuracy: {accuracy:.4f}")
    return accuracy


for epoch in range(1, 101):
    train_loss = train()
    test_acc = test(test_loader)
    if epoch % 10 == 0:
        print(f"Epoch: {epoch:03d}, Loss: {train_loss:.4f}, Test Acc: {test_acc:.4f}")

ValueError: Expected input batch_size (642) to match target batch_size (11).