# Text Classification - Training a GNN


## $\color{blue}{Sections:}$

* Preamble
1.   Admin
2.   Dataset
3.   Model
4.   Train - Validate
5.   Training Loop

## $\color{blue}{Preamble:}$

We now train a GNN in basic PyTorch. The model will look like a GCN. Inference willbe completed in another notebook as the whole graph must be uploaded at once.

## $\color{blue}{Admin}$
* Install relevant Libraries
* Import relevant Libraries

In [1]:
import torch
import pandas as pd
from google.colab import drive
import numpy as np
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

cuda


In [2]:
drive.mount("/content/drive")
%cd '/content/drive/MyDrive'

Mounted at /content/drive
/content/drive/MyDrive


## $\color{blue}{Data}$

* Connect to Drive
* Load the data
* Load adjacency matrices

In [3]:
path = 'class/datasets/'
df_train = pd.read_pickle(path + 'df_train_augmentation_ft')
df_dev = pd.read_pickle(path + 'df_dev_augmentation_ft')
df_test = pd.read_pickle(path + 'df_test_augmentation_ft')


In [4]:
df_train.columns

Index(['master', 'book_idx', 'chapter_idx', 'content', 'vanilla_embedding.1',
       'direct_ft_augmented_embedding', 'ner_responses'],
      dtype='object')

In [5]:
df1 = df_train[['book_idx', 'chapter_idx', 'content', 'direct_ft_augmented_embedding', 'ner_responses']]
df2 = df_dev[['book_idx', 'chapter_idx', 'content', 'direct_ft_augmented_embedding', 'ner_responses']]
df_val = pd.concat([df2,df1])

In [6]:
df_val.shape

(21220, 5)

In [7]:
path = 'class/tensors/adj_{}.pt'

# train
# train_people = torch.load(path.format('train_people'))
# train_locations = torch.load(path.format('train_locations'))
train_entities = torch.load(path.format('train_augmented_entities'))

# # dev
# dev_people = torch.load(path.format('dev_people'))
# dev_locations = torch.load(path.format('dev_locations'))
dev_entities = torch.load(path.format('dev_augmented_entities'))

# # val (contains the adjacency matrix for both the training and the development set)
# val_people = torch.load(path.format('val_people'))
# val_locations = torch.load(path.format('val_locations'))
val_entities = torch.load(path.format('val_augmented_entities'))

## $\color{blue}{Dataset:}$

In [8]:
from torch.utils.data import Dataset, DataLoader
from copy import deepcopy

def sample_neighborhood(A, inds, neighbor_max, branch_max, seed=None):
    # Set the random seed for deterministic responses
    if seed is not None:
        np.random.seed(seed)

    np.random.shuffle(A)  # Shuffle the list of adjacency matrices in place
    sampled_indices = set(inds)  # Initialize the set of sampled indices

    for ind in inds:  # Iterate through node in mini-batch
        # print("-----\n ind: ", ind)
        break_to_outer = False
        neighbors = set()

        for adj in A:  # Iterate through all adjacency matrices
            if break_to_outer:
              break

            # Get the indices of all neighbors that idx links to
            disclude = set([ind]) | sampled_indices
            new_neighbors = [neighbor.item() for neighbor in (adj[ind] > 0).nonzero(as_tuple=True)[0] if neighbor.item() not in disclude]
            neighbors.update(new_neighbors)

            if len(neighbors) >= neighbor_max:  # Check if we have too many neighbors
                # Take a random subset using np.random.choice
                neighbors = set(np.random.choice(list(neighbors), neighbor_max, replace=False))


            copy_neighbors = deepcopy(neighbors)
            for idx in copy_neighbors:
                if break_to_outer:
                  break

                neighbors_neighbors = set()
                for adj in A:
                    disclude = set([ind,idx]) | sampled_indices | neighbors
                    new_neighbors_neighbors = [neighbor.item() for neighbor in (adj[idx] > 0).nonzero(as_tuple=True)[0] if neighbor.item() not in disclude]
                    if len(new_neighbors_neighbors) > neighbor_max:
                      new_neighbors_neighbors = set(np.random.choice(list(new_neighbors_neighbors), neighbor_max, replace = False))
                    neighbors_neighbors.update(new_neighbors_neighbors)
                    if len(neighbors) + len(neighbors_neighbors) >= branch_max:
                      neighbors_neighbors = set(np.random.choice(list(neighbors_neighbors), branch_max - len(neighbors), replace=False))
                      neighbors.update(neighbors_neighbors)
                      break_to_outer = True
                      break

                    neighbors.update(neighbors_neighbors)
                    # print("New value of neighbors with new neighbors_neighbors: ", neighbors)

        sampled_indices.update(neighbors)  # Add new neighbors
        # print(f"____\n END OF ind {ind}; sampled indeices is now {sampled_indices}")

    sampled_indices = [int(el) for el in sampled_indices]

    return sampled_indices


