# Text Classification - GNN Inference


## $\color{blue}{Sections:}$

* Preamble
1.   Admin
2.   Dataset
3.   Model
4.   Train - Validate
5.   Training Loop

## $\color{blue}{Preamble:}$

We now look to infer the entire graph at once, relying on layerwise propogation.

## $\color{blue}{Admin}$
* Install relevant Libraries
* Import relevant Libraries

In [1]:
import torch
import pandas as pd
from google.colab import drive
import numpy as np
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

cuda


In [2]:
drive.mount("/content/drive")
%cd '/content/drive/MyDrive'

Mounted at /content/drive
/content/drive/MyDrive


## $\color{blue}{Data}$

* Connect to Drive
* Load the data
* Load adjacency matrices

In [3]:
path = 'class/datasets/'
df_train = pd.read_pickle(path + 'df_train')
df_dev = pd.read_pickle(path + 'df_dev')
df_test = pd.read_pickle(path + 'df_test')
df1 = df_train[['vanilla_embedding.1','chapter_idx']]
df2 = df_dev[['vanilla_embedding.1','chapter_idx']]
df_val = pd.concat([df1, df2])

In [4]:
path = 'class/tensors/adj_{}.pt'

# train
train_people = torch.load(path.format('train_people'))
train_locations = torch.load(path.format('train_locations'))
train_entities = torch.load(path.format('train_entities'))

# dev
dev_people = torch.load(path.format('dev_people'))
dev_locations = torch.load(path.format('dev_locations'))
dev_entities = torch.load(path.format('dev_entities'))

# val (contains the adjacency matrix for both the training and the development set)
val_people = torch.load(path.format('val_people'))
val_locations = torch.load(path.format('val_locations'))
val_entities = torch.load(path.format('val_entities'))

  train_people = torch.load(path.format('train_people'))
  train_locations = torch.load(path.format('train_locations'))
  train_entities = torch.load(path.format('train_entities'))
  dev_people = torch.load(path.format('dev_people'))
  dev_locations = torch.load(path.format('dev_locations'))
  dev_entities = torch.load(path.format('dev_entities'))
  val_people = torch.load(path.format('val_people'))
  val_locations = torch.load(path.format('val_locations'))
  val_entities = torch.load(path.format('val_entities'))


## $\color{blue}{Dataset:}$

In [5]:
from torch.utils.data import Dataset, DataLoader

def sample_neighborhood(A, inds, neighbor_max, batch_max, seed=None):
    # Set the random seed for deterministic responses
    if seed is not None:
        np.random.seed(seed)

    np.random.shuffle(A)  # Shuffle the list of adjacency matrices in place
    sampled_indices = set(inds)  # Initialize the set of sampled indices

    for ind in inds:  # Iterate through node in mini-batch
        neighbors = set()

        for adj in A:  # Iterate through all adjacency matrices
            # Get the indices of all neighbors that idx links to
            neighbors.update((adj[ind] > 0).nonzero(as_tuple=True)[0].tolist())

            if len(neighbors) > neighbor_max:  # Check if we have too many neighbors
                # Take a random subset using np.random.choice
                neighbors = set(np.random.choice(list(neighbors), neighbor_max, replace=False))

            sampled_indices.update(neighbors)  # Add new neighbors

            if len(sampled_indices) >= batch_max:
                break

    sampled_indices = list(sampled_indices)[:batch_max]  # Limit to batch_max
    return sampled_indices

In [6]:
class GNNDataset(Dataset):
  def __init__(self, H, A, labels, meta_indices, neighbor_max=8, batch_max=64, seed=None):
    """Custom dataset with neighborhood sampling

    Args:
      H : torch.tensor
        input embeddings (n x d)

      A : list[torch.tensor]
        list of (n x n)

      labels : torch.LongTensor
        y

      meta_indices : torch.LongTensor
        index of datapoint to filter validation score

      neighbor_max : int
        max neighbors for each node in mini-batch

      batch_max : int
        max size of batch

    """
    # All inits must be tensors
    self.H = H.to(device)
    self.A = [a.to(device) for a in A]
    self.labels = labels.to(device)
    self.meta_indices = meta_indices
    self.neighbor_max = neighbor_max
    self.batch_max = batch_max
    self.seed = seed

  def __len__(self):
    return len(self.labels)

  def __getitem__(self, inds):
    # print('\n####################\n')
    # print('GET ITEM CALLED', 'INDS:', inds)
    # Sample neighborhood

    # get inds in list
    inds = inds.tolist() if torch.is_tensor(inds) else (inds if isinstance(inds,list) else [inds])

    # return the required inds
    sampled_indices = sample_neighborhood(self.A, inds, self.neighbor_max, self.batch_max,seed=self.seed)

    # get the input for the required inds
    H_batch = self.H[sampled_indices]

    # get the adjacency matrix for the required inds
    A_batch = [self.A[k][sampled_indices][:, sampled_indices] for k in range(len(self.A))]

    # get the labels for the required inds
    labels_batch = self.labels[sampled_indices]

    # get meta indices
    index_batch = self.meta_indices[sampled_indices]

    return H_batch, A_batch, labels_batch, index_batch

