# Text Classification - Training a GNN


## $\color{blue}{Sections:}$

* Preamble
1.   Admin
2.   Dataset
3.   Model
4.   Train - Validate
5.   Training Loop

## $\color{blue}{Preamble:}$

We now train a GNN in basic PyTorch. The model will look like a GCN. Inference willbe completed in another notebook as the whole graph must be uploaded at once.

## $\color{blue}{Admin}$
* Install relevant Libraries
* Import relevant Libraries

In [None]:
import torch
import pandas as pd
from google.colab import drive
import numpy as np
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

cuda


In [None]:
drive.mount("/content/drive")
%cd '/content/drive/MyDrive'

Mounted at /content/drive
/content/drive/MyDrive


## $\color{blue}{Data}$

* Connect to Drive
* Load the data
* Load adjacency matrices

In [None]:
path = 'class/datasets/'
df_train = pd.read_pickle(path + 'df_train')
df_dev = pd.read_pickle(path + 'df_dev')
df_test = pd.read_pickle(path + 'df_test')

In [None]:
path = 'class/tensors/adj_{}.pt'

# train
train_people = torch.load(path.format('train_people'))
train_locations = torch.load(path.format('train_locations'))
train_entities = torch.load(path.format('train_entities'))

# dev
dev_people = torch.load(path.format('dev_people'))
dev_locations = torch.load(path.format('dev_locations'))
dev_entities = torch.load(path.format('dev_entities'))

# val (contains the adjacency matrix for both the training and the development set)
val_people = torch.load(path.format('val_people'))
val_locations = torch.load(path.format('val_locations'))
val_entities = torch.load(path.format('val_entities'))

  train_people = torch.load(path.format('train_people'))
  train_locations = torch.load(path.format('train_locations'))
  train_entities = torch.load(path.format('train_entities'))
  dev_people = torch.load(path.format('dev_people'))
  dev_locations = torch.load(path.format('dev_locations'))
  dev_entities = torch.load(path.format('dev_entities'))
  val_people = torch.load(path.format('val_people'))
  val_locations = torch.load(path.format('val_locations'))
  val_entities = torch.load(path.format('val_entities'))


## $\color{blue}{Dataset:}$

In [None]:
from torch.utils.data import Dataset, DataLoader
from copy import deepcopy

def sample_neighborhood(A, inds, neighbor_max, branch_max, seed=None):
    # Set the random seed for deterministic responses
    if seed is not None:
        np.random.seed(seed)

    np.random.shuffle(A)  # Shuffle the list of adjacency matrices in place
    sampled_indices = set(inds)  # Initialize the set of sampled indices
    # print("-----\nSTART\n")
    # print("A size: ",A[0].size())
    # print("sampled inices: ", sampled_indices)
    # print("neighbor_max: ", neighbor_max)
    # print("batch max: ", branch_max)
    # print("------\nLOOPING THROUGH INDS: ", inds)
    for ind in inds:  # Iterate through node in mini-batch
        # print("-----\n ind: ", ind)
        break_to_outer = False
        neighbors = set()

        for adj in A:  # Iterate through all adjacency matrices
            if break_to_outer:
              break

            # Get the indices of all neighbors that idx links to
            disclude = set([ind]) | sampled_indices
            new_neighbors = [neighbor.item() for neighbor in (adj[ind] > 0).nonzero(as_tuple=True)[0] if neighbor.item() not in disclude]
            neighbors.update(new_neighbors)
            # print("-------\nnew neighbors: ", new_neighbors)
            # print("Value of neighbors: ", neighbors)

            if len(neighbors) >= neighbor_max:  # Check if we have too many neighbors
                # Take a random subset using np.random.choice
                # print("------\nneighbors is bigger than max neighbors")
                neighbors = set(np.random.choice(list(neighbors), neighbor_max, replace=False))
                # print("Value of neighbors: ", neighbors)


            copy_neighbors = deepcopy(neighbors)
            for idx in copy_neighbors:
                if break_to_outer:
                  break
                # print("-----\nLooping through neighbors with ind ", idx)

                neighbors_neighbors = set()
                for adj in A:
                    disclude = set([ind,idx]) | sampled_indices | neighbors
                    new_neighbors_neighbors = [neighbor.item() for neighbor in (adj[idx] > 0).nonzero(as_tuple=True)[0] if neighbor.item() not in disclude]
                    if len(new_neighbors_neighbors) > neighbor_max:
                      new_neighbors_neighbors = set(np.random.choice(list(new_neighbors_neighbors), neighbor_max, replace = False))
                    neighbors_neighbors.update(new_neighbors_neighbors)
                    # print("-------\nnew neighbors_neighbors: ", new_neighbors_neighbors)
                    # print("Value of neighbors_neighbors: ", neighbors_neighbors)

                    if len(neighbors) + len(neighbors_neighbors) >= branch_max:
                      # print(f"------branch max exceeded for ind {ind} with len neighbors = {len(neighbors)} and len neighbors neighbors = {len(neighbors_neighbors)}")

                      neighbors_neighbors = set(np.random.choice(list(neighbors_neighbors), branch_max - len(neighbors), replace=False))
                      # print("Value of neighbors_neighbors: ", neighbors_neighbors)

                      neighbors.update(neighbors_neighbors)
                      # print("New value of neighbors: ", neighbors)

                      break_to_outer = True
                      break

                    neighbors.update(neighbors_neighbors)
                    # print("New value of neighbors with new neighbors_neighbors: ", neighbors)

        sampled_indices.update(neighbors)  # Add new neighbors
        # print(f"____\n END OF ind {ind}; sampled indeices is now {sampled_indices}")

    return list(sampled_indices)