#check the conditions

In [9]:
import torch
import numpy as np
a = torch.Tensor([
    [0,1,0,0,0,0,0,1],
    [1,0,0,1,0,1,0,0],
    [0,1,0,1,0,0,1,0],
    [0,1,0,0,1,0,0,1],
    [1,1,0,0,1,0,0,0],
    [0,0,1,1,0,1,1,0],
    [0,0,0,1,0,0,0,1],
    [0,0,0,1,0,1,1,1]
])
A = []
A.append(a)
inds = [1,2]
nm = 2
bm = 4
seed=42

In [10]:
import torch

def neighbor_analysis(primary_inds: list[int], inds: list[int], adjacency_matrix: torch.Tensor):
    result = []

    # Loop over each primary index
    for primary_idx in primary_inds:
        # Get neighbors for the primary index
        neighbors = []
        for ind in inds:
            # Check if ind is a neighbor of primary_idx
            if adjacency_matrix[primary_idx, ind] == 1:
                neighbors.append(ind)

        # Find neighbors of neighbors (n + 1 neighbors)
        neighbors_of_neighbors = []
        for neighbor in neighbors:
            for ind in inds:
                # Check if ind is a neighbor of the current neighbor
                if adjacency_matrix[neighbor, ind] == 1 and ind != primary_idx:
                    neighbors_of_neighbors.append(ind)

        # Remove duplicates for the neighbors of neighbors
        neighbors_of_neighbors = list(set(neighbors_of_neighbors))

        # Add the tuple to results
        result.append((primary_idx, neighbors, neighbors_of_neighbors))

    return result

In [53]:
import random
inds = random.sample(list(range(12000)),6)
print(inds)
#inds = list(range(101,105))
sample = sample_neighborhood([train_entities], inds, 4, 16)
print(sample)
print(len(sample))

[1811, 4896, 2133, 2576, 6407, 9118]
[2946, 6407, 15881, 6154, 12937, 12812, 2576, 3858, 1811, 14356, 17299, 6164, 11671, 12823, 1818, 14620, 9118, 6046, 4896, 12833, 13087, 15396, 8741, 3110, 17572, 6184, 1962, 2095, 6195, 15668, 15669, 4788, 1977, 1851, 6204, 6208, 14718, 6345, 4306, 2133, 12634, 16989, 2142, 13920, 6370, 7402, 17261, 15727, 13047, 6136, 13051, 5885, 638]
53


In [23]:
analysis = neighbor_analysis(inds, sample, train_entities)
for item in analysis:
  print(item)

(7782, [16911, 8095, 16549, 17580, 16559, 17594, 18107, 17224, 17992, 8009, 12243, 16989, 17251, 16885, 8699, 8061], [17251, 16549, 17224, 17992, 8009, 17580, 8061, 16559, 16911, 12243, 16885, 8699, 17594, 18107, 16989, 8095])
(7572, [], [])
(1572, [1426, 533, 1562, 1692, 1569, 11560, 12599, 1605, 1742, 1615, 11856, 11606], [1569, 1025, 1605, 11560, 1742, 1615, 11856, 1426, 1843, 533, 11606, 12599, 1562, 11547, 1692, 3774])
(9939, [], [])
(2646, [13462, 13466, 12325, 1331, 2642], [1315, 12325, 2642, 1331, 13462, 1305, 13466, 1308, 1310])
(4608, [], [])


In [24]:
# inds = random.sample(list(range(12000)),4)
inds = list(range(101,105))
sample = sample_neighborhood([train_entities], inds, 16, 64, seed=42)
print(sample)
print(len(sample))

