In [None]:
# @title Import dependencies and clear CUDA cache

import pickle
from tqdm import tqdm
import torch
from torch_geometric.nn import global_mean_pool
from torch_geometric.data import Dataset, InMemoryDataset
from torch_geometric.loader import DataLoader
from torch_geometric.utils import from_networkx, to_dense_batch
import os

torch.cuda.empty_cache()

In [None]:
# @title Mount Google Drive (if needed)

# from google.colab import drive
# drive.mount('/content/drive')

In [None]:
# @title Import graphs and labels

# path = '/content/drive/My Drive/Colab Notebooks/data/ML4RG'
path = '../data'

with open(f"{path}/graphs.pkl", "rb") as graphs_file:
    graphs = pickle.load(graphs_file)

print(f"number of graphs: {len(graphs)}")
for i in range(5):
    print(f"{i}: {graphs[i]}")


with open(f"{path}/binary_labels.pkl", "rb") as labels_file:
    labels = pickle.load(labels_file)

print(f"number of labels: {len(labels)}")
for i in range(5):
    print(f"{i}: {labels[i]}")

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device = {device}")

In [None]:
# @title Hyperparameters

hyperparameters = {
    "num_classes": 2,
    "node_features_dim": 1,
    "hidden_dim": 8,
    "num_heads": 2,
    "num_layers": 2,
    "num_epochs": 20,
    "learning_rate": 0.005,
    "batch_size": 1,
    "patience": 5,
    "weight_decay": 1e-4,
    "dropout": 0.2,
}

In [None]:
#  @title Convert networkx to Data objects


class GraphDataset(InMemoryDataset):
    def __init__(self, graphs, labels, transform=None):
        super().__init__(None, transform)

        if len(graphs) != len(labels):
            raise ValueError(f"Number of graphs ({len(graphs)}) must match number of labels ({len(labels)})")

        data_list = []
        for graph, label in tqdm(zip(graphs, labels)):
            data = from_networkx(graph)
            if data.x is None:
                data.x = torch.ones(
                    data.num_nodes,
                    hyperparameters["node_features_dim"],
                    dtype=torch.float,
                )

            data.y = torch.tensor([label], dtype=torch.long)
            data_list.append(data)

        self.data, self.slices = self.collate(data_list)

In [None]:
# @title Save and load the dataset

dataset_path = f"{path}/processed_dataset.pkl"

if os.path.exists(dataset_path):
    print("Loading preprocessed dataset...")
    with open(dataset_path, "rb") as f:
        dataset = pickle.load(f)
    print(f"Loaded dataset with {len(dataset)} graphs")
else:
    print("Processing dataset for the first time...")
    dataset = GraphDataset(graphs, labels)

    # Save the processed dataset
    with open(dataset_path, "wb") as f:
        pickle.dump(dataset, f)
    print(f"Saved processed dataset to {dataset_path}")

for i in range(5):
    print(f"{i}: {dataset[i]}")

In [None]:
# @title Determine max node degree

max_degree = 0
for data in dataset:
    if data.edge_index.size(1) > 0:
        all_nodes = torch.cat([data.edge_index[0], data.edge_index[1]])
        unique_nodes, degree_counts = torch.unique(all_nodes, return_counts=True)
        max_degree = max(max_degree, degree_counts.max().item())

print(f"Maximum node degree in the dataset: {max_degree}")

In [None]:
# @title Create DataLoader from dataset

n_total = len(dataset)
n_train = int(0.8 * n_total)
n_val = int(0.1 * n_total)
n_test = n_total - n_train - n_val

print(n_total, n_train, n_val, n_test)

torch.manual_seed(42)
train_dataset, val_dataset, test_dataset = torch.utils.data.random_split(
    dataset, [n_train, n_val, n_test]
)

train_loader = DataLoader(
    train_dataset, batch_size=hyperparameters["batch_size"], shuffle=True
)
val_loader = DataLoader(
    val_dataset, batch_size=hyperparameters["batch_size"], shuffle=False
)
test_loader = DataLoader(
    test_dataset, batch_size=hyperparameters["batch_size"], shuffle=False
)


print(len(train_loader), len(val_loader), len(test_loader))

