In [None]:
# @title Import dependencies

import pickle
from tqdm import tqdm
import torch
from torch_geometric.nn import GATConv, global_mean_pool
from torch_geometric.data import Dataset, InMemoryDataset
from torch_geometric.loader import DataLoader
from torch_geometric.utils import from_networkx
import os

In [None]:
# @title Import graphs and labels

graphs_path = "../graphs/all_graphs.pkl"
with open(graphs_path, "rb") as graphs_file:
    graphs = pickle.load(graphs_file)

print(f"number of graphs: {len(graphs)}")
for i in range(5):
    print(f"{i}: {graphs[i]}")


labels_path = "../graphs/all_graphs/binary_labels.pkl"
with open(labels_path, "rb") as labels_file:
    labels = pickle.load(labels_file)

print(f"number of labels: {len(labels)}")
for i in range(5):
    print(f"{i}: {labels[i]}")

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device = {device}")

In [None]:
# @title Hyperparameters

hyperparameters = {
    "num_classes": 2,
    "node_features_dim": 1,
    "hidden_dim": 16,
    "num_heads": 8,
    "num_layers": 4,
    "num_epochs": 20,
    "learning_rate": 0.001,
    "batch_size": 2,
    "patience": 100,
    "weight_decay": 1e-4,
}

In [None]:
#  @title Convert networkx to Data objects


class GraphDataset(InMemoryDataset):
    def __init__(self, graphs, labels, transform=None):
        super().__init__(None, transform)
        data_list = []
        for graph, label in tqdm(zip(graphs, labels)):
            data = from_networkx(graph)
            if data.x is None:
                data.x = torch.ones(
                    data.num_nodes,
                    hyperparameters["node_features_dim"],
                    dtype=torch.float,
                )
                data.y = torch.tensor([label], dtype=torch.long)
            data_list.append(data)
            self.data, self.slices = self.collate(data_list)

In [None]:
# @title Save and load the dataset

dataset_path = "../data/datasets/processed_dataset.pkl"

if os.path.exists(dataset_path):
    print("Loading preprocessed dataset...")
    with open(dataset_path, "rb") as f:
        dataset = pickle.load(f)
    print(f"Loaded dataset with {len(dataset)} graphs")
else:
    print("Processing dataset for the first time...")
    dataset = GraphDataset(graphs, labels)

    # Save the processed dataset
    with open(dataset_path, "wb") as f:
        pickle.dump(dataset, f)
    print(f"Saved processed dataset to {dataset_path}")

for i in range(5):
    print(f"{i}: {dataset[i]}")

In [None]:
# @title Create DataLoader from dataset

n_total = len(dataset)
n_train = int(0.8 * n_total)
n_val = int(0.1 * n_total)
n_test = n_total - n_train - n_val

print(n_total, n_train, n_val, n_test)

torch.manual_seed(42)
train_dataset, val_dataset, test_dataset = torch.utils.data.random_split(
    dataset, [n_train, n_val, n_test]
)

train_loader = DataLoader(
    train_dataset, batch_size=hyperparameters["batch_size"], shuffle=True
)
val_loader = DataLoader(
    val_dataset, batch_size=hyperparameters["batch_size"], shuffle=False
)
test_loader = DataLoader(
    test_dataset, batch_size=hyperparameters["batch_size"], shuffle=False
)


print(len(train_loader), len(val_loader), len(test_loader))

In [None]:
# @title GAT