[14338, 8, 11785, 13835, 524, 18, 3602, 3611, 3617, 1058, 3621, 2094, 3632, 14384, 14386, 2609, 11828, 14388, 14387, 54, 56, 7231, 13377, 1090, 3651, 65, 71, 14920, 3656, 74, 14922, 3659, 77, 3662, 76, 3664, 80, 1104, 1108, 3669, 3668, 84, 1112, 88, 90, 14430, 14432, 99, 101, 102, 103, 104, 12905, 1127, 1128, 106, 14447, 112, 131, 1156, 1157, 4231, 140, 13968, 13973, 1175, 153, 3738, 3737, 667, 160, 2721, 162, 1187, 3750, 3240, 14510, 178, 1206, 189, 1215, 12479, 194, 12484, 12492, 204, 206, 14038, 13533, 4332, 13038, 3832, 761, 13051, 11010, 1289, 11022, 11023, 787, 11027, 795, 11048, 11054, 11055, 11056, 14642, 1843, 11064, 11070, 11071, 11073, 12099, 12101, 11078, 12102, 11080, 11081, 11082, 3403, 11079, 11085, 846, 11088, 3409, 849, 11090, 11089, 11091, 854, 3415, 1367, 11097, 11098, 11102, 16735, 12126, 11103, 12131, 13669, 11622, 11111, 11116, 12653, 3438, 3439, 11120, 11122, 11125, 12150, 14201, 11130, 3962, 11132, 12155, 12668, 14203, 3458, 3472, 12177, 3987, 11159, 12185, 2972

In [11]:
class GNNDataset(Dataset):
  def __init__(self, H, A, labels, meta_indices, neighbor_max=4, branch_max=16, seed=None):
    """Custom dataset with neighborhood sampling

    Args:
      H : torch.tensor
        input embeddings (n x d)

      A : list[torch.tensor]
        list of (n x n)

      labels : torch.LongTensor
        y

      meta_indices : torch.LongTensor
        index of datapoint to filter validation score

      neighbor_max : int
        max neighbors for each node in mini-batch

      batch_max : int
        max size of batch

    """
    # All inits must be tensors
    self.H = H.to(device)
    self.A = [a.to(device) for a in A]
    self.labels = labels.to(device)
    self.meta_indices = meta_indices
    self.neighbor_max = neighbor_max
    self.branch_max = branch_max
    self.seed = seed

  def __len__(self):
    return len(self.labels)

  def __getitem__(self, inds):
    # print('\n####################\n')
    # print('GET ITEM CALLED', 'INDS:', inds)
    # Sample neighborhood

    # get inds in list
    inds = inds.tolist() if torch.is_tensor(inds) else (inds if isinstance(inds,list) else [inds])

    # return the required inds
    sampled_indices = sample_neighborhood(self.A, inds, self.neighbor_max, self.branch_max,seed=self.seed)

    # get the input for the required inds
    H_batch = self.H[sampled_indices]

    # get the adjacency matrix for the required inds
    A_batch = [self.A[k][sampled_indices][:, sampled_indices] for k in range(len(self.A))]

    # get the labels for the required inds
    labels_batch = self.labels[sampled_indices]

    # get meta indices
    index_batch = self.meta_indices[sampled_indices]

    return H_batch, A_batch, labels_batch, index_batch

In [26]:
df_train.columns

Index(['master', 'book_idx', 'chapter_idx', 'content', 'vanilla_embedding.1',
       'direct_ft_augmented_embedding', 'ner_responses'],
      dtype='object')

In [28]:
H = torch.stack(list(torch.Tensor(df_train['direct_ft_augmented_embedding'])))
labels = torch.LongTensor(list(df_train['chapter_idx']))
A = []
A.append(train_entities)
meta_indices = torch.LongTensor(list(range(df_train.shape[0])))

  H = torch.stack(list(torch.Tensor(df_train['direct_ft_augmented_embedding'])))


In [31]:
train_dataset = GNNDataset(H, A, labels, meta_indices, neighbor_max=4, branch_max=10)

# Prevent dataloader from calling a single index at a time
custom_sampler = torch.utils.data.sampler.BatchSampler(
    torch.utils.data.sampler.RandomSampler(train_dataset),
    batch_size=16,
    drop_last=False)


train_loader = DataLoader(train_dataset, sampler = custom_sampler)


In [34]:
# Check batches

# Number of batches to inspect
num_batches_to_check = 2

for batch_idx, (inputs, adjacency, labels, indices) in enumerate(train_loader):
    print('\n##########################\n')
    print(f"Batch {batch_idx + 1}/{num_batches_to_check}:")
    print('-' * 10)
    print("Inputs:")
    print(f"  Type: {type(inputs)}")
    print(f"  Shape: {inputs.size()}")
    print('-' * 10)
    print("Adjacency:")
    print(f"  Type: {type(adjacency)}")
    print(f"  Shape: {adjacency[0].size()}")
    print('-' * 10)
    print("Indices:")
    print(f"  Type: {type(indices)}")
    print(f"  Shape: {indices.size()}")
    print(indices)
    print('-' * 10)
    print("Labels:")
    print(f"  Type: {type(labels)}")
    print(f"  Shape: {labels.size()}")
    print(labels)

    # Stop after inspecting the desired number of batches
    if batch_idx + 1 >= num_batches_to_check:
        break

RuntimeError: CUDA error: device-side assert triggered
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.


In [37]:
H = torch.stack(list(torch.Tensor(df_val['direct_ft_augmented_embedding'])))
labels = torch.LongTensor(list(df_val['chapter_idx']))
A = []
A.append(val_entities)
meta_indices = torch.LongTensor(list(range(df_val.shape[0])))

In [56]:
validation_dataset = GNNDataset(H, A, labels, meta_indices, neighbor_max=4, branch_max=10, seed=42)

# Prevent dataloader from calling a single index at a time
custom_validation_sampler = torch.utils.data.sampler.BatchSampler(
    torch.utils.data.sampler.SequentialSampler(validation_dataset),
    batch_size=16,
    drop_last=False)


train_loader_fixed = DataLoader(validation_dataset, sampler = custom_validation_sampler)

RuntimeError: CUDA error: device-side assert triggered
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.


In [57]:
# Check batches

# Number of batches to inspect
num_batches_to_check = 2

for batch_idx, (inputs, adjacency, labels, indices) in enumerate(train_loader_fixed):
    print('\n##########################\n')
    print(f"Batch {batch_idx + 1}/{num_batches_to_check}:")
    print('-' * 10)
    print("Inputs:")
    print(f"  Type: {type(inputs)}")
    print(f"  Shape: {inputs.size()}")
    print('-' * 10)
    print("Adjacency:")
    print(f"  Type: {type(adjacency)}")
    print(f"  Shape: {adjacency[0].size()}")
    print('-' * 10)
    print("Indices:")
    print(f"  Type: {type(indices)}")
    print(f"  Shape: {indices.size()}")
    print(indices)
    print('-' * 10)
    print("Labels:")
    print(f"  Type: {type(labels)}")
    print(f"  Shape: {labels.size()}")
    print(labels)

    # Stop after inspecting the desired number of batches
    if batch_idx + 1 >= num_batches_to_check:
        break

NameError: name 'train_loader_fixed' is not defined

The dataloader seems to be working correctly, the implementation of custom sampling on all indices at once leads to DataLoaders collate function inserting a new dimension that it will stach against. Because all indices are dealt with at once, there is no stacking.

The simple solution will be to simply squeeze the tensors in the training loop. The validation loader eradicates randomness from the process.

## $\color{blue}{Model:}$

In [43]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class GNNLayer(nn.Module):
    def __init__(self, in_features, out_features, num_relations=1, dropout=0.4):
        super(GNNLayer, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.num_relations = num_relations
        self.dropout = dropout

        self.T = nn.ParameterList([nn.Parameter(torch.Tensor(in_features, out_features)) for _ in range(num_relations)])
        self.E = nn.ParameterList([nn.Parameter(torch.Tensor(in_features, out_features)) for _ in range(num_relations)])

        # Batch normalization
        self.batch_norm = nn.BatchNorm1d(out_features)

        self.reset_parameters()

    def reset_parameters(self):
        for t in self.T:
            nn.init.xavier_uniform_(t)
        for e in self.E:
            nn.init.xavier_uniform_(e)

    def forward(self, H, A):
        H_out = torch.zeros_like(H)
        for k in range(self.num_relations):
            messages_projection = A[k].T @ H @ self.E[k]
            degrees = A[k].sum(dim=1, keepdim=True)
            degrees[degrees == 0] = 1.0
            messages_projection /= degrees

            self_projection = H @ self.T[k]

            # Include skip connection
            H_out += F.leaky_relu(self_projection + messages_projection) + H

        # Apply batch normalization
        H_out = self.batch_norm(H_out)

        # Apply dropout
        H_out = F.dropout(H_out, p=self.dropout, training=self.training)

        return H_out

class GNNModel(nn.Module):
    def __init__(self, d, h, c, num_relations=1, num_layers=3, dropout=0.4):
        super(GNNModel, self).__init__()
        self.num_layers = num_layers
        self.gnn_layers = nn.ModuleList([GNNLayer(d, d, num_relations, dropout) for _ in range(num_layers)])
        self.fc1 = nn.Linear(d, h)
        self.batch_norm_fc1 = nn.BatchNorm1d(h)
        self.fc2 = nn.Linear(h, c)
        self.dropout = dropout

    def forward(self, H, A):
        for layer in self.gnn_layers:
            H = layer(H, A)

        H = F.relu(self.batch_norm_fc1(self.fc1(H)))
        H = F.dropout(H, p=self.dropout, training=self.training)
        Output = self.fc2(H)
        return Output

    def forward_layer(self, H, A, layer_idx):
        """Forward pass for a specific layer."""
        H = self.gnn_layers[layer_idx](H, A)
        return H


In [44]:
class GNNModel(nn.Module):
   def __init__(self, d, h, c, num_relations=1, num_layers=2):
      super(GNNModel, self).__init__()
      self.num_layers = num_layers
      self.gnn_layers = nn.ModuleList([GNNLayer(d, d) for _ in range(num_layers)])
      self.fc1 = nn.Linear(d, h)
      self.fc2 = nn.Linear(h, c)

   def forward(self, H, A):
      for layer in self.gnn_layers:
         H = layer(H, A)
      # Classification
      H = F.relu(self.fc1(H))
      Output = self.fc2(H)
      return Output

In [45]:
import torch.optim as optim

d = 768
h = 400   # hidden dimension of fully connected layer
c = 70   # number of classes
num_relations = 1   # number of relationship types

# Model, Loss, Optimizer
model = GNNModel(d, h, c, num_relations)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)



In [15]:
def count_parameters_per_module(model):
    print("Module and parameter counts:")

    for name, module in model.named_modules():
        # Skip the top-level module (the model itself)
        if not isinstance(module, nn.Module) or name == "":
            continue

        param_count = sum(p.numel() for p in module.parameters() if p.requires_grad)

        if param_count > 0:  # Only print modules that have parameters
            print(f"{name}: {param_count} parameters")

In [16]:
count_parameters_per_module(model)

Module and parameter counts:
gnn_layers: 2362368 parameters
gnn_layers.0: 1181184 parameters
gnn_layers.0.T: 589824 parameters
gnn_layers.0.E: 589824 parameters
gnn_layers.0.batch_norm: 1536 parameters
gnn_layers.1: 1181184 parameters
gnn_layers.1.T: 589824 parameters
gnn_layers.1.E: 589824 parameters
gnn_layers.1.batch_norm: 1536 parameters
fc1: 307600 parameters
fc2: 28070 parameters


## $\color{blue}{Train-Validate:}$

In [17]:
def accuracy(outputs, labels):
    # argmax to get predicted classes
    _, predicted = torch.max(outputs, 1)

    # count correct
    correct = (predicted == labels).sum().item()

    # get average
    acc = correct / labels.size(0)  # Total number of samples
    return acc

In [38]:
import numpy as np

def train(model, train_loader, criterion, optimizer):
    model.train()
    epoch_train_losses = []
    epoch_train_accuracy = []

    for batch_idx, (H, A, y, indices) in enumerate(train_loader):
        print('train_batch;', batch_idx)
        optimizer.zero_grad()

        H = H.squeeze(0)
        A = [a.squeeze(0) for a in A]
        y = y.squeeze(0)

        out = model(H,A)
        train_loss = criterion(out, y)
        train_accuracy = accuracy(out, y)


        epoch_train_losses.append(train_loss.item())
        epoch_train_accuracy.append(train_accuracy)

        # Backpropagation and optimization
        train_loss.backward()
        optimizer.step()

    return np.mean(epoch_train_losses), np.mean(epoch_train_accuracy)

In [39]:
def validate(model, dev_loader, criterion, threshold=746):
    model.eval()
    epoch_dev_losses = []
    epoch_dev_accuracy = []
    pred_holder = []
    real_holder = []

    with torch.no_grad():
        for batch_idx, (H, A, y, indices) in enumerate(dev_loader):
            print('val_batch', batch_idx)
            H = H.squeeze(0)
            A = [a.squeeze(0) for a in A]
            y = y.squeeze(0)
            indices = indices.squeeze(0)

            out = model(H, A)

            # Filter out training points
            mask = indices < threshold
            filtered_out = out[mask]
            filtered_y = y[mask]

            # Calculate loss and accuracy only on filtered outputs
            if filtered_out.size(0) > 0:  # Ensure there are samples to evaluate
                dev_loss = criterion(filtered_out, filtered_y)
                dev_accuracy = accuracy(filtered_out, filtered_y)

                epoch_dev_losses.append(dev_loss.item())
                epoch_dev_accuracy.append(dev_accuracy)

    # Avoid division by zero if no validation points were processed
    return np.mean(epoch_dev_losses), np.mean(epoch_dev_accuracy)

## $\color{blue}{Training:}$

In [20]:
from collections import namedtuple
Stats = namedtuple('Stats', [
    'train_loss',
    'train_accuracy',
    'dev_loss',
    'dev_accuracy',
    'epoch',
    'lr',
    'alpha',
    'max_accuracy'
])

In [21]:
def gen_config(lr_low, lr_high, alpha_low, alpha_high):
  np.random.seed()
  lr = round(10**float(np.random.uniform(lr_low,lr_high)),6)
  alpha = round(10**float(np.random.uniform(alpha_low,alpha_high)),6)
  return lr, alpha

In [22]:
def gen_ranges( lr, lr_range, alpha, alpha_range):

  lr_center = lr
  lr_low = lr_center - lr_range/2
  lr_high = lr_center + lr_range/2
  lr_diff = lr_high - lr_low

  alpha_center = alpha
  alpha_low = alpha_center - alpha_range/2
  alpha_high = alpha_center + alpha_range/2
  alpha_diff = alpha_high - alpha_low

  return (lr_low, lr_high, alpha_low, alpha_high)

In [23]:
def search_stats(results):
  best_stats = None
  max_dev_accuracy = 0
  for i in range(len(results)):
    acc = results[i].dev_accuracy
    if acc > max_dev_accuracy:
      best_stats = results[i]
      max_dev_accuracy = acc
  return best_stats

In [46]:
def tv_run(epochs, model, lr, alpha, max_accuracy, path, verbose = 0):
  """
  Runs a training setup
  verbose == 1 - print model results
  verbose == 2 -> print epoch and model results
  """
  model = model.to(device)
  criterion = nn.CrossEntropyLoss()
  optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=alpha)

  # Prepare data loaders
  train_loader = DataLoader(train_dataset, sampler = custom_train_sampler)
  dev_loader = DataLoader(validation_dataset, sampler = custom_validation_sampler)

  # Hold epoch stats
  train_losses = []
  train_accuracy = []
  dev_losses = []
  dev_accuracy = []
  epoch_holder = []

  # Break if no improvement
  current_best = 0
  no_improvement = 0


  # Run epochs
  for epoch in range(epochs):

    # break out of epochs
    if no_improvement >= 6:
      break

    # call training and validation functions
    train_loss, train_acc = train(model, train_loader, criterion, optimizer)
    dev_loss, dev_acc = validate(model, dev_loader, criterion)

    # Store epoch stats
    train_losses.append(train_loss)
    train_accuracy.append(train_acc)
    dev_losses.append(dev_loss)
    dev_accuracy.append(dev_acc)
    epoch_holder.append(epoch + 1)

    # check for improvement
    if dev_acc > current_best:
      current_best = dev_acc
      no_improvement = 0
    else:
      no_improvement += 1

    # save best model
    if dev_acc > max_accuracy:
      torch.save(model.state_dict(), path)
      max_accuracy = dev_acc


    # optionally print epoch results
    if verbose == 2:
      print(f'\n --------- \nEpoch: {epoch + 1}\n')
      print(f'Epoch {epoch + 1} train loss: {train_loss:.4f}')
      print(f'Epoch {epoch + 1} train accuracy: {train_acc:.4f}')
      print(f'Epoch {epoch + 1} dev loss: {dev_loss:.4f}')
      print(f'Epoch {epoch + 1} dev accuracy: {dev_acc:.4f}')

      # save best results
  max_ind = np.argmax(dev_accuracy)

  stats = Stats(
      train_losses[max_ind],
      train_accuracy[max_ind],
      dev_losses[max_ind],
      dev_accuracy[max_ind],
      epoch_holder[max_ind],
      lr, alpha,
      max_accuracy
  )

  # optionally print model results
  if verbose in [1,2]:
    print('\n ######## \n')
    print(f'lr:{stats.lr}, alpha:{stats.alpha} @ epoch {stats.epoch}.')
    print(f'TL:{stats.train_loss}, TA:{stats.train_accuracy}.')
    print(f'DL:{stats.dev_loss}, DA:{stats.dev_accuracy}')

  return stats