#check the conditions

In [None]:
import torch
import numpy as np
a = torch.Tensor([
    [0,1,0,0,0,0,0,1],
    [1,0,0,1,0,1,0,0],
    [0,1,0,1,0,0,1,0],
    [0,1,0,0,1,0,0,1],
    [1,1,0,0,1,0,0,0],
    [0,0,1,1,0,1,1,0],
    [0,0,0,1,0,0,0,1],
    [0,0,0,1,0,1,1,1]
])
A = []
A.append(a)
inds = [1,2]
nm = 2
bm = 4
seed=42

In [None]:
import torch

def neighbor_analysis(primary_inds: list[int], inds: list[int], adjacency_matrix: torch.Tensor):
    result = []

    # Loop over each primary index
    for primary_idx in primary_inds:
        # Get neighbors for the primary index
        neighbors = []
        for ind in inds:
            # Check if ind is a neighbor of primary_idx
            if adjacency_matrix[primary_idx, ind] == 1:
                neighbors.append(ind)

        # Find neighbors of neighbors (n + 1 neighbors)
        neighbors_of_neighbors = []
        for neighbor in neighbors:
            for ind in inds:
                # Check if ind is a neighbor of the current neighbor
                if adjacency_matrix[neighbor, ind] == 1 and ind != primary_idx:
                    neighbors_of_neighbors.append(ind)

        # Remove duplicates for the neighbors of neighbors
        neighbors_of_neighbors = list(set(neighbors_of_neighbors))

        # Add the tuple to results
        result.append((primary_idx, neighbors, neighbors_of_neighbors))

    return result

In [None]:
import random
inds = random.sample(list(range(12000)),6)
print(inds)
#inds = list(range(101,105))
sample = sample_neighborhood([train_entities], inds, 4, 16)
print(sample)
print(len(sample))

[2144, 2327, 5933, 10288, 346, 5847]
[7431, 6920, 5010, 8468, 11540, 2327, 5149, 543, 8865, 6951, 11052, 5933, 1711, 10288, 432, 4146, 3632, 4786, 6198, 2879, 2374, 4679, 9802, 7883, 4945, 9042, 7638, 5847, 346, 9820, 2144, 11872, 6624, 7523, 6888, 1001, 1774, 8574]
38


In [None]:
analysis = neighbor_analysis(inds, sample, train_entities)
for item in analysis:
  print(item)

(2144, [6920, 5010, 8468, 8865, 1711, 6198, 2879, 2374, 9802, 9042, 11872, 8574], [11872, 8865, 6624, 2374, 6920, 9802, 1711, 432, 5010, 9042, 8468, 4146, 6198, 7638, 8574, 2879])
(2327, [], [])
(5933, [], [])
(10288, [7431, 11540, 5149, 543, 6951, 11052, 3632, 4786, 4679, 7883, 4945, 9820, 7523, 6888, 1001, 1774], [7523, 6951, 4679, 6888, 1001, 7883, 11052, 7431, 1774, 3632, 4945, 4786, 11540, 9820, 5149, 543])
(346, [], [])
(5847, [], [])


In [None]:
# inds = random.sample(list(range(12000)),4)
inds = list(range(101,105))
sample = sample_neighborhood([train_entities], inds, 16, 64, seed=42)
print(sample)
print(len(sample))