class GAT(torch.nn.Module):
    def __init__(
        self,
        num_classes: int,
        node_features_dim: int,
        hidden_dim: int,
        num_heads: int,
        num_layers: int,
    ):
        super().__init__()
        self.num_layers = num_layers
        self.hidden_dim = hidden_dim
        self.num_heads = num_heads

        self.input_projection = torch.nn.Linear(node_features_dim, hidden_dim)
        self.layers = torch.nn.ModuleList(
            [
                GATConv(
                    in_channels=hidden_dim,
                    out_channels=hidden_dim // num_heads,
                    heads=num_heads,
                    dropout=0.2,
                )
                for _ in range(num_layers)
            ]
        )
        self.pool = global_mean_pool
        self.classifier = torch.nn.Sequential(
            torch.nn.Linear(hidden_dim, hidden_dim // 2),
            torch.nn.ReLU(),
            torch.nn.Linear(hidden_dim // 2, num_classes),
        )

    def forward(self, data):
        """
        Args:
            data: Batch object from torch_geometric.loader.DataLoader
        """
        x, edge_index, batch = data.x, data.edge_index, data.batch
        x = self.input_projection(x)

        for layer in self.layers:
            x = layer(x, edge_index)

        x = self.pool(x, batch)

        out = self.classifier(x)
        return out

In [None]:
# @title Training helper functions


def train_epoch(model, device, train_loader, optimizer, loss_function):
    model.train()
    total_loss = 0
    total_correct = 0
    total_samples = 0

    for batch in train_loader:
        batch = batch.to(device)
        optimizer.zero_grad()

        output = model(batch)
        loss = loss_function(output, batch.y)

        loss.backward()
        optimizer.step()

        total_loss += loss.item()
        prediction = output.argmax(dim=1)
        total_correct += (prediction == batch.y).sum().item()
        total_samples += len(batch.y)

    average_loss = total_loss / len(train_loader)
    accuracy = total_correct / total_samples

    return average_loss, accuracy


def validate_epoch(model, device, val_loader, loss_function):
    model.eval()
    total_loss = 0
    total_correct = 0
    total_samples = 0

    with torch.no_grad():
        for batch in val_loader:
            batch = batch.to(device)
            output = model(batch)
            loss = loss_function(output, batch.y)

            total_loss += loss.item()
            prediction = output.argmax(dim=1)
            total_correct += (prediction == batch.y).sum().item()
            total_samples += len(batch.y)

    average_loss = total_loss / len(val_loader)
    accuracy = total_correct / total_samples

    return average_loss, accuracy


def train(model, device, train_loader, val_loader):
    optimizer = torch.optim.AdamW(
        model.parameters(),
        lr=hyperparameters["learning_rate"],
        weight_decay=hyperparameters["weight_decay"],
    )
    loss_function = torch.nn.CrossEntropyLoss()

    train_losses = []
    val_losses = []
    train_accuracies = []
    val_accuracies = []

    best_val_loss = float("inf")
    patience_counter = 0
    best_model_state = None

    print("Starting training...")
    for epoch in range(hyperparameters["num_epochs"]):
        train_loss, train_acc = train_epoch(
            model, device, train_loader, optimizer, loss_function
        )
        val_loss, val_acc = validate_epoch(model, device, val_loader, loss_function)

        train_losses.append(train_loss)
        val_losses.append(val_loss)
        train_accuracies.append(train_acc)
        val_accuracies.append(val_acc)

        if val_loss < best_val_loss:
            best_val_loss = val_loss
            patience_counter = 0
            best_model_state = model.state_dict().copy()
        else:
            patience_counter += 1

        if epoch % 10 == 0 or patience_counter >= hyperparameters["patience"]:
            print(
                f"Epoch {epoch:3d} | Train Loss: {train_loss:.4f} | "
                f"Train Acc: {train_acc:.4f} | Val Loss: {val_loss:.4f} | "
                f"Val Acc: {val_acc:.4f}"
            )

        if patience_counter >= hyperparameters["patience"]:
            print(f"Early stopping at epoch {epoch}")
            break

    if best_model_state is not None:
        model.load_state_dict(best_model_state)

    print("Training completed.")
    return model

In [None]:
# @title Training

model = GAT(
    hyperparameters["num_classes"],
    hyperparameters["node_features_dim"],
    hyperparameters["hidden_dim"],
    hyperparameters["num_heads"],
    hyperparameters["num_layers"],
)
model = model.to(device)

print(
    f"Model initialized with "
    f"{sum(p.numel() for p in model.parameters())} parameters"
)

torch.cuda.empty_cache()
train(model, device, train_loader, val_loader)

In [None]:
# @title Additional Notes

# Not real transformer

# It lacks:
# - positional encoding (how to do this in graph context?)
# - only local attention - no dense self-attention
# - no FFNs between attention layers
# - no residual connections