#### $\color{red}{Sanity-check:}$

In [25]:
# model
model = GNNModel(d, h, c, num_relations)

In [64]:
df_val.columns

Index(['book_idx', 'chapter_idx', 'content', 'direct_ft_augmented_embedding',
       'ner_responses'],
      dtype='object')

In [31]:
torch.Tensor(np.stack(df_train['direct_ft_augmented_embedding'].to_list())).size()

torch.Size([20474, 768])

In [47]:
# training loader
H_train = torch.Tensor(np.stack(df_train['direct_ft_augmented_embedding'].to_list()))
labels_train = torch.LongTensor(list(df_train['chapter_idx']))
A_train = []
A_train.append(train_entities)
train_indices = torch.LongTensor(list(range(df_train.shape[0])))

train_dataset = GNNDataset(H_train, A_train, labels_train, train_indices, neighbor_max=4, branch_max=16)

# Prevent dataloader from calling a single index at a time
custom_train_sampler = torch.utils.data.sampler.BatchSampler(
    torch.utils.data.sampler.RandomSampler(train_dataset),
    batch_size=64,
    drop_last=False)


train_loader = DataLoader(train_dataset, sampler = custom_train_sampler)

# training loader
df1 = df_train[['direct_ft_augmented_embedding', 'chapter_idx']]
df2 = df_dev[['direct_ft_augmented_embedding', 'chapter_idx']]
df_val = pd.concat([df2, df1])
H_val = torch.Tensor(np.stack(df_val['direct_ft_augmented_embedding'].to_list()))
labels_val = torch.LongTensor(list(df_val['chapter_idx']))
A_val = []
A_val.append(val_entities)
val_indices = torch.LongTensor(list(range(df_val.shape[0])))