[9221, 2061, 3599, 10261, 9214, 6937, 4386, 2855, 8492, 4654, 10039, 4152, 2363, 11581, 317, 6210, 580, 6725, 9046, 599, 5721, 3673, 1883, 3679, 4707, 4964, 101, 102, 103, 104, 8551, 5480, 5485, 11121, 3953, 7285, 8824, 11904, 6274, 9607, 11660, 7317, 5271, 8856, 3235, 7594, 4524, 5550, 10927, 4023, 1722, 4030, 10430, 8641, 4305, 8916, 9688, 2009, 10203, 10207, 9440, 9185, 1250, 8680, 6640, 11763, 2804, 510]
68


In [None]:
class GNNDataset(Dataset):
  def __init__(self, H, A, labels, meta_indices, neighbor_max=4, branch_max=16, seed=None):
    """Custom dataset with neighborhood sampling

    Args:
      H : torch.tensor
        input embeddings (n x d)

      A : list[torch.tensor]
        list of (n x n)

      labels : torch.LongTensor
        y

      meta_indices : torch.LongTensor
        index of datapoint to filter validation score

      neighbor_max : int
        max neighbors for each node in mini-batch

      batch_max : int
        max size of batch

    """
    # All inits must be tensors
    self.H = H.to(device)
    self.A = [a.to(device) for a in A]
    self.labels = labels.to(device)
    self.meta_indices = meta_indices
    self.neighbor_max = neighbor_max
    self.branch_max = branch_max
    self.seed = seed

  def __len__(self):
    return len(self.labels)

  def __getitem__(self, inds):
    # print('\n####################\n')
    # print('GET ITEM CALLED', 'INDS:', inds)
    # Sample neighborhood

    # get inds in list
    inds = inds.tolist() if torch.is_tensor(inds) else (inds if isinstance(inds,list) else [inds])

    # return the required inds
    sampled_indices = sample_neighborhood(self.A, inds, self.neighbor_max, self.branch_max,seed=self.seed)

    # get the input for the required inds
    H_batch = self.H[sampled_indices]

    # get the adjacency matrix for the required inds
    A_batch = [self.A[k][sampled_indices][:, sampled_indices] for k in range(len(self.A))]

    # get the labels for the required inds
    labels_batch = self.labels[sampled_indices]

    # get meta indices
    index_batch = self.meta_indices[sampled_indices]

    return H_batch, A_batch, labels_batch, index_batch

In [None]:
H = torch.stack(list(df_train['vanilla_embedding.1']))
labels = torch.LongTensor(list(df_train['chapter_idx']))
A = []
A.append(train_entities)
meta_indices = torch.LongTensor(list(range(df_train.shape[0])))

In [None]:
train_dataset = GNNDataset(H, A, labels, meta_indices, neighbor_max=4, branch_max=10)

# Prevent dataloader from calling a single index at a time
custom_sampler = torch.utils.data.sampler.BatchSampler(
    torch.utils.data.sampler.RandomSampler(train_dataset),
    batch_size=6,
    drop_last=False)


train_loader = DataLoader(train_dataset, sampler = custom_sampler)


In [None]:
# Check batches

# Number of batches to inspect
num_batches_to_check = 2

for batch_idx, (inputs, adjacency, labels, indices) in enumerate(train_loader):
    print('\n##########################\n')
    print(f"Batch {batch_idx + 1}/{num_batches_to_check}:")
    print('-' * 10)
    print("Inputs:")
    print(f"  Type: {type(inputs)}")
    print(f"  Shape: {inputs.size()}")
    print('-' * 10)
    print("Adjacency:")
    print(f"  Type: {type(adjacency)}")
    print(f"  Shape: {adjacency[0].size()}")
    print('-' * 10)
    print("Indices:")
    print(f"  Type: {type(indices)}")
    print(f"  Shape: {indices.size()}")
    print(indices)
    print('-' * 10)
    print("Labels:")
    print(f"  Type: {type(labels)}")
    print(f"  Shape: {labels.size()}")
    print(labels)

    # Stop after inspecting the desired number of batches
    if batch_idx + 1 >= num_batches_to_check:
        break


##########################

Batch 1/2:
----------
Inputs:
  Type: <class 'torch.Tensor'>
  Shape: torch.Size([1, 26, 768])
----------
Adjacency:
  Type: <class 'list'>
  Shape: torch.Size([1, 26, 26])
