# Text Classification - GNN Inference


## $\color{blue}{Sections:}$

* Preamble
1.   Admin
2.   Dataset
3.   Model
4.   Validate
5.   Post Treatment

## $\color{blue}{Preamble:}$

We now look at inference on our GNN, returning the predictions from the GNN for further analysis.

## $\color{blue}{Admin}$
* Install relevant Libraries
* Import relevant Libraries

In [1]:
import torch
import pandas as pd
from google.colab import drive
import numpy as np
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

cuda


In [2]:
drive.mount("/content/drive")
%cd '/content/drive/MyDrive'

Mounted at /content/drive
/content/drive/MyDrive


## $\color{blue}{Data}$

* Connect to Drive
* Load the data
* Load adjacency matrices

In [3]:
path = 'class/datasets/'
df_train = pd.read_pickle(path + 'df_train')
df_dev = pd.read_pickle(path + 'df_dev')
df_test = pd.read_pickle(path + 'df_test')
df1 = df_train[['vanilla_embedding.1','chapter_idx']]
df2 = df_dev[['vanilla_embedding.1','chapter_idx']]
df_val = pd.concat([df1, df2])

In [4]:
path = 'class/tensors/adj_{}.pt'

# train
train_people = torch.load(path.format('train_people'))
train_locations = torch.load(path.format('train_locations'))
train_entities = torch.load(path.format('train_entities'))

# dev
dev_people = torch.load(path.format('dev_people'))
dev_locations = torch.load(path.format('dev_locations'))
dev_entities = torch.load(path.format('dev_entities'))

# val (contains the adjacency matrix for both the training and the development set)
val_people = torch.load(path.format('val_people'))
val_locations = torch.load(path.format('val_locations'))
val_entities = torch.load(path.format('val_entities'))

  train_people = torch.load(path.format('train_people'))
  train_locations = torch.load(path.format('train_locations'))
  train_entities = torch.load(path.format('train_entities'))
  dev_people = torch.load(path.format('dev_people'))
  dev_locations = torch.load(path.format('dev_locations'))
  dev_entities = torch.load(path.format('dev_entities'))
  val_people = torch.load(path.format('val_people'))
  val_locations = torch.load(path.format('val_locations'))
  val_entities = torch.load(path.format('val_entities'))


## $\color{blue}{Dataset:}$

In [7]:
from torch.utils.data import Dataset, DataLoader
from copy import deepcopy

def sample_neighborhood(A, inds, neighbor_max, branch_max, seed=None):
    # Set the random seed for deterministic responses
    if seed is not None:
        np.random.seed(seed)

    np.random.shuffle(A)  # Shuffle the list of adjacency matrices in place
    sampled_indices = set(inds)  # Initialize the set of sampled indices

    for ind in inds:  # Iterate through node in mini-batch
        break_to_outer = False
        neighbors = set()

        for adj in A:  # Iterate through all adjacency matrices
            if break_to_outer:
              break

            # Get the indices of all neighbors that idx links to
            disclude = set([ind]) | sampled_indices
            new_neighbors = [neighbor.item() for neighbor in (adj[ind] > 0).nonzero(as_tuple=True)[0] if neighbor.item() not in disclude]
            neighbors.update(new_neighbors)


            if len(neighbors) >= neighbor_max:  # Check if we have too many neighbors
                # Take a random subset using np.random.choice
                neighbors = set(np.random.choice(list(neighbors), neighbor_max, replace=False))


            copy_neighbors = deepcopy(neighbors)
            for idx in copy_neighbors:
                if break_to_outer:
                  break

                neighbors_neighbors = set()
                for adj in A:
                    disclude = set([ind,idx]) | sampled_indices | neighbors
                    new_neighbors_neighbors = [neighbor.item() for neighbor in (adj[idx] > 0).nonzero(as_tuple=True)[0] if neighbor.item() not in disclude]
                    if len(new_neighbors_neighbors) > neighbor_max:
                      new_neighbors_neighbors = set(np.random.choice(list(new_neighbors_neighbors), neighbor_max, replace = False))
                    neighbors_neighbors.update(new_neighbors_neighbors)

                    if len(neighbors) + len(neighbors_neighbors) >= branch_max:

                      neighbors_neighbors = set(np.random.choice(list(neighbors_neighbors), branch_max - len(neighbors), replace=False))

                      neighbors.update(neighbors_neighbors)

                      break_to_outer = True
                      break

                    neighbors.update(neighbors_neighbors)

        sampled_indices.update(neighbors)  # Add new neighbors

    return list(sampled_indices)

