# Text Classification - Training a GNN


## $\color{blue}{Sections:}$

* Preamble
1.   Admin
2.   Dataset
3.   Model
4.   Train - Validate
5.   Training Loop

## $\color{blue}{Preamble:}$

We now train a GNN in basic PyTorch. The model will look like a GCN. Inference willbe completed in another notebook as the whole graph must be uploaded at once.

## $\color{blue}{Admin}$
* Install relevant Libraries
* Import relevant Libraries

In [None]:
import torch
import pandas as pd
from google.colab import drive
import numpy as np
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

cuda


In [None]:
drive.mount("/content/drive")
%cd '/content/drive/MyDrive'

Mounted at /content/drive
/content/drive/MyDrive


## $\color{blue}{Data}$

* Connect to Drive
* Load the data
* Load adjacency matrices

In [None]:
path = 'class/datasets/'
df_train = pd.read_pickle(path + 'df_train')
df_dev = pd.read_pickle(path + 'df_dev')
df_test = pd.read_pickle(path + 'df_test')

In [None]:
path = 'class/tensors/adj_{}.pt'

# train
train_people = torch.load(path.format('train_people'))
train_locations = torch.load(path.format('train_locations'))
train_entities = torch.load(path.format('train_entities'))

# dev
dev_people = torch.load(path.format('dev_people'))
dev_locations = torch.load(path.format('dev_locations'))
dev_entities = torch.load(path.format('dev_entities'))

# val (contains the adjacency matrix for both the training and the development set)
val_people = torch.load(path.format('val_people.1'))
val_locations = torch.load(path.format('val_locations.1'))
val_entities = torch.load(path.format('val_entities.1'))

  train_people = torch.load(path.format('train_people'))
  train_locations = torch.load(path.format('train_locations'))
  train_entities = torch.load(path.format('train_entities'))
  dev_people = torch.load(path.format('dev_people'))
  dev_locations = torch.load(path.format('dev_locations'))
  dev_entities = torch.load(path.format('dev_entities'))
  val_people = torch.load(path.format('val_people.1'))
  val_locations = torch.load(path.format('val_locations.1'))
  val_entities = torch.load(path.format('val_entities.1'))


## $\color{blue}{Sampling:}$

In [None]:
from torch.utils.data import Dataset, DataLoader
from copy import deepcopy

def sample_neighborhood(primary_inds, input, adj, distance, neighbor_max = 4):
    """
    Takes the given inds and the inputs, returns the sampled set of indices and the corresponding activation flag.
    If the activation flag is True, then the the datapoint is primary and has been search for neighbors and neighbors
    of neighbors. This indicates that we rely on the datapoint for metrics and loss.

    Args:
      primary_inds : iterable
          indices sampled by the dataloader
      input : torch.Tensor
          (m,d) input tensor
      adj : torch.Tensor
          (m,m) adjacency matrix
      distance : torch.Tensor
          (m,m) : cosine similarity between all inputs
      neighbour_max : int (optional)
          The maximum number of neighbors to consider for each point

    Returns:
      sampled_indices : list
          indices of all datapoints to be processed in the batch
      activation_flag : list
          boolean_flag indicating whether the corresponding datapoint is to be considered for metrics
    """

    def _get_closest_neighbors(ind):
      """get up to neighbor_max close neighbors"""
      local_neighbors = []
      local_activation_flag = []
      candidate_neighbors = [neighbor.item() for neighbor in (adj[ind] > 0).nonzero(as_tuple=True)[0] if neighbor.item() not in sampled_indices]
      candidate_distances = [(neighbor, distance[primary_ind][neighbor]) for neighbor in candidate_neighbors]
      sorted_neighbors = sorted(candidate_distances, key=lambda x: x[1])
      return [neighbor for neighbor, dist in sorted_neighbors[:neighbor_max]], candidate_neighbors

    sampled_indices = []
    activation_flag = []
    all_banned_neighbors = []

    for primary_ind in primary_inds:

      # if primary ind has been added as a neighbor, convert the activation flag to true, else add it as a standard primary index
      if primary_ind in sampled_indices:
        activation_flag[sampled_indices.index(primary_ind)] = True
      else:
        sampled_indices.append(primary_ind)
        activation_flag.append(True)

      # print('\n', primary_ind)
      # print('sampled_indices', sampled_indices)

      level_1_neighbors, candidate_neighbors = _get_closest_neighbors(primary_ind)
      banned_neighbors = list(set(candidate_neighbors) - set(level_1_neighbors))
      all_banned_neighbors.extend(banned_neighbors)
      # print('banned_neighbors', banned_neighbors)
      # print('all_banned_neighbors', all_banned_neighbors)

      level_1_activation_flag = [False for el in level_1_neighbors]
      sampled_indices.extend(level_1_neighbors)
      activation_flag.extend(level_1_activation_flag)

      # print('level_1_neighbors', level_1_neighbors)
      # print('sampled_indices', sampled_indices)
      # print('level_1_activation_flag', level_1_activation_flag)
      # print('activation_flag', activation_flag)

      for level_1_ind in level_1_neighbors:
        level_2_neighbors, _ = _get_closest_neighbors(level_1_ind)
        level_2_activation_flag = [False for el in level_2_neighbors]
        sampled_indices.extend(level_2_neighbors)
        activation_flag.extend(level_2_activation_flag)

        # print('level_2_neighbors', level_2_neighbors)
        # print('sampled_indices', sampled_indices)
        # print('level_2_activation_flag', level_2_activation_flag)
        # print('activation_flag', activation_flag)

    # include only 4 level one neighbors for each primary index to avoid pollution
    clean_indices = []
    clean_flags = []
    for i in range(len(sampled_indices)):
      target = sampled_indices[i]
      if target in primary_inds:
        clean_indices.append(target)
        clean_flags.append(True)
      elif target in all_banned_neighbors:
        continue
      else:
        clean_indices.append(target)
        clean_flags.append(False)
    # print(f'\ncleaning\nsampled_indices : {sampled_indices}\nclean_indices : {clean_indices}')
    return clean_indices, clean_flags


## $\color{blue}{Sampling Test:}$

In [None]:
import torch
import numpy as np
a = torch.Tensor([
    [0,1,0,0,0,0,0,1],
    [1,0,0,1,0,1,0,0],
    [0,1,0,1,0,0,1,0],
    [0,1,0,0,1,0,0,1],
    [1,1,0,0,1,0,0,0],
    [0,0,1,1,0,1,1,0],
    [0,0,0,1,0,0,0,1],
    [0,0,0,1,0,1,1,1]
])

seed=42

In [None]:
import torch