----------
Indices:
  Type: <class 'torch.Tensor'>
  Shape: torch.Size([1, 26])
tensor([[ 9152,  4865,  8832,  2305,  9542, 10381,  8910,   146, 10964,  2389,
          4501,  3799,  3220,  2711,  9438,   607,   865,  1448,  8298,  2540,
          4271,  2101,  5814,   696,  1145,  3899]])
----------
Labels:
  Type: <class 'torch.Tensor'>
  Shape: torch.Size([1, 26])
tensor([[53, 17, 42, 12, 16,  5,  8, 47, 48, 57, 45, 54,  9, 15, 64, 12, 59, 11,
          5, 56, 58, 15, 47,  5,  5, 67]], device='cuda:0')

##########################

Batch 2/2:
----------
Inputs:
  Type: <class 'torch.Tensor'>
  Shape: torch.Size([1, 16, 768])
----------
Adjacency:
  Type: <class 'list'>
  Shape: torch.Size([1, 16, 16])
----------
Indices:
  Type: <class 'torch.Tensor'>
  Shape: torch.Size([1, 16])
tenso

In [None]:
H = torch.stack(list(df_train['vanilla_embedding.1']))
labels = torch.LongTensor(list(df_train['chapter_idx']))
A = []
A.append(train_entities)
meta_indices = torch.LongTensor(list(range(df_train.shape[0])))

In [None]:
train_dataset_fixed = GNNDataset(H, A, labels, meta_indices, neighbor_max=8, batch_max=32, seed=42)

# Prevent dataloader from calling a single index at a time
custom_validation_sampler = torch.utils.data.sampler.BatchSampler(
    torch.utils.data.sampler.SequentialSampler(train_dataset_fixed),
    batch_size=8,
    drop_last=False)


train_loader_fixed = DataLoader(train_dataset_fixed, sampler = custom_validation_sampler)

In [None]:
# Check batches

# Number of batches to inspect
num_batches_to_check = 2

for batch_idx, (inputs, adjacency, labels, indices) in enumerate(train_loader_fixed):
    print('\n##########################\n')
    print(f"Batch {batch_idx + 1}/{num_batches_to_check}:")
    print('-' * 10)
    print("Inputs:")
    print(f"  Type: {type(inputs)}")
    print(f"  Shape: {inputs.size()}")
    print('-' * 10)
    print("Adjacency:")
    print(f"  Type: {type(adjacency)}")
    print(f"  Shape: {adjacency[0].size()}")
    print('-' * 10)
    print("Indices:")
    print(f"  Type: {type(indices)}")
    print(f"  Shape: {indices.size()}")
    print(indices)
    print('-' * 10)
    print("Labels:")
    print(f"  Type: {type(labels)}")
    print(f"  Shape: {labels.size()}")
    print(labels)

    # Stop after inspecting the desired number of batches
    if batch_idx + 1 >= num_batches_to_check:
        break


##########################

Batch 1/2:
----------
Inputs:
  Type: <class 'torch.Tensor'>
  Shape: torch.Size([1, 32, 768])
----------
Adjacency:
  Type: <class 'list'>
  Shape: torch.Size([1, 32, 32])
----------
Indices:
  Type: <class 'torch.Tensor'>
  Shape: torch.Size([1, 32])
tensor([[    0,     1,     2,     3,     4,     5,     6,     7,  5446,  9222,
          2250,   583,  9547,  7760,  6290,  8020,  8983,  5533,  6623,   993,
          8424,  8616,  9900,  1900,  9773,  2668,   817,  3572,  7989,  9849,
         10620,  8189]])
----------
Labels:
  Type: <class 'torch.Tensor'>
  Shape: torch.Size([1, 32])
tensor([[31, 15, 15, 59, 62, 63, 17, 15, 31, 15, 31, 16, 16, 31, 59, 56, 16, 31,
         50, 31, 31,  5, 14, 59, 59, 57, 31, 58,  7, 31,  9, 59]],
       device='cuda:0')

##########################

Batch 2/2:
----------
Inputs:
  Type: <class 'torch.Tensor'>
  Shape: torch.Size([1, 32, 768])
----------
Adjacency:
  Type: <class 'list'>
  Shape: torch.Size([1, 32, 32])
---

The dataloader seems to be working correctly, the implementation of custom sampling on all indices at once leads to DataLoaders collate function inserting a new dimension that it will stach against. Because all indices are dealt with at once, there is no stacking.