validation_dataset = GNNDataset(H_val, A_val, labels_val, val_indices, neighbor_max=4, branch_max=16, seed=42)

# Prevent dataloader from calling a single index at a time
custom_validation_sampler = torch.utils.data.sampler.BatchSampler(
    torch.utils.data.sampler.SequentialSampler(validation_dataset),
    batch_size=64,
    drop_last=False)


dev_loader = DataLoader(validation_dataset, sampler = custom_validation_sampler)



In [49]:
epochs = 20
lr = 0.00018
alpha = 0.0006
path = "class/models/GNN_augmented_ft.pt"
max_accuracy = 0.7168

In [50]:
tv_run(epochs, model, lr, alpha, max_accuracy, path, verbose = 2)

train_batch; 0
train_batch; 1
train_batch; 2
train_batch; 3
train_batch; 4
train_batch; 5
train_batch; 6
train_batch; 7
train_batch; 8
train_batch; 9
train_batch; 10
train_batch; 11
train_batch; 12
train_batch; 13
train_batch; 14
train_batch; 15
train_batch; 16
train_batch; 17
train_batch; 18
train_batch; 19
train_batch; 20
train_batch; 21
train_batch; 22
train_batch; 23
train_batch; 24
train_batch; 25
train_batch; 26
train_batch; 27
train_batch; 28
train_batch; 29
train_batch; 30
train_batch; 31
train_batch; 32
train_batch; 33
train_batch; 34
train_batch; 35
train_batch; 36
train_batch; 37
train_batch; 38
train_batch; 39
train_batch; 40
train_batch; 41
train_batch; 42
train_batch; 43
train_batch; 44
train_batch; 45
train_batch; 46
train_batch; 47
train_batch; 48
train_batch; 49
train_batch; 50
train_batch; 51
train_batch; 52
train_batch; 53
train_batch; 54
train_batch; 55
train_batch; 56
train_batch; 57
train_batch; 58
train_batch; 59
train_batch; 60
train_batch; 61
train_batch; 62
tr