def neighbor_analysis(primary_inds: list[int], inds: list[int], adjacency_matrix: torch.Tensor):
    result = []

    # Loop over each primary index
    for primary_idx in primary_inds:
        # Get neighbors for the primary index
        neighbors = []
        for ind in inds:
            # Check if ind is a neighbor of primary_idx
            if adjacency_matrix[primary_idx, ind] == 1:
                neighbors.append(ind)

        # Find neighbors of neighbors (n + 1 neighbors)
        neighbors_of_neighbors = []
        for neighbor in neighbors:
            for ind in inds:
                # Check if ind is a neighbor of the current neighbor
                if adjacency_matrix[neighbor, ind] == 1 and ind != primary_idx:
                    neighbors_of_neighbors.append(ind)

        # Remove duplicates for the neighbors of neighbors
        neighbors_of_neighbors = list(set(neighbors_of_neighbors))

        # Add the tuple to results
        result.append((primary_idx, neighbors, neighbors_of_neighbors))

    return result

primary_inds, input, adj, distance,

In [None]:
def calculate_cosine(H):
    dot_product = torch.matmul(H, H.transpose(0, 1))  # shape (m, m)

    lengths = torch.sqrt(dot_product.diagonal()).unsqueeze(1)  # shape (m, 1)
    denominator = lengths @ lengths.transpose(0, 1)  # shape (m, m)

    return dot_product / denominator  # shape (m, m)

In [None]:
H = torch.stack(list(df_train['vanilla_embedding.1']))


In [None]:
distance = calculate_cosine(H)

In [None]:
import random
inds = random.sample(list(range(12000)),6)
print(inds)
sample = sample_neighborhood(inds,H,train_entities,distance)
print(sample)
print(len(sample[0]))

[6339, 2165, 4826, 11292, 11280, 5401]

 6339
