# Text Classification - Training a GAT with PyTorch Geometric and Distance Sampler


## $\color{blue}{Sections:}$

* Preamble
* Admin
* Dataset
* Model
* Sampling
* Train - Validate
* Test Predictions


## $\color{blue}{Preamble:}$

We now train a GAT in PyTorch Geometric. We will keep the network quite close to the previous version. Note poor performance using GAT in this problem. A stable versioning between torch-spare, torch, and torch-geometric is required.

## $\color{blue}{Admin}$
* Install relevant Libraries
* Import relevant Libraries

In [None]:
import torch
import pandas as pd
from google.colab import drive
import numpy as np
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

cuda


In [None]:
drive.mount("/content/drive")
%cd '/content/drive/MyDrive'

Mounted at /content/drive
/content/drive/MyDrive


In [None]:
import torch
!pip uninstall torch-scatter torch-sparse torch-geometric torch-cluster  --y
!pip install torch-scatter -f https://data.pyg.org/whl/torch-{torch.__version__}.html
!pip install torch-sparse -f https://data.pyg.org/whl/torch-{torch.__version__}.html
!pip install torch-cluster -f https://data.pyg.org/whl/torch-{torch.__version__}.html
!pip install git+https://github.com/pyg-team/pytorch_geometric.git

[0mLooking in links: https://data.pyg.org/whl/torch-2.5.1+cu124.html
Collecting torch-scatter
  Downloading https://data.pyg.org/whl/torch-2.5.0%2Bcu124/torch_scatter-2.1.2%2Bpt25cu124-cp311-cp311-linux_x86_64.whl (10.8 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m10.8/10.8 MB[0m [31m44.6 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: torch-scatter
Successfully installed torch-scatter-2.1.2+pt25cu124
Looking in links: https://data.pyg.org/whl/torch-2.5.1+cu124.html
Collecting torch-sparse
  Downloading https://data.pyg.org/whl/torch-2.5.0%2Bcu124/torch_sparse-0.6.18%2Bpt25cu124-cp311-cp311-linux_x86_64.whl (5.2 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m5.2/5.2 MB[0m [31m18.3 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: torch-sparse
Successfully installed torch-sparse-0.6.18+pt25cu124
Looking in links: https://data.pyg.org/whl/torch-2.5.1+cu124.html
Collecting torch-cluster
  Downloading https://da

In [None]:
import torch_geometric

## $\color{blue}{Data}$

* Connect to Drive
* Load the data
* Load adjacency matrices
* Instantiate PyTorch Geometric Data objects

In [None]:
path = 'class/datasets/'
df_train = pd.read_pickle(path + 'df_train')
df_dev = pd.read_pickle(path + 'df_dev')
df_test = pd.read_pickle(path + 'df_test')

In [None]:
path = 'class/tensors/adj_{}.pt'

# train
train_people = torch.load(path.format('train_people'), weights_only=True)
train_locations = torch.load(path.format('train_locations'), weights_only=True)
train_entities = torch.load(path.format('train_entities'), weights_only=True)

# dev
dev_people = torch.load(path.format('dev_people'), weights_only=True)
dev_locations = torch.load(path.format('dev_locations'), weights_only=True)
dev_entities = torch.load(path.format('dev_entities'), weights_only=True)

# val (contains the adjacency matrix for both the training and the development set)
val_people = torch.load(path.format('val_people.1'), weights_only=True)
val_locations = torch.load(path.format('val_locations.1'), weights_only=True)
val_entities = torch.load(path.format('val_entities.1'), weights_only=True)

In [None]:
train_entities +=  torch.eye(train_entities.size(0), device=train_entities.device)  # Identity matrix
val_entities +=  torch.eye(val_entities.size(0), device=val_entities.device)  # Identity matrix

In [None]:
df1 = df_train[['index', 'chapter_idx', 'vanilla_embedding.1']]
df2 = df_dev[['index', 'chapter_idx', 'vanilla_embedding.1']]
df_val = pd.concat([df2,df1])

In [None]:
# inputs
H_train = torch.stack([torch.tensor(el) for el in list(df_train['vanilla_embedding.1'])]).to(device)
labels_train = torch.LongTensor(list(df_train['chapter_idx'])).to(device)

H_dev = torch.stack([torch.tensor(el) for el in list(df_dev['vanilla_embedding.1'])]).to(device)
labels_dev = torch.LongTensor(list(df_dev['chapter_idx'])).to(device)

H_val = torch.stack([torch.tensor(el) for el in list(df_val['vanilla_embedding.1'])]).to(device)
labels_val = torch.LongTensor(list(df_val['chapter_idx'])).to(device)

  H_train = torch.stack([torch.tensor(el) for el in list(df_train['vanilla_embedding.1'])]).to(device)
  H_dev = torch.stack([torch.tensor(el) for el in list(df_dev['vanilla_embedding.1'])]).to(device)
  H_val = torch.stack([torch.tensor(el) for el in list(df_val['vanilla_embedding.1'])]).to(device)


In [None]:
# train relationships where edge index is a tuple [0][0] > [1][0] The first element of list one, links to first element of list 2
train_edge_index = train_entities.nonzero(as_tuple=True)
train_edge_index = torch.stack(train_edge_index).long().to(device)
# train_edge_relation = torch.zeros(train_edge_index.size(1), dtype=torch.long)

dev_edge_index = dev_entities.nonzero(as_tuple=True)
dev_edge_index = torch.stack(dev_edge_index).long().to(device)
# dev_edge_relation = torch.zeros(dev_edge_index.size(1), dtype=torch.long)

val_edge_index = val_entities.nonzero(as_tuple=True)
val_edge_index = torch.stack(val_edge_index).long().to(device)
# val_edge_relation = torch.zeros(val_edge_index.size(1), dtype=torch.long)

In [None]:
from torch_geometric.data import Data

train_data = Data(x=H_train, edge_index=train_edge_index, y=labels_train)
dev_data = Data(x=H_dev, edge_index=dev_edge_index, y=labels_dev)
val_data = Data(x=H_val, edge_index=val_edge_index, y=labels_val)

In [None]:
import torch
import torch.nn.functional as F

def create_closest_neighbors_dict(embedding_matrix, adjacency_matrix, k=4):
    """
    Create a dictionary of closest neighbors based on cosine similarity.

    Parameters:
    - embedding_matrix: (n x d) tensor where n is the number of nodes and d is the embedding dimension.
    - adjacency_matrix: (n x n) tensor representing the graph connectivity.
    - k: Number of closest neighbors to find for each node.

    Returns:
    - closest_neighbor_indices_dict: A dictionary where keys are node indices and values are lists of closest neighbor indices.
    """
    num_nodes = embedding_matrix.size(0)
    closest_neighbor_indices_dict = {}

    # Iterate over each node
    for i in range(num_nodes):
        # Get similarities with all other nodes
        # Use only neighbors defined by the adjacency matrix
        neighbor_indices = adjacency_matrix[i].nonzero(as_tuple=True)[0].to(device)  # Indices of neighbors

        # Calculate cosine similarities if there are neighbors
        if neighbor_indices.numel() > 0:
            similarities = F.cosine_similarity(embedding_matrix[i].unsqueeze(0), embedding_matrix[neighbor_indices], dim=1)
            # Get the top k neighbor indices based on similarities
            if similarities.size(0) < k:
                top_k_vals, top_k_indices = similarities.topk(similarities.size(0))
                top_k_vals = [el.item() for el in top_k_vals]
            else:
                top_k_vals, top_k_indices = similarities.topk(k)
                top_k_vals = [el.item() for el in top_k_vals]

            closest_neighbors = neighbor_indices[top_k_indices].tolist() # Convert to list

            closest_neighbor_indices_dict[i] = list(zip(closest_neighbors, top_k_vals))
        else:
            # If no neighbors, return an empty list
            closest_neighbor_indices_dict[i] = []

    return closest_neighbor_indices_dict

In [None]:
train_closest_neighbors = create_closest_neighbors_dict(H_train, train_entities)
val_closest_neighbors = create_closest_neighbors_dict(H_val, val_entities)

## $\color{blue}{Model}$


In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import MessagePassing
from torch_geometric.utils import add_self_loops, degree

class GATLayer(MessagePassing):
    def __init__(self, in_features, out_channels, dropout=0.45):
        super(GATLayer, self).__init__(aggr='add')  # Use 'add' for aggregation.
        self.in_features = in_features
        self.out_features = out_channels
        self.dropout = dropout

        # Linear transformations for queries, keys, and values
        self.Wq = nn.Linear(in_features, out_channels)
        self.Wk = nn.Linear(in_features, out_channels)
        self.Wv = nn.Linear(in_features, out_channels)

        # Boost self
        self.self_bias = nn.Parameter(torch.Tensor(1))  # Bias term for self-loops
        nn.init.constant_(self.self_bias, 1.0)
        self.leakyrelu = nn.LeakyReLU()

        self.batch_norm = nn.BatchNorm1d(out_channels)
        self.reset_parameters()

    def reset_parameters(self):
        for layer in [self.Wq, self.Wk, self.Wv]:
            nn.init.xavier_uniform_(layer.weight)

    def forward(self, x, edge_index):
        # Transform node features into Q, K, V
        H_q = self.Wq(x)  # (N, out_channels)
        H_k = self.Wk(x)  # (N, out_channels)
        H_v = self.Wv(x)  # (N, out_channels)

        edge_index = torch.stack([edge_index.coo()[0], edge_index.coo()[1]])

        # Propagate messages using the edge index
        out = self.propagate(edge_index, x=H_v, H_q=H_q, H_k=H_k)

        # Apply dropout and batch normalization
        out = F.dropout(out, p=self.dropout)
        out = self.batch_norm(out)

        return out

    def message(self, x_j, H_q, H_k, edge_index):

        # flag = True
        # Calculate attention scores. Assuming H_k and H_q are already of shape (N, out_channels)
        E = torch.matmul(H_q, H_k.transpose(0, 1))  # (N, N) attention scores

        n = H_q.size(0)
        I = torch.eye(n, device=x_j.device)  # Identity matrix
        E += self.self_bias * I  # Incorporate self-attention bias

        row = edge_index[0]
        col = edge_index[1]

        attention = E[row, col]  # Correct attention for each directed edge

        # Apply softmax to get normalized attention weights for each node
        attention = F.softmax(attention, dim=-1)

        # Weight the neighbor features by the attention coefficients
        weighted_messages = attention.view(-1, 1) * x_j  # Scale by attention scores

        return weighted_messages  # Return the weighted messages


In [None]:
class GNNModel(nn.Module):
    def __init__(self, d, h, c, num_layers=2, dropout_rate=0.42):
        super(GNNModel, self).__init__()
        self.num_layers = num_layers
        self.gnn_layers = nn.ModuleList([GATLayer(d, d, dropout_rate) for _ in range(num_layers)])
        self.fc1 = nn.Linear(d, h)
        self.batch_norm_fc1 = nn.BatchNorm1d(h)
        self.fc2 = nn.Linear(h, c)
        self.dropout = nn.Dropout(dropout_rate)
        self.relu = nn.ReLU()

    def forward(self, x, edge_index):
        #print('############# to GAT Layers #############')
        for layer in self.gnn_layers:
            x = layer(x, edge_index)

        x = self.relu(self.batch_norm_fc1(self.dropout(self.fc1(x))))
        Output = self.fc2(x)
        return Output

    def forward_layer(self, x, edge_index, layer_idx):
        """Forward pass for a specific layer."""
        x = self.gnn_layers[layer_idx](x, edge_index)
        return x

In [None]:
d = 768
h = 400   # hidden dimension of fully connected layer
c = 70   # number of classes
num_relations = 2   # number of relationship types

# Model, Loss, Optimizer
model = GNNModel(d,h,c)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.AdamW(model.parameters(), lr=0.001)

In [None]:
def count_parameters_per_module(model):
    print("Module and parameter counts:")

    for name, module in model.named_modules():
        # Skip the top-level module (the model itself)
        if not isinstance(module, nn.Module) or name == "":
            continue

        param_count = sum(p.numel() for p in module.parameters() if p.requires_grad)

        if param_count > 0:  # Only print modules that have parameters
            print(f"{name}: {param_count} parameters")

In [None]:
count_parameters_per_module(model)

Module and parameter counts:
gnn_layers: 3546626 parameters
gnn_layers.0: 1773313 parameters
gnn_layers.0.Wq: 590592 parameters
gnn_layers.0.Wk: 590592 parameters
gnn_layers.0.Wv: 590592 parameters
gnn_layers.0.batch_norm: 1536 parameters
gnn_layers.1: 1773313 parameters
gnn_layers.1.Wq: 590592 parameters
gnn_layers.1.Wk: 590592 parameters
gnn_layers.1.Wv: 590592 parameters
gnn_layers.1.batch_norm: 1536 parameters
fc1: 307600 parameters
batch_norm_fc1: 800 parameters
fc2: 28070 parameters


## $\color{blue}{Sampling}$


In [None]:
import torch
from torch import Tensor
from typing import Callable, List, NamedTuple, Optional, Tuple, Union
from torch_geometric.loader import NeighborSampler
from torch_geometric.typing import SparseTensor
from collections import defaultdict

class EdgeIndex(NamedTuple):
    edge_index: Tensor
    e_id: Optional[Tensor]
    size: Tuple[int, int]

    def to(self, *args, **kwargs):
        edge_index = self.edge_index.to(*args, **kwargs)
        e_id = self.e_id.to(*args, **kwargs) if self.e_id is not None else None
        return EdgeIndex(edge_index, e_id, self.size)

class Adj(NamedTuple):
    adj_t: SparseTensor
    e_id: Optional[Tensor]
    size: Tuple[int, int]

    def to(self, *args, **kwargs):
        adj_t = self.adj_t.to(*args, **kwargs)
        e_id = self.e_id.to(*args, **kwargs) if self.e_id is not None else None
        return Adj(adj_t, e_id, self.size)

class CustomNeighborSampler(NeighborSampler):
    def __init__(self, edge_index, closest_neighbor_indices_dict, k_neighbor=4, **kwargs):
        super(CustomNeighborSampler, self).__init__(edge_index, **kwargs)
        self.closest_neighbor_indices_dict = closest_neighbor_indices_dict
        self.k_neighbor = k_neighbor

    def sample(self, node_idx):
        """Sample neighbors based on precomputed closest neighbor indices."""

        batch_size = len(node_idx)
        adjs = []
        n_id = torch.tensor(node_idx)
        # Sample first-hop neighbors
        first_hop_neighbors_dict = defaultdict(list)
        for node in n_id:
            neighbors = self.closest_neighbor_indices_dict.get(node.item(), [])
            # Select the top k neighbors
            sampled_neighbors = neighbors[:self.k_neighbor]
            sampled_neighbors = [x for x,y in sampled_neighbors]  # Sample first-hop neighbors
            first_hop_neighbors_dict[node.item()].extend(sampled_neighbors)

        # Flatten first-hop neighbors into a set for uniqueness
        first_hop_node_ids_set = set()
        for node_neighbors in first_hop_neighbors_dict.values():
            first_hop_node_ids_set.update(node_neighbors)  # Keep unique entries

        # Prepare to store second-hop neighbors
        second_hop_neighbors_dict = defaultdict(list)
        second_hop_node_ids_set = set()  # Keep unique second-hop IDs
        for node in first_hop_node_ids_set:
            neighbors = self.closest_neighbor_indices_dict.get(node, [])
            # Select the top k neighbors
            sampled_neighbors = neighbors[:self.k_neighbor]  # Sample second-hop neighbors
            # Filter out first-hop neighbors
            sampled_neighbors_filtered = [n[0] for n in sampled_neighbors if n[0] not in first_hop_node_ids_set]
            second_hop_neighbors_dict[node].extend(sampled_neighbors_filtered) # dict of tuples (ind, cosine similarity)


        second_hop_neighbors = [val for vals in second_hop_neighbors_dict.values() for val in vals]
        second_hop_node_ids_set = set(second_hop_neighbors)  # Save second-hop neighbors


        # Combine first-hop and second-hop nodes to n_id
        all_neighbors = first_hop_node_ids_set.union(second_hop_node_ids_set).union(set(node_idx))
        n_id = torch.tensor(list(all_neighbors))  # Update n_id to include all unique first and second hop neighbors

        # Create the adjacency tensor for both first-hop and second-hop neighbors
        adj_t = self.create_adj_tensor(first_hop_neighbors_dict, second_hop_neighbors_dict,n_id)

        # Append the adjacency structure
        adjs.append(adj_t)

        # Return the batch size, combined node IDs excluding seed nodes, and any adjacency structures
        return batch_size, n_id, adjs[::-1]  # Return updated n_id and adjacency list

    def create_adj_tensor(self, first_hop_neighbors_dict, second_hop_neighbors_dict, n_id):
        # Step 1: Create a combined dictionary from both first and second hop neighbors
        combined_neighbors = defaultdict(set)

        # Add first-hop neighbors
        for seed_node, neighbors in first_hop_neighbors_dict.items():
            combined_neighbors[seed_node].update(neighbors)

        # Add second-hop neighbors
        for first_hop_node, neighbors in second_hop_neighbors_dict.items():
            combined_neighbors[first_hop_node].update(neighbors)

        # Step 2: Create a node_id to index mapping
        mapping = {node: idx for idx, node in enumerate(n_id.numpy())}

        # Step 3: Fill row and column indices for the sparse tensor
        row_indices = []
        col_indices = []

        for node, neighbors in combined_neighbors.items():
            if node in mapping:  # Ensure the source node is in the mapping
                for neighbor in neighbors:
                    if neighbor in mapping:  # Ensure the neighbor is in the mapping
                        row_indices.append(mapping[node])
                        col_indices.append(mapping[neighbor])

        edge_index = torch.tensor([row_indices, col_indices], dtype=torch.long)

        # When creating the SparseTensor, ensure you are specifying correct sparse size
        edge_index_sparse = SparseTensor(
            row=edge_index[0],
            col=edge_index[1],
            sparse_sizes=(len(n_id), len(n_id))
        )

        # Instead of using sparse_size, use directly the 'sparse_sizes' tuple you defined.
        edge_index_obj = EdgeIndex(
            edge_index=edge_index_sparse,
            e_id=None,
            size=edge_index_sparse.sizes()  # Use the method or property for size
        )

        return [edge_index_obj]  # Return as a list containing the EdgeIndex object



In [None]:
edge_index = train_data.edge_index
# Get unique linked nodes from edge_index
linked_nodes = torch.unique(edge_index[0])  # Get source nodes
linked_nodes = torch.unique(torch.cat([edge_index[0], edge_index[1]]))  # Get both ends of edges

# Now you can pass linked_nodes to NeighborSampler
train_sampler = CustomNeighborSampler(
  train_data.edge_index,
  closest_neighbor_indices_dict = train_closest_neighbors,
  node_idx=linked_nodes,  # Use only linked nodes
  sizes=[4, 4],
  batch_size=32,
  shuffle=True,
  num_workers=0
)

In [None]:
# Now you can pass linked_nodes to NeighborSampler
val_sampler = CustomNeighborSampler(
  val_data.edge_index,
  closest_neighbor_indices_dict = val_closest_neighbors,
  node_idx=None, #torch.arange(964),  # Use only linked nodes
  sizes=[4, 4],
  batch_size=256,
  shuffle=False,
  num_workers=0
)

In [None]:
count = 0

for batch_size, n_id, adj in train_sampler:
  if count < 2:
    print(f'batch size: {batch_size}')
    print(f'n_id: {n_id}')
    print(f'n_id size: {n_id.size()}')
    print(f'adj: {adj}')
    count += 1
  break


batch size: 32
n_id: tensor([ 2562,  2055,  2057,  5131,  9745, 11793, 11797,  7702,  5658, 10273,
         6690,  5153,   552,  8745,  2601,  6188,  6189,  4143, 10288,  1583,
         6193, 11315,  3640,  5692,  5182,   574,  9798,  6727, 10823,  3657,
        10826,  7771,  1628,  6749,  5729,  4195,  3173, 10861,   113,  9851,
         2173,  2178,   643,  6276,  2693,   648, 10383,  1680,  8336,   149,
         6296,  3227, 11931,  9887,  8868,  2217,   170,  1707,  7851,   173,
         9900,  2733,  3255,  5307,  3266,  1730,  5317, 11461,  1222,  8392,
         3274,  1742,  9936,   721,   208,  6356,  6357,  5338,  6363,  1247,
         5856,  2786,  7394,   227,  6896,  5876,  5885,  8959,  5888,   775,
         6920, 10510,  5390,  8464, 11025,  9486,  2319,  4880,  2320, 10006,
         6427,  6940,   796,  7968, 10024, 10537,   811, 11052,  9516,  1326,
         7984,  9525,  5944,  6969,  7994,  6461,  4414,  8001,  5968,  1363,
         7507,  4436,  2389, 11607, 11610, 

## $\color{blue}{Train-Validate}$


In [None]:
def accuracy(outputs, labels):
    # argmax to get predicted classes
    _, predicted = torch.max(outputs, 1)

    # count correct
    correct = (predicted == labels).sum().item()

    # get average
    acc = correct / labels.size(0)  # Total number of samples
    return acc

In [None]:
import numpy as np

def train(model, sampler, criterion, optimizer, scheduler):
    model.train()
    epoch_train_losses = []
    epoch_train_accuracy = []
    for batch_size, n_id, adjs in sampler:
      optimizer.zero_grad()

      x = train_data.x[n_id].to(device)  ##### Change to train
      edge_index = adjs[0][0].edge_index.t().to(device)
      #print('################  to model #############')
      out = model(x, edge_index)
      y = train_data.y[n_id].to(device) #### Change to train


      train_loss = criterion(out, y)
      train_accuracy = accuracy(out, y)


      epoch_train_losses.append(train_loss.item())
      epoch_train_accuracy.append(train_accuracy)

      # Backpropagation and optimization
      train_loss.backward()
      optimizer.step()
    scheduler.step()

    return np.mean(epoch_train_losses), np.mean(epoch_train_accuracy)

In [None]:
import torch

def score(reals, preds):
  return (reals == preds).sum()/len(reals)

def validate(model, sampler, criterion):
    """
    Validate the model on the validation dataset using the provided sampler.

    Parameters:
    - model: The model to be evaluated.
    - sampler: The sampler to sample validation data.
    - criterion: The loss function used for evaluation.

    Returns:
    - dev_loss: The calculated loss on the validation data.
    - dev_accuracy: The calculated accuracy on the validation data.
    """

    model.eval()


    mask = df_dev.connected.to_numpy() # mask for validation points connected on the graph
    n = df_dev.shape[0] # cutoff for validation points

    with torch.no_grad():
        for batch_size, n_id, adjs in sampler:
            edge_index = adjs[0][0].edge_index.t().to(device)
            x = val_data.x[n_id].to(device)  # Assuming `data.x` is your node features
            #print('############### to validate ####################')
            out = model(x, edge_index)
            y = val_data.y[n_id].to(device)
            #print(y[10:20])

            loss = criterion(out, y)
            acc = accuracy(out, y)

            print('val loss', loss)
            print('val acc', acc)

            _, predicted = torch.max(out, 1)
            reals = y[:n]
            preds = predicted[:n]
            outs = out[:n,:]
            total_loss = criterion(outs, reals)
            total_acc = score(reals, preds)

            # connected
            reals_con = reals
            preds_con = preds
            connected_acc = score(reals_con, preds_con)

            # isolated
            reals_iso = reals
            preds_iso = preds
            isolated_acc = score(reals_iso, preds_iso)

    return  total_loss, total_acc, connected_acc, isolated_acc





In [None]:
import time

def tv_run(epochs, model, lr, alpha, max_accuracy, path, verbose = 0, trace=False):
  """
  Runs a training setup
  verbose == 1 - print model results
  verbose == 2 -> print epoch and model results
  """
  model = model.to(device)
  criterion = nn.CrossEntropyLoss()
  optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=alpha)

  #Warm-up and linear decay scheduler
  num_warmup_steps = int(0.1 * epochs)  # 10% of epochs for warm-up
  def lr_lambda(current_step):
      if current_step < num_warmup_steps:
          return float(current_step) / float(max(1, num_warmup_steps))
      return max(
          0.0, float(epochs - current_step) / float(max(1, epochs - num_warmup_steps))
      )

  scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)

  # Hold epoch stats
  train_losses = []
  train_accuracy = []
  dev_losses = []
  dev_accuracy = []
  connected_accuracy = []
  isolated_accuracy = []
  epoch_holder = []

  # Break if no improvement
  current_best = 0
  no_improvement = 0


  # Run epochs
  for epoch in range(epochs):

    # break out of epochs
    if no_improvement >= 60:
      break

    if trace:
      torch.cuda.reset_peak_memory_stats()  # Reset memory stats
      start_time = time.time()

    train_loss, train_acc = train(model, train_sampler, criterion, optimizer, scheduler)

    if trace:
      print("\n--- Profiling Results for Training Phase ---")
      training_time = time.time() - start_time  # Calculate elapsed time
      max_train_memory = torch.cuda.max_memory_allocated()  # Get max GPU memory used during training
      print(f'Time: {training_time}\nMax memory: {max_train_memory}')
      torch.cuda.reset_peak_memory_stats()  # Reset memory stats
      start_time = time.time()
      print("\n--- Profiling Results for Validation Phase ---")

    dev_loss, dev_acc, connected_acc, isolated_acc = validate(model, val_sampler, criterion)

    if trace:
      validation_time = time.time() - start_time  # Calculate elapsed time
      max_validation_memory = torch.cuda.max_memory_allocated()  # Get max GPU memory used during training
      print(f'Time: {validation_time}\nMax memory: {max_validation_memory}')

    # Store epoch stats
    train_losses.append(train_loss)
    train_accuracy.append(train_acc)
    dev_losses.append(dev_loss)
    dev_accuracy.append(dev_acc)
    connected_accuracy.append(connected_acc.item())
    isolated_accuracy.append(isolated_acc)
    epoch_holder.append(epoch + 1)

    # check for improvement
    if connected_acc > current_best:
      current_best = connected_acc
      no_improvement = 0
    else:
      no_improvement += 1

    # save best model
    if connected_acc > max_accuracy:
      torch.save(model.state_dict(), path)
      max_accuracy = connected_acc


    # optionally print epoch results
    if verbose == 2:
      print(f'\n --------- \nEpoch: {epoch + 1}\n')
      print(f'Epoch {epoch + 1} train loss: {train_loss:.4f}')
      print(f'Epoch {epoch + 1} train accuracy: {train_acc:.4f}')
      print(f'Epoch {epoch + 1} dev loss: {dev_loss:.4f}')
      print(f'Epoch {epoch + 1} dev accuracy: {dev_acc:.4f}')
      print(f'Epoch {epoch + 1} connected accuracy: {connected_acc:.4f}')
      print(f'Epoch {epoch + 1} isolated accuracy: {isolated_acc:.4f}')



      # save best results
  #print('T',connected_accuracy)
  max_ind = np.argmax(connected_accuracy)

  stats = Stats(
      train_losses[max_ind],
      train_accuracy[max_ind],
      dev_losses[max_ind],
      dev_accuracy[max_ind],
      connected_accuracy[max_ind],
      isolated_accuracy[max_ind],
      epoch_holder[max_ind],
      lr, alpha,
      max_accuracy
  )

  # optionally print model results
  if verbose in [1,2]:
    print('\n ######## \n')
    print(f'lr:{stats.lr}, alpha:{stats.alpha} @ epoch {stats.epoch}.')
    print(f'TL:{stats.train_loss}, TA:{stats.train_accuracy}.')
    print(f'DL:{stats.dev_loss}, DA:{stats.dev_accuracy}')
    print(f'con_acc:{stats.connected_accuracy}, iso_acc:{stats.isolated_accuracy}')


  return stats

#### $\color{red}{Sanity-check:}$

In [None]:
from collections import namedtuple
Stats = namedtuple('Stats', [
    'train_loss',
    'train_accuracy',
    'dev_loss',
    'dev_accuracy',
    'connected_accuracy',
    'isolated_accuracy',
    'epoch',
    'lr',
    'alpha',
    'max_accuracy'
])

In [None]:
tv_run(epochs=20, model=model, lr=0.00005, alpha=0.005, max_accuracy=0, path="binme2", verbose=2)

val loss tensor(4.0646, device='cuda:0')
val acc 0.12332695984703633
val loss tensor(4.0628, device='cuda:0')
val acc 0.1218809980806142
val loss tensor(4.0759, device='cuda:0')
val acc 0.12035225048923678
val loss tensor(4.0515, device='cuda:0')
val acc 0.12755598831548198
val loss tensor(4.0622, device='cuda:0')
val acc 0.1276207839562443
val loss tensor(4.0724, device='cuda:0')
val acc 0.12833168805528133
val loss tensor(4.0553, device='cuda:0')
val acc 0.1326530612244898
val loss tensor(4.0514, device='cuda:0')
val acc 0.1439153439153439
val loss tensor(4.0830, device='cuda:0')
val acc 0.1038961038961039
val loss tensor(4.0479, device='cuda:0')
val acc 0.1313340227507756
val loss tensor(4.0324, device='cuda:0')
val acc 0.17025862068965517
val loss tensor(4.0537, device='cuda:0')
val acc 0.11566484517304189
val loss tensor(4.0516, device='cuda:0')
val acc 0.1342031686859273
val loss tensor(4.0695, device='cuda:0')
val acc 0.10797665369649806
val loss tensor(4.0433, device='cuda:0')


Stats(train_loss=3.8302038237253826, train_accuracy=0.15355335892644795, dev_loss=tensor(4.0307, device='cuda:0'), dev_accuracy=tensor(0.1532, device='cuda:0'), connected_accuracy=0.1532033383846283, isolated_accuracy=tensor(0.1532, device='cuda:0'), epoch=1, lr=5e-05, alpha=0.005, max_accuracy=tensor(0.1532, device='cuda:0'))

In [None]:
def gen_config(lr_low, lr_high, alpha_low, alpha_high):
  np.random.seed()
  lr = round(10**float(np.random.uniform(lr_low,lr_high)),6)
  alpha = round(10**float(np.random.uniform(alpha_low,alpha_high)),6)
  return lr, alpha

In [None]:
def gen_ranges( lr, lr_range, alpha, alpha_range):

  lr_center = lr
  lr_low = lr_center - lr_range/2
  lr_high = lr_center + lr_range/2
  lr_diff = lr_high - lr_low

  alpha_center = alpha
  alpha_low = alpha_center - alpha_range/2
  alpha_high = alpha_center + alpha_range/2
  alpha_diff = alpha_high - alpha_low

  return (lr_low, lr_high, alpha_low, alpha_high)

In [None]:
def search_stats(results):
  best_stats = None
  max_dev_accuracy = 0
  for i in range(len(results)):
    acc = results[i].dev_accuracy
    if acc > max_dev_accuracy:
      best_stats = results[i]
      max_dev_accuracy = acc
  return best_stats

In [None]:
"""
Main Admin
"""
epochs = 60
max_accuracy = 0
path = "class/models/GNN_geom_distance.4.pt"
results = []

"""
init random search
lr [10^-5 - 10^-1]
alpha [10^-5 - 10^-1]
bs [8, 32, 128]
"""
lr_low = -5
lr_high = -3
lr_range = lr_high - lr_low

alpha_low = -5
alpha_high = -2
alpha_range = alpha_high - alpha_low

d = 768
h = 400
c = 70
num_relations = 2

count = 0

"""
Hyperparameter Search
"""

for i in range(4):
  # debug
  print("\n################\n")
  print(f'round: {i}')
  # print(f'lr_low{lr_low}, lr_high{lr_high}, lr_range{lr_range}')
  # print(f'alpha_low{alpha_low}, lr_high{alpha_high}, lr_range{alpha_range}')
  print('max', max_accuracy)
  print("\n################\n")


  for j in range(6):
    count += 1
    print(count)

    # get config
    lr, alpha = gen_config(lr_low, lr_high, alpha_low, alpha_high)
    # define model
    model = GNNModel(d,h,c)
    model = model.to(device)

    # run training
    res = tv_run(epochs, model, lr, alpha, max_accuracy, path, verbose = 1)
    max_accuracy = res.max_accuracy
    results.append(res)

  # get best result of the round or even so far
  stats = search_stats(results)


  print(stats) # debug

  # reconfigure the new hypers
  lr = np.log10(stats.lr)
  lr_range = lr_range / 3

  alpha = np.log10(stats.alpha)
  alpha_range = alpha_range / 3

  config = gen_ranges(lr, lr_range, alpha, alpha_range)
  lr_low, lr_high, alpha_low, alpha_high = config
  lr_range = lr_high - lr_low
  alpha_range = alpha_high - alpha_low



################

round: 0
max 0

################

1
T [0.00917431153357029, 0.12614677846431732, 0.24541282653808594, 0.28669723868370056, 0.36467888951301575, 0.39678898453712463, 0.4266054928302765, 0.45642200112342834, 0.4816513657569885, 0.5137614607810974, 0.5321100950241089, 0.5481651425361633, 0.5596330165863037, 0.5779816508293152, 0.5940366983413696, 0.60550457239151, 0.6100917458534241, 0.6100917458534241, 0.6146788597106934, 0.6146788597106934, 0.6169724464416504, 0.6238532066345215, 0.6330274939537048, 0.6376146674156189, 0.6376146674156189, 0.6353210806846619, 0.6444953680038452, 0.6513761281967163, 0.6490825414657593, 0.6444953680038452, 0.6444953680038452, 0.6376146674156189, 0.6467889547348022, 0.6513761281967163]

 ######## 

lr:5.2e-05, alpha:0.002825 @ epoch 28.
TL:0.46861774086952207, TA:0.8937791779287226.
DL:1.698555827140808, DA:0.5549792647361755
con_acc:0.6513761281967163, iso_acc:0.4753788113594055
2
T [0.0, 0.12155962735414505, 0.23853209614753723, 0.30963

## $\color{blue}{Test-Predictions}$



In [None]:
model = GNNModel(d,h,c)
model.load_state_dict(torch.load("binme", weights_only=True))
model.to(device)

GNNModel(
  (gnn_layers): ModuleList(
    (0-1): 2 x GNNLayer(768, 768)
  )
  (fc1): Linear(in_features=768, out_features=400, bias=True)
  (batch_norm_fc1): BatchNorm1d(400, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (fc2): Linear(in_features=400, out_features=70, bias=True)
  (dropout): Dropout(p=0.4, inplace=False)
  (relu): ReLU()
)

In [None]:
import torch

def score(reals, preds):
  return (reals == preds).sum()/len(reals)

def validate(model, sampler, criterion):
    """
    Validate the model on the validation dataset using the provided sampler.

    Parameters:
    - model: The model to be evaluated.
    - sampler: The sampler to sample validation data.
    - criterion: The loss function used for evaluation.

    Returns:
    - dev_loss: The calculated loss on the validation data.
    - dev_accuracy: The calculated accuracy on the validation data.
    """

    model.eval()


    mask = df_dev.connected.to_numpy() # mask for validation points connected on the graph
    n = df_dev.shape[0] # cutoff for validation points

    with torch.no_grad():
        for batch_size, n_id, adjs in sampler:
            edge_index = adjs[0][0].edge_index.t().to(device)
            x = val_data.x[n_id].to(device)  # Assuming `data.x` is your node features
            out = model(x, edge_index)
            y = val_data.y[n_id].to(device)

            # loss = criterion(out, y)
            # acc = accuracy(out, y)

            _, predicted = torch.max(out, 1)
            reals = y[:n]
            preds = predicted[:n]
            outs = out[:n,:]
            total_loss = criterion(outs, reals)
            total_acc = score(reals, preds)

            # connected
            reals_con = reals[mask]
            preds_con = preds[mask]
            connected_acc = score(reals_con, preds_con)

            # isolated
            reals_iso = reals[~mask]
            preds_iso = preds[~mask]
            isolated_acc = score(reals_iso, preds_iso)

    return total_loss, total_acc, connected_acc, isolated_acc



In [None]:
total_loss, total_acc, connected_acc, isolated_acc = validate(model, val_sampler, nn.CrossEntropyLoss())

In [None]:
def score(reals, preds):
  return (reals == preds).sum()/len(reals)

In [None]:
print(f'Overall score: {"{:.4f}".format(score(reals, preds))}')

Overall score: 0.6006


In [None]:
reals_connected = reals[mask]
preds_connected = preds[mask]
reals_isol = reals[~mask]
preds_isol = preds[~mask]

In [None]:
print(f'Connected score: {"{:.4f}".format(score(reals_connected, preds_connected))}')

Connected score: 0.6766


In [None]:
print(f'Isolated score: {"{:.4f}".format(score(reals_isol, preds_isol))}')

Isolated score: 0.5379