In [None]:
# @title Graph Transformer Layer

class GraphTransformerLayer(torch.nn.Module):
    def __init__(self, hidden_dim: int, num_heads: int, dropout: float):
        super().__init__()

        self.hidden_dim = hidden_dim
        self.num_heads = num_heads
        self.head_dim = hidden_dim // num_heads

        self.query = torch.nn.Linear(hidden_dim, hidden_dim)
        self.key = torch.nn.Linear(hidden_dim, hidden_dim)
        self.value = torch.nn.Linear(hidden_dim, hidden_dim)

        self.ffn = torch.nn.Sequential(
            torch.nn.Linear(hidden_dim, hidden_dim * 4),
            torch.nn.GELU(),
            torch.nn.Linear(hidden_dim * 4, hidden_dim)
        )

        self.norm1 = torch.nn.LayerNorm(hidden_dim)
        self.norm2 = torch.nn.LayerNorm(hidden_dim)
        self.dropout = torch.nn.Dropout(dropout)

    def forward(self, x: torch.Tensor, batch: torch.Tensor):
        """
        Args:
            x: Node features [num_nodes, hidden_dim]
            batch: Batch assignment [num_nodes]
        """

        x_dense, mask = to_dense_batch(x, batch)
        x_norm = self.norm1(x_dense)
        batch_size, num_nodes, _ = x_dense.shape
        residual = x_dense

        q = self.query(x_norm)
        k = self.key(x_norm)
        v = self.value(x_norm)

        q = q.view(batch_size, num_nodes, self.num_heads, self.head_dim).transpose(1, 2)
        k = k.view(batch_size, num_nodes, self.num_heads, self.head_dim).transpose(1, 2)
        v = v.view(batch_size, num_nodes, self.num_heads, self.head_dim).transpose(1, 2)

        attention_scores = (q @ k.transpose(-2, -1)) / (self.head_dim ** 0.5)
        attention_mask = mask.unsqueeze(1).unsqueeze(2)
        attention_scores.masked_fill_(~attention_mask, -1e9)

        attention = torch.softmax(attention_scores, dim=-1)
        out = attention @ v

        out = out.transpose(1, 2).contiguous().view(batch_size, num_nodes, self.hidden_dim)
        out = self.dropout(out)
        out = out + residual

        residual = out
        out = self.norm2(out)
        out = self.ffn(out)
        out = self.dropout(out)
        out = out + residual
        out_flat = out[mask]
        return out_flat

In [None]:
# @title Positional encoding based on node degree

class PositionalEncoding(torch.nn.Module):
    def __init__(self, hidden_dim: int, max_degree: int):
        super().__init__()
        self.degree_embedding = torch.nn.Embedding(max_degree + 1, hidden_dim) # +1 for degree 0

    def forward(self, x: torch.Tensor, edge_index: torch.Tensor):
        """
        Args:
            x: Node features [num_nodes, hidden_dim]
            edge_index: Edge indices [2, num_edges]
        """

        node_degrees = torch.zeros(x.size(0), dtype=torch.long, device=x.device)

        if edge_index.size(1) > 0:
            all_nodes = torch.cat([edge_index[0], edge_index[1]])
            unique_nodes, degree_counts = torch.unique(all_nodes, return_counts=True)
            node_degrees[unique_nodes] = degree_counts

        pos_encoding = self.degree_embedding(node_degrees)
        return x + pos_encoding

In [None]:
# @title Graph Transformer