The simple solution will be to simply squeeze the tensors in the training loop. The validation loader eradicates randomness from the process.

## $\color{blue}{Model:}$

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class GNNLayer(nn.Module):
    def __init__(self, in_features, out_features, num_relations=1, dropout=0.3):
        super(GNNLayer, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.num_relations = num_relations
        self.dropout = dropout

        self.T = nn.ParameterList([nn.Parameter(torch.Tensor(in_features, out_features)) for _ in range(num_relations)])
        self.E = nn.ParameterList([nn.Parameter(torch.Tensor(in_features, out_features)) for _ in range(num_relations)])

        # Batch normalization
        self.batch_norm = nn.BatchNorm1d(out_features)

        self.reset_parameters()

    def reset_parameters(self):
        for t in self.T:
            nn.init.xavier_uniform_(t)
        for e in self.E:
            nn.init.xavier_uniform_(e)

    def forward(self, H, A):
        H_out = torch.zeros_like(H)
        for k in range(self.num_relations):
            messages_projection = A[k].T @ H @ self.E[k]
            degrees = A[k].sum(dim=1, keepdim=True)
            degrees[degrees == 0] = 1.0
            messages_projection /= degrees

            self_projection = H @ self.T[k]

            # Include skip connection
            H_out += F.leaky_relu(self_projection + messages_projection) + H

        # Apply batch normalization
        H_out = self.batch_norm(H_out)

        # Apply dropout
        H_out = F.dropout(H_out, p=self.dropout, training=self.training)

        return H_out

class GNNModel(nn.Module):
    def __init__(self, d, h, c, num_relations=1, num_layers=3, dropout=0.3):
        super(GNNModel, self).__init__()
        self.num_layers = num_layers
        self.gnn_layers = nn.ModuleList([GNNLayer(d, d, num_relations, dropout) for _ in range(num_layers)])
        self.fc1 = nn.Linear(d, h)
        self.batch_norm_fc1 = nn.BatchNorm1d(h)
        self.fc2 = nn.Linear(h, c)
        self.dropout = dropout

    def forward(self, H, A):
        for layer in self.gnn_layers:
            H = layer(H, A)

        H = F.relu(self.batch_norm_fc1(self.fc1(H)))
        H = F.dropout(H, p=self.dropout, training=self.training)
        Output = self.fc2(H)
        return Output

    def forward_layer(self, H, A, layer_idx):
        """Forward pass for a specific layer."""
        H = self.gnn_layers[layer_idx](H, A)
        return H


In [None]:
class GNNModel(nn.Module):
   def __init__(self, d, h, c, num_relations=1, num_layers=2):
      super(GNNModel, self).__init__()
      self.num_layers = num_layers
      self.gnn_layers = nn.ModuleList([GNNLayer(d, d) for _ in range(num_layers)])
      self.fc1 = nn.Linear(d, h)
      self.fc2 = nn.Linear(h, c)

   def forward(self, H, A):
      for layer in self.gnn_layers:
         H = layer(H, A)
      # Classification
      H = F.relu(self.fc1(H))
      Output = self.fc2(H)
      return Output

In [None]:
import torch.optim as optim

d = 768
h = 400   # hidden dimension of fully connected layer
c = 70   # number of classes
num_relations = 1   # number of relationship types

# Model, Loss, Optimizer
model = GNNModel(d, h, c, num_relations)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)



In [None]:
def count_parameters_per_module(model):
    print("Module and parameter counts:")

    for name, module in model.named_modules():
        # Skip the top-level module (the model itself)
        if not isinstance(module, nn.Module) or name == "":
            continue

        param_count = sum(p.numel() for p in module.parameters() if p.requires_grad)

        if param_count > 0:  # Only print modules that have parameters
            print(f"{name}: {param_count} parameters")

In [None]:
count_parameters_per_module(model)

Module and parameter counts:
gnn_layers: 2362368 parameters
gnn_layers.0: 1181184 parameters
gnn_layers.0.T: 589824 parameters
gnn_layers.0.E: 589824 parameters
gnn_layers.0.batch_norm: 1536 parameters
gnn_layers.1: 1181184 parameters
gnn_layers.1.T: 589824 parameters
gnn_layers.1.E: 589824 parameters
gnn_layers.1.batch_norm: 1536 parameters
fc1: 307600 parameters
fc2: 28070 parameters


## $\color{blue}{Train-Validate:}$