KeyboardInterrupt: 

#### $\color{red}{Run:}$

In [None]:
"""
Main Admin
"""
epochs = 30
max_accuracy = 0
path = "class/models/GNN.3.pt"
results = []

"""
init random search
lr [10^-5 - 10^-1]
alpha [10^-5 - 10^-1]
bs [8, 32, 128]
"""
lr_low = -5
lr_high = -3
lr_range = lr_high - lr_low

alpha_low = -5
alpha_high = -3
alpha_range = alpha_high - alpha_low

d = 768
h = 400
c = 70
num_relations = 1

count = 0

"""
Hyperparameter Search
"""

for i in range(3):
  # debug
  print("\n################\n")
  print(f'round: {i}')
  # print(f'lr_low{lr_low}, lr_high{lr_high}, lr_range{lr_range}')
  # print(f'alpha_low{alpha_low}, lr_high{alpha_high}, lr_range{alpha_range}')
  print('max', max_accuracy)
  print("\n################\n")


  for j in range(6):
    count += 1
    print(count)

    # get config
    lr, alpha = gen_config(lr_low, lr_high, alpha_low, alpha_high)
    # define model
    model = GNNModel(d, h, c, num_relations)
    model = model.to(device)

    # run training
    res = tv_run(epochs, model, lr, alpha, max_accuracy, path, verbose = 2)
    max_accuracy = res.max_accuracy
    results.append(res)

  # get best result of the round or even so far
  stats = search_stats(results)


  print(stats) # debug

  # reconfigure the new hypers
  lr = np.log10(stats.lr)
  lr_range = lr_range / 3

  alpha = np.log10(stats.alpha)
  alpha_range = alpha_range / 3

  config = gen_ranges(lr, lr_range, alpha, alpha_range)
  lr_low, lr_high, alpha_low, alpha_high = config
  lr_range = lr_high - lr_low
  alpha_range = alpha_high - alpha_low