sampled_indices [6339]
banned_neighbors [9728, 3589, 7176, 11273, 11784, 3084, 3597, 5644, 10259, 11796, 2581, 5653, 6170, 3617, 6690, 7203, 11302, 552, 8755, 10804, 6198, 3640, 8252, 2622, 64, 7744, 10305, 10817, 10823, 9802, 10826, 591, 81, 7250, 5207, 4187, 9821, 10847, 2144, 9825, 11872, 4195, 11877, 5223, 4717, 10861, 8305, 5246, 9342, 1664, 5760, 3714, 10879, 6276, 10375, 648, 11915, 144, 2195, 3733, 10393, 6298, 6814, 8865, 6819, 1702, 9897, 4781, 8365, 1711, 3762, 7346, 9906, 7864, 11449, 2234, 5307, 7866, 3261, 11453, 8383, 10431, 1730, 3266, 2247, 4808, 9415, 6862, 2768, 3797, 9431, 6363, 7387, 10462, 5856, 7394, 1253, 6377, 7403, 2796, 2286, 3311, 1776, 5876, 7928, 2810, 765, 5885, 9471, 1794, 1795, 6403, 6920, 1290, 2315, 4363, 1293, 8464, 11025, 8468, 3353, 10532, 8485, 7975, 9511, 10024, 1838, 3375, 306, 11571, 6965, 7477, 9525, 7482, 2366, 4414, 2880, 3392, 8000, 4419, 8001, 11595, 7501, 10576, 9042, 4436, 345, 3418, 9568, 24

In [None]:
analysis = neighbor_analysis(inds, sample[0], train_entities)
for item in analysis:
  print(item)

(6339, [10087, 11957, 10768, 11320], [8961, 10768, 26, 6427, 8605, 6948, 167, 11957, 11320, 8133, 1226, 11083, 721, 8020, 6232, 3420, 10461, 2143, 4576, 10087, 11625, 6129, 7409, 2803, 2165])
(2165, [10087, 6427, 8605, 4576, 8020], [8961, 10768, 26, 6427, 8605, 167, 11957, 11320, 6339, 8133, 11083, 721, 8020, 6232, 3420, 4576, 10087, 11625, 6129, 7409, 2803])
(4826, [], [])
(11292, [6129, 2143, 9235, 8778, 11085, 3998], [8961, 2183, 6283, 2701, 655, 10768, 1936, 8465, 9235, 6427, 3613, 4894, 3998, 545, 2341, 167, 2425, 4905, 5419, 560, 8114, 9910, 11320, 4028, 8778, 1612, 11085, 721, 3420, 4574, 2143, 6116, 613, 10087, 3815, 11625, 7399, 7409, 7153, 2803, 5108, 6129, 10489, 1274])
(11280, [], [])
(5401, [9910, 5054, 1274], [2183, 6283, 2701, 655, 1936, 9235, 4894, 545, 2341, 5419, 560, 11186, 4028, 5054, 8778, 1612, 4685, 4695, 4574, 3815, 7399, 8812, 6129, 7153, 5108, 10489, 1274])


In [None]:
# inds = random.sample(list(range(12000)),4)
inds = list(range(101,105))
sample = sample_neighborhood([train_entities], inds, 16, 64, seed=42)
print(sample)
print(len(sample))

[9221, 2061, 3599, 10261, 9214, 6937, 4386, 2855, 8492, 4654, 10039, 4152, 2363, 11581, 317, 6210, 580, 6725, 9046, 599, 5721, 3673, 1883, 3679, 4707, 4964, 101, 102, 103, 104, 8551, 5480, 5485, 11121, 3953, 7285, 8824, 11904, 6274, 9607, 11660, 7317, 5271, 8856, 3235, 7594, 4524, 5550, 10927, 4023, 1722, 4030, 10430, 8641, 4305, 8916, 9688, 2009, 10203, 10207, 9440, 9185, 1250, 8680, 6640, 11763, 2804, 510]
68


## $\color{blue}{Dataset:}$

In [None]:
import torch

class GNNDataset(Dataset):
  def __init__(self, H, A, labels, length, neighbor_max=4):
    """Custom dataset with neighborhood sampling

    Args:
      H : torch.tensor
        input embeddings (n x d)

      A : torch.tensor
        adjacency matrix (n x n)

      labels : torch.LongTensor
        y

      meta_indices : torch.LongTensor
        index of datapoint to filter validation score

      neighbor_max : int
        max neighbors for each node in mini-batch

      batch_max : int
        max size of batch

    """
    # All inits must be tensors
    self.H = H.to(device)
    self.A = A.to(device)
    self.cosine = self.calculate_cosine(self.H)
    self.labels = labels.to(device)
    self.neighbor_max = neighbor_max
    self.length = length

  def __len__(self):
    return self.length

  def __getitem__(self, inds):
    # Sample neighborhood

    # get inds in list
    inds = inds.tolist() if torch.is_tensor(inds) else (inds if isinstance(inds,list) else [inds])

    # return the required inds (The inds are the sampes and the active flag dictates relatively if that sample should be counted)
    sampled_indices, active_flag = sample_neighborhood(inds, self.H, self.A, self.cosine, self.neighbor_max)

    # get the input for the required inds
    H_batch = self.H[sampled_indices]

    # get the adjacency matrix for the required inds
    A_batch = self.A[sampled_indices][:, sampled_indices]

    # get the labels for the required inds
    labels_batch = self.labels[sampled_indices]

    return H_batch, A_batch, labels_batch, torch.LongTensor(active_flag).to(device)

  def calculate_cosine(self, H):
      dot_product = torch.matmul(H, H.transpose(0, 1))  # shape (m, m)

      lengths = torch.sqrt(dot_product.diagonal()).unsqueeze(1)  # shape (m, 1)
      denominator = lengths @ lengths.transpose(0, 1)  # shape (m, m)

      return dot_product / denominator  # shape (m, m)


 H, A, labels, length, neighbor_max=4

In [None]:
# # training loader
# H_train = torch.stack(list(df_train['vanilla_embedding.1']))
# labels_train = torch.LongTensor(list(df_train['chapter_idx']))
# A_train = train_entities


# train_dataset = GNNDataset(H_train, A_train, labels_train, H_train.size(0), neighbor_max=4)

# # Prevent dataloader from calling a single index at a time
# custom_train_sampler = torch.utils.data.sampler.BatchSampler(
#     torch.utils.data.sampler.RandomSampler(train_dataset),
#     batch_size=8,
#     drop_last=False)


# train_loader = DataLoader(train_dataset, sampler = custom_train_sampler)

# validation loader
df1 = df_train[['vanilla_embedding.1', 'chapter_idx']]
df2 = df_dev[['vanilla_embedding.1', 'chapter_idx']]
df_val = pd.concat([df2, df1])
H_val = torch.stack(list(df_val['vanilla_embedding.1']))
labels_val = torch.LongTensor(list(df_val['chapter_idx']))
A_val = val_entities

validation_dataset = GNNDataset(H_val, A_val, labels_val, df2.shape[0], neighbor_max=4)

# Prevent dataloader from calling a single index at a time
custom_validation_sampler = torch.utils.data.sampler.BatchSampler(
    torch.utils.data.sampler.SequentialSampler(validation_dataset),
    batch_size=8,
    drop_last=False)

dev_loader = DataLoader(validation_dataset, sampler = custom_validation_sampler)


In [None]:
# training loader created with a subset of the data
H_train = torch.stack(list(df_train['vanilla_embedding.1'])[:2000])
labels_train = torch.LongTensor(list(df_train['chapter_idx'])[:2000])
A_train = train_entities[:2000,:2000]


train_dataset = GNNDataset(H_train, A_train, labels_train, H_train.size(0), neighbor_max=4)

# Prevent dataloader from calling a single index at a time
custom_train_sampler = torch.utils.data.sampler.BatchSampler(
    torch.utils.data.sampler.RandomSampler(train_dataset),
    batch_size=8,
    drop_last=False)


train_loader = DataLoader(train_dataset, sampler = custom_train_sampler)

In [None]:
# Check batches

# Number of batches to inspect
num_batches_to_check = 2

for batch_idx, (inputs, adjacency, labels, flag) in enumerate(train_loader):
    print('\n##########################\n')
    print(f"Batch {batch_idx + 1}/{num_batches_to_check}:")
    print('-' * 10)
    print("Inputs:")
    print(f"  Type: {type(inputs)}")
    print(f"  Shape: {inputs.size()}")
    print('-' * 10)
    print("Adjacency:")
    print(f"  Type: {type(adjacency)}")
    print(f"  Shape: {adjacency[0].size()}")
    print('-' * 10)
    print("Indices:")
    print(f"  Type: {type(flag)}")
    print(f"  Shape: {flag.size()}")
    print(flag)
    print('-' * 10)
    print("Labels:")
    print(f"  Type: {type(labels)}")
    print(f"  Shape: {labels.size()}")
    print(labels)

    # Stop after inspecting the desired number of batches
    if batch_idx + 1 >= num_batches_to_check:
        break


##########################

Batch 1/2:
----------
Inputs:
  Type: <class 'torch.Tensor'>
  Shape: torch.Size([1, 13, 768])
----------
Adjacency:
  Type: <class 'torch.Tensor'>
  Shape: torch.Size([13, 13])
----------
Indices:
  Type: <class 'torch.Tensor'>
  Shape: torch.Size([1, 13])
tensor([[1, 1, 1, 1, 0, 0, 0, 0, 0, 1, 1, 1, 1]], device='cuda:0')
----------
Labels:
  Type: <class 'torch.Tensor'>
  Shape: torch.Size([1, 13])
tensor([[26, 17, 64, 12, 11, 16, 16, 16, 10, 63, 17, 14,  2]], device='cuda:0')

##########################

Batch 2/2:
----------
Inputs:
  Type: <class 'torch.Tensor'>
  Shape: torch.Size([1, 61, 768])
----------
Adjacency:
  Type: <class 'torch.Tensor'>
  Shape: torch.Size([61, 61])
----------
Indices:
  Type: <class 'torch.Tensor'>
  Shape: torch.Size([1, 61])
tensor([[1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0,

The dataloader seems to be working correctly, the implementation of custom sampling on all indices at once leads to DataLoaders collate function inserting a new dimension that it will stach against. Because all indices are dealt with at once, there is no stacking.

The simple solution will be to simply squeeze the tensors in the training loop. The validation loader eradicates randomness from the process.

## $\color{blue}{Model:}$

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class GNNLayer(nn.Module):
    def __init__(self, in_features, out_features, dropout=0.3, training=True):
        super(GNNLayer, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.dropout = dropout
        self.training = training

        self.T = nn.Parameter(torch.Tensor(in_features, out_features))
        self.E = nn.Parameter(torch.Tensor(in_features, out_features))

        # Batch normalization
        self.batch_norm = nn.BatchNorm1d(out_features)

        self.reset_parameters()

    def reset_parameters(self):
        nn.init.xavier_uniform_(self.T)
        nn.init.xavier_uniform_(self.E)

    def forward(self, H, A):
        messages_projection = A.T @ H @ self.E
        degrees = A.sum(dim=1, keepdim=True)
        degrees[degrees == 0] = 1.0
        messages_projection /= degrees

        self_projection = H @ self.T

        # Include skip connection
        H_out = F.leaky_relu(self_projection + messages_projection) + H
        H_out = F.dropout(H_out, p=self.dropout)

        # Apply batch normalization
        H_out = self.batch_norm(H_out)

        return H_out

class GNNModel(nn.Module):
    def __init__(self, d, h, c, num_layers=2, dropout=0.3):
        super(GNNModel, self).__init__()
        self.num_layers = num_layers
        self.gnn_layers = nn.ModuleList([GNNLayer(d, d, dropout) for _ in range(num_layers)])
        self.fc1 = nn.Linear(d, h)
        self.batch_norm_fc1 = nn.BatchNorm1d(h)
        self.fc2 = nn.Linear(h, c)
        self.dropout = dropout
        self.reset_parameters()

    def reset_parameters(self):
        nn.init.xavier_uniform_(self.fc1.weight)
        nn.init.xavier_uniform_(self.fc2.weight)
        nn.init.zeros_(self.fc1.bias)
        nn.init.zeros_(self.fc2.bias)

    def forward(self, H, A):
        for layer in self.gnn_layers:
            H = layer(H, A)

        H = F.dropout(H, p=self.dropout, training=self.training)
        H = F.relu(self.batch_norm_fc1(self.fc1(H)))
        Output = self.fc2(H)
        return Output

    def forward_layer(self, H, A, layer_idx):
        """Forward pass for a specific layer."""
        H = self.gnn_layers[layer_idx](H, A)
        return H


In [None]:
import torch.optim as optim

d = 768
h = 400   # hidden dimension of fully connected layer
c = 70   # number of classes

# Model, Loss, Optimizer
model = GNNModel(d, h, c)
criterion = nn.CrossEntropyLoss()
optimizer = optim.AdamW(model.parameters(), lr=0.0001)



In [None]:
def count_parameters_per_module(model):
    print("Module and parameter counts:")

    for name, module in model.named_modules():
        # Skip the top-level module (the model itself)
        if not isinstance(module, nn.Module) or name == "":
            continue

        param_count = sum(p.numel() for p in module.parameters() if p.requires_grad)

        if param_count > 0:  # Only print modules that have parameters
            print(f"{name}: {param_count} parameters")

In [None]:
count_parameters_per_module(model)

Module and parameter counts:
gnn_layers: 2362368 parameters
gnn_layers.0: 1181184 parameters
gnn_layers.0.batch_norm: 1536 parameters
gnn_layers.1: 1181184 parameters
gnn_layers.1.batch_norm: 1536 parameters
fc1: 307600 parameters
batch_norm_fc1: 800 parameters
fc2: 28070 parameters


## $\color{blue}{Train-Validate:}$

In [None]:
def accuracy(outputs, labels):
    # argmax to get predicted classes
    _, predicted = torch.max(outputs, 1)

    # count correct
    correct = (predicted == labels).sum().item()

    # get average
    acc = correct / labels.size(0)  # Total number of samples
    return acc

In [None]:
import numpy as np

def train(model, train_loader, criterion, optimizer):
    model.train()
    epoch_train_losses = []
    epoch_train_accuracy = []

    for batch_idx, (H, A, y, flag) in enumerate(train_loader):
        optimizer.zero_grad()

        H = H.squeeze(0)
        A = A.squeeze(0)
        y = y.squeeze(0)
        flag = flag.squeeze(0)

        out = model(H,A)



        train_loss = criterion(out, y)
        train_accuracy = accuracy(out, y)


        epoch_train_losses.append(train_loss.item())
        epoch_train_accuracy.append(train_accuracy)

        # Backpropagation and optimization
        train_loss.backward()
        optimizer.step()

    return np.mean(epoch_train_losses), np.mean(epoch_train_accuracy)

In [None]:
def validate(model, dev_loader, criterion):
    model.eval()
    epoch_dev_losses = []
    epoch_dev_accuracy = []
    pred_holder = []
    real_holder = []

    with torch.no_grad():
        for batch_idx, (H, A, y, flag) in enumerate(dev_loader):
            H = H.squeeze(0)
            A = A.squeeze(0)
            y = y.squeeze(0)
            flag = flag.squeeze(0).bool()

            out = model(H, A)

            filtered_out = out[flag]
            filtered_y = y[flag]

            dev_loss = criterion(filtered_out, filtered_y)
            dev_accuracy = accuracy(filtered_out, filtered_y)

            epoch_dev_losses.append(dev_loss.item())
            epoch_dev_accuracy.append(dev_accuracy)

    # Avoid division by zero if no validation points were processed
    return np.mean(epoch_dev_losses), np.mean(epoch_dev_accuracy)

## $\color{blue}{Training:}$

In [None]:
from collections import namedtuple
Stats = namedtuple('Stats', [
    'train_loss',
    'train_accuracy',
    'dev_loss',
    'dev_accuracy',
    'epoch',
    'lr',
    'alpha',
    'max_accuracy'
])

In [None]:
import time
def tv_run(epochs, model, lr, alpha, max_accuracy, path, verbose = 0):
  """
  Runs a training setup
  verbose == 1 - print model results
  verbose == 2 -> print epoch and model results
  """
  model = model.to(device)
  criterion = nn.CrossEntropyLoss()
  optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=alpha)

  # Prepare data loaders
  train_loader = DataLoader(train_dataset, sampler = custom_train_sampler)
  dev_loader = DataLoader(validation_dataset, sampler = custom_validation_sampler)

  # Hold epoch stats
  train_losses = []
  train_accuracy = []
  dev_losses = []
  dev_accuracy = []
  epoch_holder = []

  # Break if no improvement
  current_best = 0
  no_improvement = 0


  # Run epochs
  for epoch in range(epochs):

    # break out of epochs
    if no_improvement >= 6:
      break

    # call training
    torch.cuda.reset_peak_memory_stats()  # Reset memory stats
    start_time = time.time()  # Start timing the training
    train_loss, train_acc = train(model, train_loader, criterion, optimizer)
    print("\n--- Profiling Results for Training Phase ---")
    training_time = time.time() - start_time  # Calculate elapsed time
    max_train_memory = torch.cuda.max_memory_allocated()  # Get max GPU memory used during training
    print(f'Time: {training_time}\nMax memory: {max_train_memory}')

    # call validation
    torch.cuda.reset_peak_memory_stats()  # Reset memory stats
    start_time = time.time()  # Start timing the training
    dev_loss, dev_acc = validate(model, dev_loader, criterion)
    print("\n--- Profiling Results for Validation Phase ---")
    validation_time = time.time() - start_time  # Calculate elapsed time
    max_validation_memory = torch.cuda.max_memory_allocated()  # Get max GPU memory used during training
    print(f'Time: {validation_time}\nMax memory: {max_validation_memory}')



    train_losses.append(train_loss)
    train_accuracy.append(train_acc)

    dev_losses.append(dev_loss)
    dev_accuracy.append(dev_acc)
    epoch_holder.append(epoch + 1)

    # check for improvement
    if dev_acc > current_best:
      current_best = dev_acc
      no_improvement = 0
    else:
      no_improvement += 1

    # save best model
    if dev_acc > max_accuracy:
      torch.save(model.state_dict(), path)
      max_accuracy = dev_acc


    # optionally print epoch results
    if verbose == 2:
      print(f'\n --------- \nEpoch: {epoch + 1}\n')
      print(f'Epoch {epoch + 1} train loss: {train_loss:.4f}')
      print(f'Epoch {epoch + 1} train accuracy: {train_acc:.4f}')
      print(f'Epoch {epoch + 1} dev loss: {dev_loss:.4f}')
      print(f'Epoch {epoch + 1} dev accuracy: {dev_acc:.4f}')

      # save best results
  max_ind = np.argmax(dev_accuracy)

  stats = Stats(
      train_losses[max_ind],
      train_accuracy[max_ind],
      dev_losses[max_ind],
      dev_accuracy[max_ind],
      epoch_holder[max_ind],
      lr, alpha,
      max_accuracy
  )

  # optionally print model results
  if verbose in [1,2]:
    print('\n ######## \n')
    print(f'lr:{stats.lr}, alpha:{stats.alpha} @ epoch {stats.epoch}.')
    print(f'TL:{stats.train_loss}, TA:{stats.train_accuracy}.')
    print(f'DL:{stats.dev_loss}, DA:{stats.dev_accuracy}')

  return (training_time, max_train_memory), (validation_time, max_validation_memory)

In [None]:
epochs = 1
lr = 0.0005
alpha = 0.00005
path = "class/models/GNN_trace.pt"
max_accuracy = 0
model = GNNModel(d,h,c)

In [None]:
training_stats, validation_stats = tv_run(epochs, model, lr, alpha, max_accuracy, path, verbose = 2)


--- Profiling Results for Training Phase ---
Time: 19.855705499649048
Max memory: 1540577792

--- Profiling Results for Validation Phase ---
Time: 95.31733846664429
Max memory: 1535012864

 --------- 
Epoch: 1

Epoch 1 train loss: 2.5837
Epoch 1 train accuracy: 0.4124
Epoch 1 dev loss: 2.9375
Epoch 1 dev accuracy: 0.3027

 ######## 

lr:0.0005, alpha:0.0001 @ epoch 1.
TL:2.5836983284950255, TA:0.4123930328585182.
DL:2.9375194724926277, DA:0.30268595041322316


In [None]:
df1 = df_train[['vanilla_embedding.1', 'chapter_idx']]
df2 = df_dev[['vanilla_embedding.1', 'chapter_idx']]
df_val = pd.concat([df2, df1])
H_val = torch.stack(list(df_val['vanilla_embedding.1']))
labels_val = torch.LongTensor(list(df_val['chapter_idx']))
A_val = val_entities

validation_dataset = GNNDataset(H_val, A_val, labels_val, df2.shape[0], neighbor_max=4)

# Prevent dataloader from calling a single index at a time
custom_validation_sampler = torch.utils.data.sampler.BatchSampler(
    torch.utils.data.sampler.SequentialSampler(validation_dataset),
    batch_size=1024,
    drop_last=False)

dev_loader = DataLoader(validation_dataset, sampler = custom_validation_sampler)

torch.cuda.reset_peak_memory_stats()  # Reset memory stats
start_time = time.time()  # Start timing the training
dev_loss, dev_acc = validate(model, dev_loader, criterion)
print("\n--- Profiling Results for Validation Phase ---")
validation_time = time.time() - start_time  # Calculate elapsed time
max_validation_memory = torch.cuda.max_memory_allocated()  # Get max GPU memory used during training
print(f'Time: {validation_time}\nMax memory: {max_validation_memory}')
print(dev_loss)
print(dev_acc)


--- Profiling Results for Validation Phase ---
Time: 32.89159631729126
Max memory: 4280548864
2.9000470638275146
0.31016597510373445


We have a lot of space increase batch size

In [None]:
# training loader created with a subset of the data
H_train = torch.stack(list(df_train['vanilla_embedding.1'])[:2000])
labels_train = torch.LongTensor(list(df_train['chapter_idx'])[:2000])
A_train = train_entities[:2000,:2000]


train_dataset = GNNDataset(H_train, A_train, labels_train, H_train.size(0), neighbor_max=4)

# Prevent dataloader from calling a single index at a time
custom_train_sampler = torch.utils.data.sampler.BatchSampler(
    torch.utils.data.sampler.RandomSampler(train_dataset),
    batch_size=64,
    drop_last=False)


df1 = df_train[['vanilla_embedding.1', 'chapter_idx']]
df2 = df_dev[['vanilla_embedding.1', 'chapter_idx']]
df_val = pd.concat([df2, df1])
H_val = torch.stack(list(df_val['vanilla_embedding.1']))
labels_val = torch.LongTensor(list(df_val['chapter_idx']))
A_val = val_entities

validation_dataset = GNNDataset(H_val, A_val, labels_val, df2.shape[0], neighbor_max=4)

# Prevent dataloader from calling a single index at a time
custom_validation_sampler = torch.utils.data.sampler.BatchSampler(
    torch.utils.data.sampler.SequentialSampler(validation_dataset),
    batch_size=1024,
    drop_last=False)


In [None]:
training_stats, validation_stats = tv_run(epochs, model, lr, alpha, max_accuracy, path, verbose = 2)


--- Profiling Results for Training Phase ---
Time: 11.893709421157837
Max memory: 4352926208

--- Profiling Results for Validation Phase ---
Time: 32.80560326576233
Max memory: 4490978304

 --------- 
Epoch: 1

Epoch 1 train loss: 1.3462
Epoch 1 train accuracy: 0.6995
Epoch 1 dev loss: 2.6072
Epoch 1 dev accuracy: 0.3527

 ######## 

lr:0.0005, alpha:0.0001 @ epoch 1.
TL:1.3462302163243294, TA:0.6994579701375027.
DL:2.6072261333465576, DA:0.35269709543568467


In [None]:
# training loader created with a subset of the data
H_train = torch.stack(list(df_train['vanilla_embedding.1'])[:2000])
labels_train = torch.LongTensor(list(df_train['chapter_idx'])[:2000])
A_train = train_entities[:2000,:2000]


train_dataset = GNNDataset(H_train, A_train, labels_train, H_train.size(0), neighbor_max=4)

# Prevent dataloader from calling a single index at a time
custom_train_sampler = torch.utils.data.sampler.BatchSampler(
    torch.utils.data.sampler.RandomSampler(train_dataset),
    batch_size=16,
    drop_last=False)


df1 = df_train[['vanilla_embedding.1', 'chapter_idx']]
df2 = df_dev[['vanilla_embedding.1', 'chapter_idx']]
df_val = pd.concat([df2, df1])
H_val = torch.stack(list(df_val['vanilla_embedding.1']))
labels_val = torch.LongTensor(list(df_val['chapter_idx']))
A_val = val_entities

validation_dataset = GNNDataset(H_val, A_val, labels_val, df2.shape[0], neighbor_max=4)

# Prevent dataloader from calling a single index at a time
custom_validation_sampler = torch.utils.data.sampler.BatchSampler(
    torch.utils.data.sampler.SequentialSampler(validation_dataset),
    batch_size=1024,
    drop_last=False)

In [None]:
training_stats, validation_stats = tv_run(epochs, model, lr, alpha, max_accuracy, path, verbose = 2)


--- Profiling Results for Training Phase ---
Time: 17.648733139038086
Max memory: 4351542272

--- Profiling Results for Validation Phase ---
Time: 33.276201009750366
Max memory: 4490362880

 --------- 
Epoch: 1

Epoch 1 train loss: 2.6135
Epoch 1 train accuracy: 0.4058
Epoch 1 dev loss: 2.9246
Epoch 1 dev accuracy: 0.2998

 ######## 

lr:0.0005, alpha:5e-05 @ epoch 1.
TL:2.6134842596054075, TA:0.4058028126283634.
DL:2.9246041774749756, DA:0.29979253112033194


looks like it is causing overfitting lets keep training the same and make gains on validation

In [None]:
# training loader created with a subset of the data
H_train = torch.stack(list(df_train['vanilla_embedding.1'])[:2000])
labels_train = torch.LongTensor(list(df_train['chapter_idx'])[:2000])
A_train = train_entities[:2000,:2000]


train_dataset = GNNDataset(H_train, A_train, labels_train, H_train.size(0), neighbor_max=4)

# Prevent dataloader from calling a single index at a time
custom_train_sampler = torch.utils.data.sampler.BatchSampler(
    torch.utils.data.sampler.RandomSampler(train_dataset),
    batch_size=8,
    drop_last=False)


df1 = df_train[['vanilla_embedding.1', 'chapter_idx']]
df2 = df_dev[['vanilla_embedding.1', 'chapter_idx']]
df_val = pd.concat([df2, df1])
H_val = torch.stack(list(df_val['vanilla_embedding.1']))
labels_val = torch.LongTensor(list(df_val['chapter_idx']))
A_val = val_entities

validation_dataset = GNNDataset(H_val, A_val, labels_val, df2.shape[0], neighbor_max=4)

# Prevent dataloader from calling a single index at a time
custom_validation_sampler = torch.utils.data.sampler.BatchSampler(
    torch.utils.data.sampler.SequentialSampler(validation_dataset),
    batch_size=1024,
    drop_last=False)

In [None]:
training_stats, validation_stats = tv_run(epochs, model, lr, alpha, max_accuracy, path, verbose = 2)


--- Profiling Results for Training Phase ---
Time: 19.32697558403015
Max memory: 4351483904

--- Profiling Results for Validation Phase ---
Time: 33.48703742027283
Max memory: 4340441088

 --------- 
Epoch: 1

Epoch 1 train loss: 0.9849
Epoch 1 train accuracy: 0.7763
Epoch 1 dev loss: 2.4917
Epoch 1 dev accuracy: 0.3641

 ######## 

lr:0.0005, alpha:5e-05 @ epoch 1.
TL:0.9848547171354294, TA:0.7763352196350619.
DL:2.4916648864746094, DA:0.36410788381742737


# new refence above
now make bs log2

In [None]:
from torch.utils.data import Dataset, DataLoader
from copy import deepcopy

def sample_neighborhood(primary_inds, input, adj, distance, neighbor_max = 4):
    """
    Takes the given inds and the inputs, returns the sampled set of indices and the corresponding activation flag.
    If the activation flag is True, then the the datapoint is primary and has been search for neighbors and neighbors
    of neighbors. This indicates that we rely on the datapoint for metrics and loss.

    Args:
      primary_inds : iterable
          indices sampled by the dataloader
      input : torch.Tensor
          (m,d) input tensor
      adj : torch.Tensor
          (m,m) adjacency matrix
      distance : torch.Tensor
          (m,m) : cosine similarity between all inputs
      neighbour_max : int (optional)
          The maximum number of neighbors to consider for each point

    Returns:
      sampled_indices : list
          indices of all datapoints to be processed in the batch
      activation_flag : list
          boolean_flag indicating whether the corresponding datapoint is to be considered for metrics
    """

    def _get_closest_neighbors(ind):
      """get up to neighbor_max close neighbors"""
      local_neighbors = []
      local_activation_flag = []
      candidate_neighbors = [neighbor.item() for neighbor in (adj[ind] > 0).nonzero(as_tuple=True)[0] if neighbor.item() not in sampled_indices]
      candidate_distances = [(neighbor, distance[primary_ind][neighbor]) for neighbor in candidate_neighbors]
      sorted_neighbors = sorted(candidate_distances, key=lambda x: x[1])
      return [neighbor for neighbor, dist in sorted_neighbors[:neighbor_max]], candidate_neighbors

    sampled_indices = []
    activation_flag = []
    all_banned_neighbors = []

    for primary_ind in primary_inds:

      # if primary ind has been added as a neighbor, convert the activation flag to true, else add it as a standard primary index
      if primary_ind in sampled_indices:
        activation_flag[sampled_indices.index(primary_ind)] = True
      else:
        sampled_indices.append(primary_ind)
        activation_flag.append(True)

      # print('\n', primary_ind)
      # print('sampled_indices', sampled_indices)

      level_1_neighbors, candidate_neighbors = _get_closest_neighbors(primary_ind)
      banned_neighbors = list(set(candidate_neighbors) - set(level_1_neighbors))
      all_banned_neighbors.extend(banned_neighbors)
      # print('banned_neighbors', banned_neighbors)
      # print('all_banned_neighbors', all_banned_neighbors)

      level_1_activation_flag = [False for el in level_1_neighbors]
      sampled_indices.extend(level_1_neighbors)
      activation_flag.extend(level_1_activation_flag)

      # print('level_1_neighbors', level_1_neighbors)
      # print('sampled_indices', sampled_indices)
      # print('level_1_activation_flag', level_1_activation_flag)
      # print('activation_flag', activation_flag)

      for level_1_ind in level_1_neighbors:
        level_2_neighbors, _ = _get_closest_neighbors(level_1_ind)
        level_2_activation_flag = [False for el in level_2_neighbors]
        sampled_indices.extend(level_2_neighbors)
        activation_flag.extend(level_2_activation_flag)

        # print('level_2_neighbors', level_2_neighbors)
        # print('sampled_indices', sampled_indices)
        # print('level_2_activation_flag', level_2_activation_flag)
        # print('activation_flag', activation_flag)

    # include only 4 level one neighbors for each primary index to avoid pollution
    clean_indices = []
    clean_flags = []
    for i in range(len(sampled_indices)):
      target = sampled_indices[i]
      if target in primary_inds:
        clean_indices.append(target)
        clean_flags.append(True)
      elif target in all_banned_neighbors:
        continue
      else:
        clean_indices.append(target)
        clean_flags.append(False)
    # print(f'\ncleaning\nsampled_indices : {sampled_indices}\nclean_indices : {clean_indices}')

    limit = 2**int(np.floor(np.log2(len(clean_indices))))
    return clean_indices[:limit], clean_flags[:limit]


In [None]:
training_stats, validation_stats = tv_run(epochs, model, lr, alpha, max_accuracy, path, verbose = 2)


--- Profiling Results for Training Phase ---
Time: 19.72818875312805
Max memory: 2920893952

--- Profiling Results for Validation Phase ---
Time: 33.25472927093506
Max memory: 3038790144

 --------- 
Epoch: 1

Epoch 1 train loss: 0.8533
Epoch 1 train accuracy: 0.7958
Epoch 1 dev loss: 2.5077
Epoch 1 dev accuracy: 0.3769

 ######## 

lr:0.0005, alpha:5e-05 @ epoch 1.
TL:0.8532900586128235, TA:0.7958125.
DL:2.5076897144317627, DA:0.3769035532994924


# No impact on power 2

Check sparse calc

In [None]:
#rest sample neighborhood

from torch.utils.data import Dataset, DataLoader
from copy import deepcopy

def sample_neighborhood(primary_inds, input, adj, distance, neighbor_max = 4):
    """
    Takes the given inds and the inputs, returns the sampled set of indices and the corresponding activation flag.
    If the activation flag is True, then the the datapoint is primary and has been search for neighbors and neighbors
    of neighbors. This indicates that we rely on the datapoint for metrics and loss.

    Args:
      primary_inds : iterable
          indices sampled by the dataloader
      input : torch.Tensor
          (m,d) input tensor
      adj : torch.Tensor
          (m,m) adjacency matrix
      distance : torch.Tensor
          (m,m) : cosine similarity between all inputs
      neighbour_max : int (optional)
          The maximum number of neighbors to consider for each point

    Returns:
      sampled_indices : list
          indices of all datapoints to be processed in the batch
      activation_flag : list
          boolean_flag indicating whether the corresponding datapoint is to be considered for metrics
    """

    def _get_closest_neighbors(ind):
      """get up to neighbor_max close neighbors"""
      local_neighbors = []
      local_activation_flag = []
      candidate_neighbors = [neighbor.item() for neighbor in (adj[ind] > 0).nonzero(as_tuple=True)[0] if neighbor.item() not in sampled_indices]
      candidate_distances = [(neighbor, distance[primary_ind][neighbor]) for neighbor in candidate_neighbors]
      sorted_neighbors = sorted(candidate_distances, key=lambda x: x[1])
      return [neighbor for neighbor, dist in sorted_neighbors[:neighbor_max]], candidate_neighbors

    sampled_indices = []
    activation_flag = []
    all_banned_neighbors = []

    for primary_ind in primary_inds:

      # if primary ind has been added as a neighbor, convert the activation flag to true, else add it as a standard primary index
      if primary_ind in sampled_indices:
        activation_flag[sampled_indices.index(primary_ind)] = True
      else:
        sampled_indices.append(primary_ind)
        activation_flag.append(True)

      # print('\n', primary_ind)
      # print('sampled_indices', sampled_indices)

      level_1_neighbors, candidate_neighbors = _get_closest_neighbors(primary_ind)
      banned_neighbors = list(set(candidate_neighbors) - set(level_1_neighbors))
      all_banned_neighbors.extend(banned_neighbors)
      # print('banned_neighbors', banned_neighbors)
      # print('all_banned_neighbors', all_banned_neighbors)

      level_1_activation_flag = [False for el in level_1_neighbors]
      sampled_indices.extend(level_1_neighbors)
      activation_flag.extend(level_1_activation_flag)

      # print('level_1_neighbors', level_1_neighbors)
      # print('sampled_indices', sampled_indices)
      # print('level_1_activation_flag', level_1_activation_flag)
      # print('activation_flag', activation_flag)

      for level_1_ind in level_1_neighbors:
        level_2_neighbors, _ = _get_closest_neighbors(level_1_ind)
        level_2_activation_flag = [False for el in level_2_neighbors]
        sampled_indices.extend(level_2_neighbors)
        activation_flag.extend(level_2_activation_flag)

        # print('level_2_neighbors', level_2_neighbors)
        # print('sampled_indices', sampled_indices)
        # print('level_2_activation_flag', level_2_activation_flag)
        # print('activation_flag', activation_flag)

    # include only 4 level one neighbors for each primary index to avoid pollution
    clean_indices = []
    clean_flags = []
    for i in range(len(sampled_indices)):
      target = sampled_indices[i]
      if target in primary_inds:
        clean_indices.append(target)
        clean_flags.append(True)
      elif target in all_banned_neighbors:
        continue
      else:
        clean_indices.append(target)
        clean_flags.append(False)
    # print(f'\ncleaning\nsampled_indices : {sampled_indices}\nclean_indices : {clean_indices}')
    return clean_indices, clean_flags


In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class GNNLayer(nn.Module):
    def __init__(self, in_features, out_features, dropout=0.3, training=True):
        super(GNNLayer, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.dropout = dropout
        self.training = training

        self.T = nn.Parameter(torch.Tensor(in_features, out_features))
        self.E = nn.Parameter(torch.Tensor(in_features, out_features))

        # Batch normalization
        self.batch_norm = nn.BatchNorm1d(out_features)

        self.reset_parameters()

    def reset_parameters(self):
        nn.init.xavier_uniform_(self.T)
        nn.init.xavier_uniform_(self.E)

    def forward(self, H, A):
        messages_projection = A.T @ H @ self.E
        degrees = A.sum(dim=1, keepdim=True)
        degrees[degrees == 0] = 1.0
        messages_projection /= degrees

        self_projection = H @ self.T

        # Include skip connection
        H_out = F.leaky_relu(self_projection + messages_projection) + H
        H_out = F.dropout(H_out, p=self.dropout)

        # Apply batch normalization
        H_out = self.batch_norm(H_out)

        return H_out

class GNNModel(nn.Module):
    def __init__(self, d, h, c, num_layers=2, dropout=0.3):
        super(GNNModel, self).__init__()
        self.num_layers = num_layers
        self.gnn_layers = nn.ModuleList([GNNLayer(d, d, dropout) for _ in range(num_layers)])
        self.fc1 = nn.Linear(d, h)
        self.batch_norm_fc1 = nn.BatchNorm1d(h)
        self.fc2 = nn.Linear(h, c)
        self.dropout = dropout
        self.reset_parameters()

    def reset_parameters(self):
        nn.init.xavier_uniform_(self.fc1.weight)
        nn.init.xavier_uniform_(self.fc2.weight)
        nn.init.zeros_(self.fc1.bias)
        nn.init.zeros_(self.fc2.bias)

    def forward(self, H, A):
        for layer in self.gnn_layers:
            H = layer(H, A)

        H = F.dropout(H, p=self.dropout, training=self.training)
        H = F.relu(self.batch_norm_fc1(self.fc1(H)))
        Output = self.fc2(H)
        return Output

    def forward_layer(self, H, A, layer_idx):
        """Forward pass for a specific layer."""
        H = self.gnn_layers[layer_idx](H, A)
        return H


In [None]:
epochs = 1
lr = 0.0005
alpha = 0.00005
path = "class/models/GNN_trace.pt"
max_accuracy = 0
model = GNNModel(d,h,c)

In [None]:
training_stats, validation_stats = tv_run(epochs, model, lr, alpha, max_accuracy, path, verbose = 2)


--- Profiling Results for Training Phase ---
Time: 19.56251311302185
Max memory: 2921960448

--- Profiling Results for Validation Phase ---
Time: 32.56710410118103
Max memory: 3060967424

 --------- 
Epoch: 1

Epoch 1 train loss: 2.5983
Epoch 1 train accuracy: 0.4109
Epoch 1 dev loss: 2.8818
Epoch 1 dev accuracy: 0.3029

 ######## 

lr:0.0005, alpha:5e-05 @ epoch 1.
TL:2.5983164286613465, TA:0.41088487058579787.
DL:2.8818299770355225, DA:0.3029045643153527


# No Effect

low precision

In [None]:
from torch.cuda.amp import GradScaler, autocast
import numpy as np

def train(model, train_loader, criterion, optimizer, scaler):
    model.train()
    epoch_train_losses = []
    epoch_train_accuracy = []

    for batch_idx, (H, A, y, flag) in enumerate(train_loader):
        optimizer.zero_grad()

        H = H.squeeze(0)
        A = A.squeeze(0)
        y = y.squeeze(0)
        flag = flag.squeeze(0)

        with autocast():  # Assuming torch.cuda.amp.autocast
            out = model(H, A)
            train_loss = criterion(out, y)
            train_accuracy = accuracy(out, y)

        # Scale the loss, perform backward, and update
        scaler.scale(train_loss).backward()
        scaler.step(optimizer)
        scaler.update()

        epoch_train_losses.append(train_loss.item())
        epoch_train_accuracy.append(train_accuracy)

    return np.mean(epoch_train_losses), np.mean(epoch_train_accuracy)

In [None]:
import time
def tv_run(epochs, model, lr, alpha, max_accuracy, path, verbose = 0):
  """
  Runs a training setup
  verbose == 1 - print model results
  verbose == 2 -> print epoch and model results
  """
  model = model.to(device)
  criterion = nn.CrossEntropyLoss()
  optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=alpha)
  scaler = GradScaler()

  # Prepare data loaders
  train_loader = DataLoader(train_dataset, sampler = custom_train_sampler)
  dev_loader = DataLoader(validation_dataset, sampler = custom_validation_sampler)

  # Hold epoch stats
  train_losses = []
  train_accuracy = []
  dev_losses = []
  dev_accuracy = []
  epoch_holder = []

  # Break if no improvement
  current_best = 0
  no_improvement = 0


  # Run epochs
  for epoch in range(epochs):

    # break out of epochs
    if no_improvement >= 6:
      break

    # call training
    torch.cuda.reset_peak_memory_stats()  # Reset memory stats
    start_time = time.time()  # Start timing the training
    train_loss, train_acc = train(model, train_loader, criterion, optimizer, scaler)
    print("\n--- Profiling Results for Training Phase ---")
    training_time = time.time() - start_time  # Calculate elapsed time
    max_train_memory = torch.cuda.max_memory_allocated()  # Get max GPU memory used during training
    print(f'Time: {training_time}\nMax memory: {max_train_memory}')

    # call validation
    torch.cuda.reset_peak_memory_stats()  # Reset memory stats
    start_time = time.time()  # Start timing the training
    dev_loss, dev_acc = validate(model, dev_loader, criterion)
    print("\n--- Profiling Results for Validation Phase ---")
    validation_time = time.time() - start_time  # Calculate elapsed time
    max_validation_memory = torch.cuda.max_memory_allocated()  # Get max GPU memory used during training
    print(f'Time: {validation_time}\nMax memory: {max_validation_memory}')



    train_losses.append(train_loss)
    train_accuracy.append(train_acc)

    dev_losses.append(dev_loss)
    dev_accuracy.append(dev_acc)
    epoch_holder.append(epoch + 1)

    # check for improvement
    if dev_acc > current_best:
      current_best = dev_acc
      no_improvement = 0
    else:
      no_improvement += 1

    # save best model
    if dev_acc > max_accuracy:
      torch.save(model.state_dict(), path)
      max_accuracy = dev_acc


    # optionally print epoch results
    if verbose == 2:
      print(f'\n --------- \nEpoch: {epoch + 1}\n')
      print(f'Epoch {epoch + 1} train loss: {train_loss:.4f}')
      print(f'Epoch {epoch + 1} train accuracy: {train_acc:.4f}')
      print(f'Epoch {epoch + 1} dev loss: {dev_loss:.4f}')
      print(f'Epoch {epoch + 1} dev accuracy: {dev_acc:.4f}')

      # save best results
  max_ind = np.argmax(dev_accuracy)

  stats = Stats(
      train_losses[max_ind],
      train_accuracy[max_ind],
      dev_losses[max_ind],
      dev_accuracy[max_ind],
      epoch_holder[max_ind],
      lr, alpha,
      max_accuracy
  )

  # optionally print model results
  if verbose in [1,2]:
    print('\n ######## \n')
    print(f'lr:{stats.lr}, alpha:{stats.alpha} @ epoch {stats.epoch}.')
    print(f'TL:{stats.train_loss}, TA:{stats.train_accuracy}.')
    print(f'DL:{stats.dev_loss}, DA:{stats.dev_accuracy}')

  return (training_time, max_train_memory), (validation_time, max_validation_memory)

In [None]:
training_stats, validation_stats = tv_run(epochs, model, lr, alpha, max_accuracy, path, verbose = 2)

  scaler = GradScaler()
  with autocast():  # Assuming torch.cuda.amp.autocast



--- Profiling Results for Training Phase ---
Time: 19.43116569519043
Max memory: 4391714816

--- Profiling Results for Validation Phase ---
Time: 33.2727472782135
Max memory: 4380591104

 --------- 
Epoch: 1

Epoch 1 train loss: 1.3900
Epoch 1 train accuracy: 0.6911
Epoch 1 dev loss: 2.5692
Epoch 1 dev accuracy: 0.3392

 ######## 

lr:0.0005, alpha:5e-05 @ epoch 1.
TL:1.3899921100139618, TA:0.6910955044707143.
DL:2.569183588027954, DA:0.3392116182572614


# change batch size
# sparse metrices
# mixed precision training
# memory pinning