In [None]:
def accuracy(outputs, labels):
    # argmax to get predicted classes
    _, predicted = torch.max(outputs, 1)

    # count correct
    correct = (predicted == labels).sum().item()

    # get average
    acc = correct / labels.size(0)  # Total number of samples
    return acc

In [None]:
import numpy as np

def train(model, train_loader, criterion, optimizer):
    model.train()
    epoch_train_losses = []
    epoch_train_accuracy = []

    for batch_idx, (H, A, y, indices) in enumerate(train_loader):
        optimizer.zero_grad()

        H = H.squeeze(0)
        A = [a.squeeze(0) for a in A]
        y = y.squeeze(0)

        out = model(H,A)
        train_loss = criterion(out, y)
        train_accuracy = accuracy(out, y)


        epoch_train_losses.append(train_loss.item())
        epoch_train_accuracy.append(train_accuracy)

        # Backpropagation and optimization
        train_loss.backward()
        optimizer.step()

    return np.mean(epoch_train_losses), np.mean(epoch_train_accuracy)

In [None]:
def validate(model, dev_loader, criterion, threshold=12000):
    model.eval()
    epoch_dev_losses = []
    epoch_dev_accuracy = []
    pred_holder = []
    real_holder = []

    with torch.no_grad():
        for batch_idx, (H, A, y, indices) in enumerate(dev_loader):
            H = H.squeeze(0)
            A = [a.squeeze(0) for a in A]
            y = y.squeeze(0)
            indices = indices.squeeze(0)

            out = model(H, A)

            # Filter out training points
            mask = indices >= threshold
            filtered_out = out[mask]
            filtered_y = y[mask]

            # Calculate loss and accuracy only on filtered outputs
            if filtered_out.size(0) > 0:  # Ensure there are samples to evaluate
                dev_loss = criterion(filtered_out, filtered_y)
                dev_accuracy = accuracy(filtered_out, filtered_y)

                epoch_dev_losses.append(dev_loss.item())
                epoch_dev_accuracy.append(dev_accuracy)

    # Avoid division by zero if no validation points were processed
    return np.mean(epoch_dev_losses), np.mean(epoch_dev_accuracy)

## $\color{blue}{Training:}$

In [None]:
from collections import namedtuple
Stats = namedtuple('Stats', [
    'train_loss',
    'train_accuracy',
    'dev_loss',
    'dev_accuracy',
    'epoch',
    'lr',
    'alpha',
    'max_accuracy'
])

In [None]:
def gen_config(lr_low, lr_high, alpha_low, alpha_high):
  np.random.seed()
  lr = round(10**float(np.random.uniform(lr_low,lr_high)),6)
  alpha = round(10**float(np.random.uniform(alpha_low,alpha_high)),6)
  return lr, alpha

In [None]:
def gen_ranges( lr, lr_range, alpha, alpha_range):

  lr_center = lr
  lr_low = lr_center - lr_range/2
  lr_high = lr_center + lr_range/2
  lr_diff = lr_high - lr_low

  alpha_center = alpha
  alpha_low = alpha_center - alpha_range/2
  alpha_high = alpha_center + alpha_range/2
  alpha_diff = alpha_high - alpha_low

  return (lr_low, lr_high, alpha_low, alpha_high)

In [None]:
def search_stats(results):
  best_stats = None
  max_dev_accuracy = 0
  for i in range(len(results)):
    acc = results[i].dev_accuracy
    if acc > max_dev_accuracy:
      best_stats = results[i]
      max_dev_accuracy = acc
  return best_stats