In [7]:
# training loader
df1 = df_train[['vanilla_embedding.1', 'chapter_idx']]
df2 = df_dev[['vanilla_embedding.1', 'chapter_idx']]
df_val = pd.concat([df1, df2])
H_val = torch.stack(list(df_val['vanilla_embedding.1']))
labels_val = torch.LongTensor(list(df_val['chapter_idx']))
A_val = []
A_val.append(val_entities)
val_indices = torch.LongTensor(list(range(df_val.shape[0])))



validation_dataset = GNNDataset(H_val, A_val, labels_val, val_indices, neighbor_max=8, batch_max=32, seed=42)

# Prevent dataloader from calling a single index at a time
custom_validation_sampler = torch.utils.data.sampler.BatchSampler(
    torch.utils.data.sampler.SequentialSampler(validation_dataset),
    batch_size=8,
    drop_last=False)


dev_loader = DataLoader(validation_dataset, sampler = custom_validation_sampler)

In [22]:
from torch.utils.data import Dataset, DataLoader

class GNNInferenceDataset(Dataset):
    def __init__(self, H, A, labels, threshold=12000):
        """
        Args:
            H (torch.Tensor): tensor of node embeddings.
            A (list[torch.Tensor]): A list of torch.tensor adjacency matrices, one for each link type.
            labels (torch.LongTensor): classification labels
            threshold (int) : cut-off point for the training data
        """
        self.H = H.to(device)
        self.A = [a.to(device) for a in A]
        self.labels = labels.to(device)
        self.is_training = (torch.arange(H.size(0)) < threshold).to(device)

    def __len__(self):
        # Return the total number of samples
        return len(self.labels)

    def __getitem__(self, idx):
        # Get the data point and its training status
        H_batch = self.H[idx]
        A_batch = [self.A[k][idx] for k in range(len(self.A))]
        labels_batch = self.labels[idx]
        is_training_batch = self.is_training[idx]

        # Return an instance of DataPoint
        return H_batch, A_batch, labels_batch, is_training_batch

In [23]:
H = torch.stack(list(df_val['vanilla_embedding.1']))
labels = torch.LongTensor(list(df_val['chapter_idx']))
A = []
A.append(val_entities)

In [24]:
inference_dataset = GNNInferenceDataset(H,A,labels)

In [8]:
len(inference_dataset)

12964

In [9]:
inference_dataset[0][3]

tensor(True, device='cuda:0')

In [12]:
inference_dataset[12000][3]

tensor(False, device='cuda:0')

In [25]:
inference_dataloader = DataLoader(inference_dataset, batch_size=len(inference_dataset), shuffle=False)

## $\color{blue}{Model:}$

In [8]:
# Make robust to change of degrees at inference

import torch
import torch.nn as nn
import torch.nn.functional as F

class GNNLayer(nn.Module):
   def __init__(self, in_features, out_features, num_relations=1):
      super(GNNLayer, self).__init__()
      self.in_features = in_features
      self.out_features = out_features
      self.num_relations = num_relations

      # Define weight matrices for different relationship types
      self.T = nn.ParameterList([nn.Parameter(torch.Tensor(in_features, out_features)) for _ in range(num_relations)])
      self.E = nn.ParameterList([nn.Parameter(torch.Tensor(in_features, out_features)) for _ in range(num_relations)])

      # Learnable normalization factor
      self.norm = nn.Parameter(torch.ones(out_features))

      # Initialize parameters
      self.reset_parameters()

   def reset_parameters(self):
      for t in self.T:
          nn.init.xavier_uniform_(t)
      for e in self.E:
          nn.init.xavier_uniform_(e)

   def forward(self, H, A):
      H_out = torch.zeros_like(H)
      for k in range(self.num_relations):

          messages_projection = A[k].T @ H @ self.E[k] # get messages
          degrees = A[k].sum(dim=1, keepdim=True) # calculate degrees
          degrees[degrees == 0] = 1.0
          messages_projection /= degrees # adjust messages to degrees

          self_projection = H @ self.T[k] # get self projection

          H_out += F.leaky_relu(self_projection + messages_projection) # combine self projection and messages

      return H_out / self.norm