class GraphTransformer(torch.nn.Module):
    def __init__(
        self,
        num_classes: int,
        node_features_dim: int,
        hidden_dim: int,
        num_heads: int,
        num_layers: int,
        dropout: float,
        max_degree: int
    ):
        super().__init__()

        if hidden_dim % num_heads != 0:
            raise ValueError(f"hidden_dim ({hidden_dim}) must be divisible by num_heads ({num_heads})")

        self.hidden_dim = hidden_dim
        self.input_projection = torch.nn.Linear(node_features_dim, hidden_dim)
        self.pos_encoding = PositionalEncoding(hidden_dim, max_degree)
        self.layers = torch.nn.ModuleList([
            GraphTransformerLayer(hidden_dim, num_heads, dropout)
            for _ in range(num_layers)
        ])
        self.pool = global_mean_pool
        self.classifier = torch.nn.Sequential(
            torch.nn.Linear(hidden_dim, hidden_dim // 2),
            torch.nn.GELU(),
            torch.nn.Dropout(dropout),
            torch.nn.Linear(hidden_dim // 2, num_classes)
        )

    def forward(self, data):
        """
        Args:
            data: Data object containing node features, edge indices, and batch assignments
        """

        x, edge_index, batch = data.x, data.edge_index, data.batch
        x = self.input_projection(x)
        x = self.pos_encoding(x, edge_index)

        for layer in self.layers:
            x = layer(x, batch)

        x = self.pool(x, batch)
        out = self.classifier(x)
        return out

In [None]:
# @title Training helper functions


def train_epoch(model, device, train_loader, optimizer, loss_function):
    model.train()
    total_loss = 0
    total_correct = 0
    total_samples = 0

    for batch in train_loader:
        batch = batch.to(device)
        optimizer.zero_grad()

        output = model(batch)
        loss = loss_function(output, batch.y)

        loss.backward()
        optimizer.step()

        total_loss += loss.item()
        prediction = output.argmax(dim=1)
        total_correct += (prediction == batch.y).sum().item()
        total_samples += len(batch.y)

    average_loss = total_loss / len(train_loader)
    accuracy = total_correct / total_samples

    return average_loss, accuracy


def validate_epoch(model, device, val_loader, loss_function):
    model.eval()
    total_loss = 0
    total_correct = 0
    total_samples = 0

    with torch.no_grad():
        for batch in val_loader:
            batch = batch.to(device)
            output = model(batch)
            loss = loss_function(output, batch.y)

            total_loss += loss.item()
            prediction = output.argmax(dim=1)
            total_correct += (prediction == batch.y).sum().item()
            total_samples += len(batch.y)

    average_loss = total_loss / len(val_loader)
    accuracy = total_correct / total_samples

    return average_loss, accuracy


def train(model, device, train_loader, val_loader):
    optimizer = torch.optim.AdamW(
        model.parameters(),
        lr=hyperparameters["learning_rate"],
        weight_decay=hyperparameters["weight_decay"],
    )
    loss_function = torch.nn.CrossEntropyLoss()

    train_losses = []
    val_losses = []
    train_accuracies = []
    val_accuracies = []

    best_val_loss = float("inf")
    patience_counter = 0
    best_model_state = None

    print("Starting training...")
    for epoch in range(hyperparameters["num_epochs"]):
        train_loss, train_acc = train_epoch(
            model, device, train_loader, optimizer, loss_function
        )
        val_loss, val_acc = validate_epoch(model, device, val_loader, loss_function)

        train_losses.append(train_loss)
        val_losses.append(val_loss)
        train_accuracies.append(train_acc)
        val_accuracies.append(val_acc)

        if val_loss < best_val_loss:
            best_val_loss = val_loss
            patience_counter = 0
            best_model_state = model.state_dict().copy()
        else:
            patience_counter += 1

        if epoch % 10 == 0 or patience_counter >= hyperparameters["patience"]:
            print(
                f"Epoch {epoch:3d} | Train Loss: {train_loss:.4f} | "
                f"Train Acc: {train_acc:.4f} | Val Loss: {val_loss:.4f} | "
                f"Val Acc: {val_acc:.4f}"
            )

        if patience_counter >= hyperparameters["patience"]:
            print(f"Early stopping at epoch {epoch}")
            break

    if best_model_state is not None:
        model.load_state_dict(best_model_state)

    print("Training completed.")

In [None]:
# @title Training


model = GraphTransformer(
    num_classes=hyperparameters["num_classes"],
    node_features_dim=hyperparameters["node_features_dim"],
    hidden_dim=hyperparameters["hidden_dim"],
    num_heads=hyperparameters["num_heads"],
    num_layers=hyperparameters["num_layers"],
    dropout=hyperparameters["dropout"],
    max_degree=max_degree
)
model = model.to(device)

print(
    f"Model initialized with "
    f"{sum(p.numel() for p in model.parameters())} parameters"
)

torch.cuda.empty_cache()
train(model, device, train_loader, val_loader)