In [None]:
def tv_run(epochs, model, lr, alpha, max_accuracy, path, verbose = 0):
  """
  Runs a training setup
  verbose == 1 - print model results
  verbose == 2 -> print epoch and model results
  """
  model = model.to(device)
  criterion = nn.CrossEntropyLoss()
  optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=alpha)

  # Prepare data loaders
  train_loader = DataLoader(train_dataset, sampler = custom_train_sampler)
  dev_loader = DataLoader(validation_dataset, sampler = custom_validation_sampler)

  # Hold epoch stats
  train_losses = []
  train_accuracy = []
  dev_losses = []
  dev_accuracy = []
  epoch_holder = []

  # Break if no improvement
  current_best = 0
  no_improvement = 0


  # Run epochs
  for epoch in range(epochs):

    # break out of epochs
    if no_improvement >= 6:
      break

    # call training and validation functions
    train_loss, train_acc = train(model, train_loader, criterion, optimizer)
    dev_loss, dev_acc = validate(model, dev_loader, criterion)

    # Store epoch stats
    train_losses.append(train_loss)
    train_accuracy.append(train_acc)
    dev_losses.append(dev_loss)
    dev_accuracy.append(dev_acc)
    epoch_holder.append(epoch + 1)

    # check for improvement
    if dev_acc > current_best:
      current_best = dev_acc
      no_improvement = 0
    else:
      no_improvement += 1

    # save best model
    if dev_acc > max_accuracy:
      torch.save(model.state_dict(), path)
      max_accuracy = dev_acc


    # optionally print epoch results
    if verbose == 2:
      print(f'\n --------- \nEpoch: {epoch + 1}\n')
      print(f'Epoch {epoch + 1} train loss: {train_loss:.4f}')
      print(f'Epoch {epoch + 1} train accuracy: {train_acc:.4f}')
      print(f'Epoch {epoch + 1} dev loss: {dev_loss:.4f}')
      print(f'Epoch {epoch + 1} dev accuracy: {dev_acc:.4f}')

      # save best results
  max_ind = np.argmax(dev_accuracy)

  stats = Stats(
      train_losses[max_ind],
      train_accuracy[max_ind],
      dev_losses[max_ind],
      dev_accuracy[max_ind],
      epoch_holder[max_ind],
      lr, alpha,
      max_accuracy
  )

  # optionally print model results
  if verbose in [1,2]:
    print('\n ######## \n')
    print(f'lr:{stats.lr}, alpha:{stats.alpha} @ epoch {stats.epoch}.')
    print(f'TL:{stats.train_loss}, TA:{stats.train_accuracy}.')
    print(f'DL:{stats.dev_loss}, DA:{stats.dev_accuracy}')

  return stats

#### $\color{red}{Sanity-check:}$

In [None]:
# model
model = GNNModel(d, h, c, num_relations)

In [None]:
# training loader
H_train = torch.stack(list(df_train['vanilla_embedding.1']))
labels_train = torch.LongTensor(list(df_train['chapter_idx']))
A_train = []
A_train.append(train_entities)
train_indices = torch.LongTensor(list(range(df_train.shape[0])))

train_dataset = GNNDataset(H_train, A_train, labels_train, train_indices, neighbor_max=4, branch_max=16)

# Prevent dataloader from calling a single index at a time
custom_train_sampler = torch.utils.data.sampler.BatchSampler(
    torch.utils.data.sampler.RandomSampler(train_dataset),
    batch_size=8,
    drop_last=False)


train_loader = DataLoader(train_dataset, sampler = custom_train_sampler)

# training loader
df1 = df_train[['vanilla_embedding.1', 'chapter_idx']]
df2 = df_dev[['vanilla_embedding.1', 'chapter_idx']]
df_val = pd.concat([df1, df2])
H_val = torch.stack(list(df_val['vanilla_embedding.1']))
labels_val = torch.LongTensor(list(df_val['chapter_idx']))
A_val = []
A_val.append(val_entities)
val_indices = torch.LongTensor(list(range(df_val.shape[0])))



validation_dataset = GNNDataset(H_val, A_val, labels_val, val_indices, neighbor_max=4, branch_max=16, seed=42)

# Prevent dataloader from calling a single index at a time
custom_validation_sampler = torch.utils.data.sampler.BatchSampler(
    torch.utils.data.sampler.SequentialSampler(validation_dataset),
    batch_size=8,
    drop_last=False)


dev_loader = DataLoader(validation_dataset, sampler = custom_validation_sampler)



In [None]:
epochs = 2
lr = 0.0005
alpha = 0.0001
path = "class/models/GNN.2.pt"
max_accuracy = 0

In [None]:
tv_run(epochs, model, lr, alpha, max_accuracy, path, verbose = 2)