In [8]:
class GNNDataset(Dataset):
  def __init__(self, H, A, labels, meta_indices, neighbor_max=4, branch_max=16, seed=None):
    """Custom dataset with neighborhood sampling

    Args:
      H : torch.tensor
        input embeddings (n x d)

      A : list[torch.tensor]
        list of (n x n)

      labels : torch.LongTensor
        y

      meta_indices : torch.LongTensor
        index of datapoint to filter validation score

      neighbor_max : int
        max neighbors for each node in mini-batch

      batch_max : int
        max size of batch

    """
    # All inits must be tensors
    self.H = H.to(device)
    self.A = [a.to(device) for a in A]
    self.labels = labels.to(device)
    self.meta_indices = meta_indices
    self.neighbor_max = neighbor_max
    self.branch_max = branch_max
    self.seed = seed

  def __len__(self):
    return len(self.labels)

  def __getitem__(self, inds):
    # print('\n####################\n')
    # print('GET ITEM CALLED', 'INDS:', inds)
    # Sample neighborhood

    # get inds in list
    inds = inds.tolist() if torch.is_tensor(inds) else (inds if isinstance(inds,list) else [inds])

    # return the required inds
    sampled_indices = sample_neighborhood(self.A, inds, self.neighbor_max, self.branch_max,seed=self.seed)

    # get the input for the required inds
    H_batch = self.H[sampled_indices]

    # get the adjacency matrix for the required inds
    A_batch = [self.A[k][sampled_indices][:, sampled_indices] for k in range(len(self.A))]

    # get the labels for the required inds
    labels_batch = self.labels[sampled_indices]

    # get meta indices
    index_batch = self.meta_indices[sampled_indices]

    return H_batch, A_batch, labels_batch, index_batch

In [9]:
H_val = torch.stack(list(df_val['vanilla_embedding.1']))
labels_val = torch.LongTensor(list(df_val['chapter_idx']))
A_val = []
A_val.append(val_entities)
val_indices = torch.LongTensor(list(range(df_val.shape[0])))

validation_dataset = GNNDataset(H_val, A_val, labels_val, val_indices, neighbor_max=4, branch_max=16, seed=42)

# Prevent dataloader from calling a single index at a time
custom_validation_sampler = torch.utils.data.sampler.BatchSampler(
    torch.utils.data.sampler.SequentialSampler(validation_dataset),
    batch_size=8,
    drop_last=False)


dev_loader = DataLoader(validation_dataset, sampler = custom_validation_sampler)

## $\color{blue}{Model:}$

