In [1]:
!pip install torch_geometric pyg_lib torch_scatter torch_sparse torch_cluster torch_spline_conv -f https://data.pyg.org/whl/torch-2.8.0+cu126.html

Looking in links: https://data.pyg.org/whl/torch-2.8.0+cu126.html
Collecting torch_geometric
  Downloading torch_geometric-2.7.0-py3-none-any.whl.metadata (63 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m63.7/63.7 kB[0m [31m2.2 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting pyg_lib
  Downloading https://data.pyg.org/whl/torch-2.8.0%2Bcu126/pyg_lib-0.5.0%2Bpt28cu126-cp312-cp312-linux_x86_64.whl (4.8 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m4.8/4.8 MB[0m [31m56.3 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting torch_scatter
  Downloading https://data.pyg.org/whl/torch-2.8.0%2Bcu126/torch_scatter-2.1.2%2Bpt28cu126-cp312-cp312-linux_x86_64.whl (10.9 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m10.9/10.9 MB[0m [31m80.7 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting torch_sparse
  Downloading https://data.pyg.org/whl/torch-2.8.0%2Bcu126/torch_sparse-0.6.18%2Bpt28cu126-cp312-cp312-linux_x86_64.whl (5.2 MB

In [2]:
import argparse
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch_geometric
import pandas as pd

from torch_geometric.nn import GINConv, global_add_pool
from torch_geometric.loader import DataLoader
from sklearn import metrics

In [3]:
class GIN(nn.Module):
    def __init__(self, input_dim, hidden_dim, num_layers, dropout):
        super().__init__()

        self.dropout = dropout
        self.convs = nn.ModuleList()
        self.batch_norms = nn.ModuleList()

        for _ in range(num_layers):
            self.convs.append(
                GINConv(nn.Sequential(
                    nn.Linear(input_dim, 2 * hidden_dim),
                    nn.BatchNorm1d(2 * hidden_dim),
                    nn.ReLU(),
                    nn.Linear(2 * hidden_dim, hidden_dim),
                ))
            )
            self.batch_norms.append(nn.BatchNorm1d(hidden_dim))
            input_dim = hidden_dim

        self.lin1 = nn.Linear(hidden_dim, hidden_dim)
        self.batch_norm1 = nn.BatchNorm1d(hidden_dim)
        self.classifier = nn.Linear(hidden_dim, 1)

    def forward(self, data):
        x, edge_index, batch = data.x, data.edge_index, data.batch
        for conv, batch_norm in zip(self.convs, self.batch_norms):
            x = F.relu(batch_norm(conv(x, edge_index)))
            x = F.dropout(x, self.dropout, training=self.training)
        x = global_add_pool(x, batch)
        x = F.relu(self.batch_norm1(self.lin1(x)))
        x = F.dropout(x, self.dropout, training=self.training)
        return self.classifier(x).view(-1)

In [4]:
def train():
    model.train()
    total_loss = 0

    for data in train_loader:
        data = data.to(device)
        out = model(data)
        loss = criterion(out, data.y.float())

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item() * data.num_graphs
        
    return total_loss / len(train_loader.dataset)


@torch.no_grad()
def test(loader):
    model.eval()
    predictions = []
    labels = []

    for data in loader:
        data = data.to(device)
        out = model(data)
        pred = (out > 0).float()
        predictions.append(pred.cpu())
        labels.append(data.y.cpu())

    accuracy = metrics.accuracy_score(torch.cat(labels), torch.cat(predictions))
    f1 = metrics.f1_score(torch.cat(labels), torch.cat(predictions))

    return accuracy, f1

In [5]:
parser = argparse.ArgumentParser(description="GIN for partial automorphism extension problem")
parser.add_argument("--seed", type=int, default=42, 
                    help="Random seed for reproducibility (default: 42)") 
parser.add_argument("--batch_size", type=int, default=64,
                    help="Input batch size (default: 64)")
parser.add_argument("--epochs", type=int, default=150,
                    help="Number of epochs to train (default: 150)")
parser.add_argument("--lr", type=float, default=0.0008007016085176578,
                    help="Learning rate (default: 0.0008007016085176578)")
parser.add_argument("--weight_decay", type=float, default=1.5408221478908417e-05,
                    help="Weight decay (default: 1.5408221478908417e-05)")
parser.add_argument("--hidden_dim", type=int, default=512,
                    help="Hidden dimension size (default: 512)") 
parser.add_argument("--num_layers", type=int, default=2, 
                    help="Number of GIN layers (default: 2)")
parser.add_argument("--dropout", type=float, default=0.04821922755593036, 
                    help="Dropout rate (default: 0.04821922755593036)")
parser.add_argument("--factor", type=float, default=0.5,
                    help="Factor for learning rate scheduler (default: 0.5)")
parser.add_argument("--patience", type=int, default=3,
                    help="Patience for learning rate scheduler (default: 3)")
args = parser.parse_args('')

In [6]:
torch_geometric.seed_everything(args.seed)

train_dataset = torch.load('/kaggle/input/graphs-with-automorphisms/train_dataset.pt',weights_only=False)
val_dataset = torch.load('/kaggle/input/graphs-with-automorphisms/val_dataset.pt',weights_only=False)

train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=args.batch_size)

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
model = GIN(3, args.hidden_dim, args.num_layers, args.dropout).to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
criterion = nn.BCEWithLogitsLoss()
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
    optimizer, mode='max', factor=args.factor, patience=args.patience
)

In [7]:
best_model_stats = [0.0, 0.0, 0.0, 0.0, 0.0]
training_history = []
patience = 15
patience_counter = 0

for epoch in range(1, args.epochs + 1):
    train_loss = train()
    train_acc, train_f1 = test(train_loader)
    val_acc, val_f1 = test(val_loader)
    scheduler.step(val_acc)
        
    training_history.append({
                "epoch": epoch,
                "train_loss": train_loss,
                "train_acc": train_acc,
                "train_f1": train_f1,
                "val_acc": val_acc,
                "val_f1": val_f1,
                "learning_rate": optimizer.param_groups[0]['lr']   
    })

    if val_acc > best_model_stats[3]:
        best_model_stats = [train_loss,
                            train_acc, train_f1, val_acc, val_f1]
        patience_counter = 0
        torch.save(model.state_dict(), "/kaggle/working/best_model.pt")
    else:
        patience_counter += 1

    if patience_counter >= patience:
            print(f"Early stopping at epoch {epoch}.")
            break

    print(f"Epoch {epoch:02d} | "
            f"Train Loss: {train_loss:.4f} | "
            f"Train Acc: {train_acc:.4f} | "
            f"Train F1:  {train_f1:.4f} | "
            f"Val Acc:   {val_acc:.4f} | "
            f"Val F1:    {val_f1:.4f}")


history_df = pd.DataFrame(training_history)
history_df.to_csv("/kaggle/working/training_history.csv", index=False)
    
print("================================\n")
print("Best Model Stats:")
print(f"Train Loss: {best_model_stats[0]:.4f} | "
        f"Train Acc: {best_model_stats[1]:.4f} | "
        f"Train F1:  {best_model_stats[2]:.4f} | "
        f"Val Acc:   {best_model_stats[3]:.4f} | "
        f"Val F1:    {best_model_stats[4]:.4f}")

Epoch 01 | Train Loss: 0.6431 | Train Acc: 0.6769 | Train F1:  0.7127 | Val Acc:   0.6758 | Val F1:    0.7096
Epoch 02 | Train Loss: 0.5782 | Train Acc: 0.7073 | Train F1:  0.7408 | Val Acc:   0.7030 | Val F1:    0.7368
Epoch 03 | Train Loss: 0.5553 | Train Acc: 0.7267 | Train F1:  0.7725 | Val Acc:   0.7247 | Val F1:    0.7708
Epoch 04 | Train Loss: 0.5380 | Train Acc: 0.7301 | Train F1:  0.7644 | Val Acc:   0.7283 | Val F1:    0.7619
Epoch 05 | Train Loss: 0.5272 | Train Acc: 0.7400 | Train F1:  0.7886 | Val Acc:   0.7369 | Val F1:    0.7867
Epoch 06 | Train Loss: 0.5160 | Train Acc: 0.7505 | Train F1:  0.7924 | Val Acc:   0.7505 | Val F1:    0.7925
Epoch 07 | Train Loss: 0.5063 | Train Acc: 0.7531 | Train F1:  0.7923 | Val Acc:   0.7505 | Val F1:    0.7889
Epoch 08 | Train Loss: 0.4984 | Train Acc: 0.7563 | Train F1:  0.7990 | Val Acc:   0.7522 | Val F1:    0.7952
Epoch 09 | Train Loss: 0.4925 | Train Acc: 0.7644 | Train F1:  0.8015 | Val Acc:   0.7621 | Val F1:    0.7990
Epoch 10 |