0
8
16
24
32
40
48
56
64
72
80
88
96
104
112
120
128
136
144
152
160
168
176
184
192
200
208
216
224
232
240
248
256
264
272
280
288
296
304
312
320
328
336
344
352
360
368
376
384
392
400
408
416
424
432
440
448
456
464
472
480
488
496
504
512
520
528
536
544
552
560
568
576
584
592
600
608
616
624
632
640
648
656
664
672
680
688
696
704
712
720
728
736
744
752
760
768
776
784
792
800
808
816
824
832
840
848
856
864
872
880
888
896
904
912
920
928
936
944
952
960
968
976
984
992
1000
1008
1016
1024
1032
1040
1048
1056
1064
1072
1080
1088
1096
1104
1112
1120
1128
1136
1144
1152
1160
1168
1176
1184
1192
1200
1208
1216
1224
1232
1240
1248
1256
1264
1272
1280
1288
1296
1304
1312
1320
1328
1336
1344
1352
1360
1368
1376
1384
1392
1400
1408
1416
1424
1432
1440
1448
1456
1464
1472
1480
1488
1496
1504
1512
1520
1528
1536
1544
1552
1560
1568
1576
1584
1592
1600
1608
1616
1624
1632
1640
1648
1656
1664
1672
1680
1688
1696
1704
1712
1720
1728
1736
1744
1752
1760
1768
1776
1784
1792
1800
1808
1816


Stats(train_loss=0.3517950084544718, train_accuracy=0.9055995973645191, dev_loss=1.4560895394253177, dev_accuracy=0.6747574191728593, epoch=1, lr=0.0005, alpha=0.0001, max_accuracy=0.6747574191728593)

#### $\color{red}{Run:}$

In [None]:
"""
Main Admin
"""
epochs = 30
max_accuracy = 0
path = "class/models/GNN.3.pt"
results = []

"""
init random search
lr [10^-5 - 10^-1]
alpha [10^-5 - 10^-1]
bs [8, 32, 128]
"""
lr_low = -5
lr_high = -3
lr_range = lr_high - lr_low

alpha_low = -5
alpha_high = -3
alpha_range = alpha_high - alpha_low

d = 768
h = 400
c = 70
num_relations = 1

count = 0

"""
Hyperparameter Search
"""

for i in range(3):
  # debug
  print("\n################\n")
  print(f'round: {i}')
  # print(f'lr_low{lr_low}, lr_high{lr_high}, lr_range{lr_range}')
  # print(f'alpha_low{alpha_low}, lr_high{alpha_high}, lr_range{alpha_range}')
  print('max', max_accuracy)
  print("\n################\n")


  for j in range(6):
    count += 1
    print(count)

    # get config
    lr, alpha = gen_config(lr_low, lr_high, alpha_low, alpha_high)
    # define model
    model = GNNModel(d, h, c, num_relations)
    model = model.to(device)

    # run training
    res = tv_run(epochs, model, lr, alpha, max_accuracy, path, verbose = 2)
    max_accuracy = res.max_accuracy
    results.append(res)

  # get best result of the round or even so far
  stats = search_stats(results)


  print(stats) # debug

  # reconfigure the new hypers
  lr = np.log10(stats.lr)
  lr_range = lr_range / 3

  alpha = np.log10(stats.alpha)
  alpha_range = alpha_range / 3

  config = gen_ranges(lr, lr_range, alpha, alpha_range)
  lr_low, lr_high, alpha_low, alpha_high = config
  lr_range = lr_high - lr_low
  alpha_range = alpha_high - alpha_low



################

round: 0
max 0

################

1

 --------- 
Epoch: 1

Epoch 1 train loss: 3.0942
Epoch 1 train accuracy: 0.2960
Epoch 1 dev loss: 1.9864
Epoch 1 dev accuracy: 0.4889

 --------- 
Epoch: 2

Epoch 2 train loss: 1.9579
Epoch 2 train accuracy: 0.5328
Epoch 2 dev loss: 1.4462
Epoch 2 dev accuracy: 0.5802

 --------- 
Epoch: 3

Epoch 3 train loss: 1.4683
Epoch 3 train accuracy: 0.6356
Epoch 3 dev loss: 1.2633
Epoch 3 dev accuracy: 0.6220

 --------- 
Epoch: 4

Epoch 4 train loss: 1.1842
Epoch 4 train accuracy: 0.6981
Epoch 4 dev loss: 1.1948
Epoch 4 dev accuracy: 0.6162

 --------- 
Epoch: 5

Epoch 5 train loss: 0.9946
Epoch 5 train accuracy: 0.7451
Epoch 5 dev loss: 1.1355
Epoch 5 dev accuracy: 0.6359

 --------- 
Epoch: 6

Epoch 6 train loss: 0.8639
Epoch 6 train accuracy: 0.7772
Epoch 6 dev loss: 1.1321
Epoch 6 dev accuracy: 0.6428

 --------- 
Epoch: 7

Epoch 7 train loss: 0.7508
Epoch 7 train accuracy: 0.8082
Epoch 7 dev loss: 1.1236
Epoch 7 dev accuracy: 0.6572