In [10]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class GNNLayer(nn.Module):
    def __init__(self, in_features, out_features, num_relations=1, dropout=0.3):
        super(GNNLayer, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.num_relations = num_relations
        self.dropout = dropout

        self.T = nn.ParameterList([nn.Parameter(torch.Tensor(in_features, out_features)) for _ in range(num_relations)])
        self.E = nn.ParameterList([nn.Parameter(torch.Tensor(in_features, out_features)) for _ in range(num_relations)])

        # Batch normalization
        self.batch_norm = nn.BatchNorm1d(out_features)

        self.reset_parameters()

    def reset_parameters(self):
        for t in self.T:
            nn.init.xavier_uniform_(t)
        for e in self.E:
            nn.init.xavier_uniform_(e)

    def forward(self, H, A):
        H_out = torch.zeros_like(H)
        for k in range(self.num_relations):
            messages_projection = A[k].T @ H @ self.E[k]
            degrees = A[k].sum(dim=1, keepdim=True)
            degrees[degrees == 0] = 1.0
            messages_projection /= degrees

            self_projection = H @ self.T[k]

            # Include skip connection
            H_out += F.leaky_relu(self_projection + messages_projection) + H

        # Apply batch normalization
        H_out = self.batch_norm(H_out)

        # Apply dropout
        H_out = F.dropout(H_out, p=self.dropout, training=self.training)

        return H_out

class GNNModel(nn.Module):
   def __init__(self, d, h, c, num_relations=1, num_layers=2):
      super(GNNModel, self).__init__()
      self.num_layers = num_layers
      self.gnn_layers = nn.ModuleList([GNNLayer(d, d) for _ in range(num_layers)])
      self.fc1 = nn.Linear(d, h)
      self.fc2 = nn.Linear(h, c)

   def forward(self, H, A):
      for layer in self.gnn_layers:
         H = layer(H, A)
      # Classification
      H = F.relu(self.fc1(H))
      Output = self.fc2(H)
      return Output


In [11]:
d = 768
h = 400
c = 70
num_relations = 1
path = "class/models/GNN.3.pt"


In [12]:
model = GNNModel(d, h, c, num_relations)
criterion = nn.CrossEntropyLoss()
model.load_state_dict(torch.load(path))
model = model.to(device)

  model.load_state_dict(torch.load(path))


## $\color{blue}{Validate:}$

In [13]:
def accuracy(outputs, labels):
    # argmax to get predicted classes
    _, predicted = torch.max(outputs, 1)

    # count correct
    correct = (predicted == labels).sum().item()

    # get average
    acc = correct / labels.size(0)  # Total number of samples
    return acc

In [14]:
def validate(model, dev_loader, criterion, threshold=12000):
    """return predictions, ground truths and indices, as val results includes
    data points sampled numerous times per epoch, reported results are for
    sampled points and so many points are included many times in the result
    post treatment is required to get a single score for each validation point.
    """
    model.eval()
    epoch_dev_losses = []
    epoch_dev_accuracy = []
    pred_holder = []
    real_holder = []
    index_holder = []

    with torch.no_grad():
        for batch_idx, (H, A, y, indices) in enumerate(dev_loader):
            H = H.squeeze(0)
            A = [a.squeeze(0) for a in A]
            y = y.squeeze(0)
            indices = indices.squeeze(0)

            out = model(H, A)

            # Filter out training points
            mask = indices >= threshold
            filtered_out = out[mask]
            filtered_y = y[mask]
            filtered_indices = indices[mask]



            # Calculate loss and accuracy only on filtered outputs
            if filtered_out.size(0) > 0:  # Ensure there are samples to evaluate
                dev_loss = criterion(filtered_out, filtered_y)
                dev_accuracy = accuracy(filtered_out, filtered_y)

                epoch_dev_losses.append(dev_loss.item())
                epoch_dev_accuracy.append(dev_accuracy)

                _, predicted = torch.max(filtered_out, 1)

                preds = [item.item() for item in predicted]
                pred_holder += preds
                reals = [item.item() for item in filtered_y]
                real_holder += reals
                inds = [item.item() for item in filtered_indices]
                index_holder += inds

    return np.mean(epoch_dev_losses), np.mean(epoch_dev_accuracy), pred_holder, real_holder, index_holder

In [15]:
loss, accuracy, preds, reals, inds = validate(model, dev_loader, criterion)

In [16]:
loss

1.6364245035638487

In [17]:
accuracy

0.7093572360587586

## $\color{blue}{Post-Treatment:}$

In [18]:
df_dev.columns

Index(['index', 'master', 'book_idx', 'book', 'chapter_idx', 'chapter',
       'author', 'content', 'vanilla_embedding', 'vanilla_preds',
       'vanilla_pseudo_book', 'vanilla_moe_e2e_soft_preds',
       'vanilla_moe_e2e_soft_pseudo_book', 'vanilla_moe_e2e_hard_preds',
       'vanilla_moe_e2e_hard_pseudo_book', 'vanilla_moe_e2e_soft_forest_preds',
       'vanilla_moe_e2e_soft_forest_pseudo_book', 'vanilla_moe_hard_pre_preds',
       'vanilla_moe_hard_pre_pseudo_book', 'vanilla_embedding.1',
       'direct_ft_preds', 'direct_ft_pseudo_book', 'ft_embedding',
       'embedding_ft_preds', 'embedding_ft_pseudo_book', 'direct_ft_moe_preds',
       'direct_ft_moe_pseudo_book', 'ft_embedding_pal', 'mistral_ots_book',
       'mistral_ft_book', 'gpt_4o_mini_ft_book', 'gpt_4o_mini_book',
       'gpt_4o_mini_book_checkpoint', 'ner_responses'],
      dtype='object')

In [19]:
from collections import namedtuple

# store results for every validation point in a named tuple

Check = namedtuple("Check", ['id', 'df_label','model_label', 'predicted_label'])

res = []
for i in range(len(inds)):
  point_ind = inds[i] - 12000
  df_label = df_dev.loc[point_ind]['chapter_idx']
  res.append(Check(point_ind,df_label,reals[i], preds[i]))

In [20]:
# note proportion of data points with NER neighbors

A_val[0].size()
adj = A_val[0]
neighbors = torch.sum(adj,dim=1)!=0
torch.sum(neighbors)/len(neighbors)

tensor(0.4673)

In [21]:
# bool if node has neighbors

neighbors_valid = neighbors[12000:]
neighbors_valid[:20]

tensor([False, False,  True, False,  True,  True, False,  True,  True,  True,
         True, False, False,  True,  True,  True,  True, False, False, False])

In [22]:
# list of results just for connected points

neighbors_res = []
neighbors_ids = []
for check in res:
  identity = check.id
  if neighbors_valid[identity]:
    neighbors_res.append(check)
    neighbors_ids.append(identity)
neighbors_ids = list(set(neighbors_ids))

In [23]:
# Get a list of predictions for each index (inds is index recorded in model output)

check = [None] * df_dev.shape[0]
for i in range(df_dev.shape[0]):
  hold = []
  results_ind = i + 12000
  for j in range(len(inds)):
    if (inds[j] == results_ind):
      hold.append(preds[j])
  check[i] = hold

In [24]:
# Use mode as central tendancy for multiple predictions of the same data point

def mode(lstr):
  unique, counts = np.unique(lstr, return_counts=True)
  max_idx = np.argmax(counts)
  return unique[max_idx]

predictions = [mode(el) for el in check]

In [25]:
gnn_count_n = 0
gnn_number_n = 0
gnn_count_i = 0
gnn_number_i = 0

van_count_n = 0
van_number_n = 0
van_count_i = 0
van_number_i = 0

for i in range(df_dev.shape[0]):
  if i in neighbors_ids:
    gnn_number_n += 1
    van_number_n += 1
    if df_dev.loc[i]["chapter_idx"] == predictions[i]:
      gnn_count_n += 1
    if df_dev.loc[i]["chapter_idx"] == df_dev.loc[i]['vanilla_preds']:
      van_count_n += 1
  else:
    van_number_i += 1
    gnn_number_i += 1
    if df_dev.loc[i]["chapter_idx"] == predictions[i]:
      gnn_count_i += 1
    if df_dev.loc[i]["chapter_idx"] == df_dev.loc[i]['vanilla_preds']:
      van_count_i += 1

print(f'GNN Neighbors: {gnn_count_n/gnn_number_n}')
print(f'Van Neighbors: {van_count_n/van_number_n}')

print('----')

print(f'GNN Loners: {gnn_count_i/gnn_number_i}')
print(f'Van Loners: {van_count_i/van_number_i}')


GNN Neighbors: 0.7018348623853211
Van Neighbors: 0.6490825688073395
----
GNN Loners: 0.5018939393939394
Van Loners: 0.553030303030303


In [26]:
# Update DataFrame

df_dev['gcn_preds'] = predictions
df_dev['connected'] = list(neighbors_valid)

In [27]:
path = "class/datasets/"
df_dev.to_pickle(path + "df_dev")