In [9]:
class GNNModel(nn.Module):
  def __init__(self, d, h, c, num_relations=1, num_layers=3):
    super(GNNModel, self).__init__()
    self.num_layers = num_layers
    self.gnn_layers = nn.ModuleList([GNNLayer(d, d) for _ in range(num_layers)])
    self.fc1 = nn.Linear(d, h)
    self.fc2 = nn.Linear(h, c)

  def forward(self, H, A):
    for layer in self.gnn_layers:
      H = layer(H, A)
    # Classification
    H = F.relu(self.fc1(H))
    Output = self.fc2(H)
    return Output

  def forward_layer(self, H, A, layer_idx):
    """Forward pass for a specific layer."""
    H = self.gnn_layers[layer_idx](H, A)
    return H

In [10]:
d = 768
h = 400
c = 70
num_relations = 1
path = "class/models/GNN.1.pt"


In [13]:
model = GNNModel(d, h, c, num_relations)
criterion = nn.CrossEntropyLoss()
path = "class/models/GNN.1.pt"
model.load_state_dict(torch.load(path))
model = model.to(device)

  model.load_state_dict(torch.load(path))


In [18]:
def layerwise_propagation(H, A, model):
  with torch.no_grad(): # No need to track gradients during inference
    for idx in range(model.num_layers):
      H = model.forward_layer(H, A, idx)
      # After propagating through all GNN layers, pass through the classification head
      H = F.leaky_relu(model.fc1(H))
      Output = model.fc2(H)
  return Output

In [11]:
def accuracy(outputs, labels):
    # argmax to get predicted classes
    _, predicted = torch.max(outputs, 1)

    # count correct
    correct = (predicted == labels).sum().item()

    # get average
    acc = correct / labels.size(0)  # Total number of samples
    return acc

In [34]:
def validate(model, val_loader, criterion):
    model.eval()
    epoch_dev_losses = []
    epoch_dev_accuracy = []

    with torch.no_grad():
        for batch_idx, (H, A, y, training) in enumerate(val_loader):

            out = model(H, A)
            print(out.size())

            # Filter out training points
            mask = ~training
            filtered_out = out[mask]
            filtered_y = y[mask]


            dev_loss = criterion(filtered_out, filtered_y)
            dev_accuracy = accuracy(filtered_out, filtered_y)

    # Avoid division by zero if no validation points were processed
    return dev_loss.item(), dev_accuracy, A

In [12]:
def validate(model, dev_loader, criterion, threshold=12000):
    model.eval()
    epoch_dev_losses = []
    epoch_dev_accuracy = []

    with torch.no_grad():
        for batch_idx, (H, A, y, indices) in enumerate(dev_loader):
            H = H.squeeze(0)
            A = [a.squeeze(0) for a in A]
            y = y.squeeze(0)
            indices = indices.squeeze(0)

            out = model(H, A)

            # Filter out training points
            mask = indices >= threshold
            filtered_out = out[mask]
            filtered_y = y[mask]

            # Calculate loss and accuracy only on filtered outputs
            if filtered_out.size(0) > 0:  # Ensure there are samples to evaluate
                dev_loss = criterion(filtered_out, filtered_y)
                dev_accuracy = accuracy(filtered_out, filtered_y)

                epoch_dev_losses.append(dev_loss.item())
                epoch_dev_accuracy.append(dev_accuracy)

    # Avoid division by zero if no validation points were processed
    return np.mean(epoch_dev_losses), np.mean(epoch_dev_accuracy)

In [14]:
loss, acc = validate(model, dev_loader, criterion)

In [15]:
loss

3.177604614874062

In [16]:
acc

0.6281040596830071

In [38]:
type(val_entities)
val_entities.size()
val_entities.sum()

tensor(1085820.)

In [43]:
val_entities = val_entities.to(device)


In [59]:
(val_entities == A[0]).sum() / (val_entities == A[0]).size(0)**2

tensor(1., device='cuda:0')