################

round: 0
max 0

################

1

 --------- 
Epoch: 1

Epoch 1 train loss: 3.0942
Epoch 1 train accuracy: 0.2960
Epoch 1 dev loss: 1.9864
Epoch 1 dev accuracy: 0.4889

 --------- 
Epoch: 2

Epoch 2 train loss: 1.9579
Epoch 2 train accuracy: 0.5328
Epoch 2 dev loss: 1.4462
Epoch 2 dev accuracy: 0.5802

 --------- 
Epoch: 3

Epoch 3 train loss: 1.4683
Epoch 3 train accuracy: 0.6356
Epoch 3 dev loss: 1.2633
Epoch 3 dev accuracy: 0.6220

 --------- 
Epoch: 4

Epoch 4 train loss: 1.1842
Epoch 4 train accuracy: 0.6981
Epoch 4 dev loss: 1.1948
Epoch 4 dev accuracy: 0.6162

 --------- 
Epoch: 5

Epoch 5 train loss: 0.9946
Epoch 5 train accuracy: 0.7451
Epoch 5 dev loss: 1.1355
Epoch 5 dev accuracy: 0.6359

 --------- 
Epoch: 6

Epoch 6 train loss: 0.8639
Epoch 6 train accuracy: 0.7772
Epoch 6 dev loss: 1.1321
Epoch 6 dev accuracy: 0.6428

 --------- 
Epoch: 7

Epoch 7 train loss: 0.7508
Epoch 7 train accuracy: 0.8082
Epoch 7 dev loss: 1.1236
Epoch 7 dev accuracy: 0.6572
