# CS 224W Project: Text Augmented Graphs (OGBN-Arxiv)

### This Colab contains code for experiments using various techniques for combining textual and graph information, tested on the OGBN-Arxiv dataset

## Install PyG and other required libraries

In [1]:
import argparse
import torch
import torch.nn.functional as F
torch_version = str(torch.__version__)
scatter_src = f"https://pytorch-geometric.com/whl/torch-{torch_version}.html"
sparse_src = f"https://pytorch-geometric.com/whl/torch-{torch_version}.html"
!pip install torch-scatter -f $scatter_src
!pip install torch-sparse -f $sparse_src
!pip install torch-geometric
!pip install ogb
!pip install faiss-gpu
import torch_geometric
import torch_geometric.transforms as T
from torch_geometric.nn import GCNConv, SAGEConv
from ogb.nodeproppred import PygNodePropPredDataset, Evaluator
import numpy as np
import pickle
device = f'cuda' if torch.cuda.is_available() else 'cpu'
device = torch.device(device)

Looking in links: https://pytorch-geometric.com/whl/torch-2.1.0+cu121.html
Looking in links: https://pytorch-geometric.com/whl/torch-2.1.0+cu121.html


## Link the Colab to your Google Drive

We use Google Drive to load pre trained LM embeddings and logits. Instructions to download these embeddings are provided later in the Colab


In [2]:
from google.colab import drive
drive.mount("/content/drive/")

Drive already mounted at /content/drive/; to attempt to forcibly remount, call drive.mount("/content/drive/", force_remount=True).


## Add code to load the OGBN Arxiv dataset


In [3]:
DATASET = "ogbn-arxiv"
class LoadData:
    """
    A class used to load and process graph data using the PyTorch Geometric (PyG) library.
    """

    def load_data(self):
        """
        Loads the graph dataset, applies necessary transformations, and retrieves split indices.

        Returns:
        --------
        data: PyG data object
            The transformed graph data.
        split_idx: dict
            A dictionary containing the train, validation, and test split indices.
        num_classes: int
            The number of classes in the dataset.
        """
        dataset = PygNodePropPredDataset(name=DATASET)
        data = dataset[0]
        transform = T.Compose([T.ToUndirected(), T.ToSparseTensor()])
        data = transform(data)
        split_idx = dataset.get_idx_split()
        return data, split_idx, dataset.num_classes

## Define the model training and loss function



In [4]:
class Loss:
    """
    A class for defining loss computation in neural network training.

    """

    def get_loss(self, out, labels, train_idx):
        """
        Computes the cross-entropy loss between the output predictions and the true labels.

        Parameters:
        -----------
        out : torch.Tensor
            The output predictions from the neural network model, typically the logits.
        labels : torch.Tensor
            The true labels for the training data.
        train_idx : torch.Tensor or list
            The indices of the training data samples.

        Returns:
        --------
        torch.Tensor
            The computed cross-entropy loss.
        """
        return F.cross_entropy(out, labels[train_idx])


def train(model, data, train_idx, optimizer, loss_obj):
    """
    Trains a neural network model for one epoch.

    Parameters:
    -----------
    model : torch.nn.Module
        The neural network model to be trained.
    data : object
        The data object containing features and adjacency information.
    train_idx : torch.Tensor or list
        The indices of the training data.
    optimizer : torch.optim.Optimizer
        The optimizer used for updating model weights.
    loss_obj : Loss
        An instance of the Loss class to compute the loss.

    Returns:
    --------
    float
        The loss value computed for this training epoch.
    """
    model.train()
    optimizer.zero_grad()
    out = model(data.x, data.adj_t, train_idx)[train_idx]
    loss = loss_obj.get_loss(out, data.y.squeeze(1), train_idx)
    loss.backward()
    optimizer.step()
    return loss.item()

## Define the model testing function


In [5]:
@torch.no_grad()
def test(model, data, split_idx, evaluator):
    """
    Evaluates the performance of a trained model on training, validation, and test datasets.

    Parameters:
    -----------
    model : torch.nn.Module
        The neural network model to be evaluated.
    data : object
        The data object containing features, adjacency information, and labels.
    split_idx : dict
        A dictionary with keys 'train', 'valid', and 'test' mapping to the respective data indices.
    evaluator : object
        An object used to evaluate the model's predictions. It must have an 'eval' method.

    Returns:
    --------
    tuple: (train_acc, valid_acc, test_acc)
        A tuple containing the accuracy on the training, validation, and test sets.
    """
    model.eval()

    out = model(data.x, data.adj_t)
    y_pred = out.argmax(dim=-1, keepdim=True).cpu()

    # Evaluate accuracy on training, validation, and test sets
    train_acc = evaluator.eval({
        'y_true': data.y[split_idx['train']],
        'y_pred': y_pred[split_idx['train']],
    })['acc']
    valid_acc = evaluator.eval({
        'y_true': data.y[split_idx['valid']],
        'y_pred': y_pred[split_idx['valid']],
    })['acc']
    test_acc = evaluator.eval({
        'y_true': data.y[split_idx['test']],
        'y_pred': y_pred[split_idx['test']],
    })['acc']

    return train_acc, valid_acc, test_acc


## Define Hyperparameters

Our code support two models


1.   GraphSAGE (torch_geometric.nn.models.GraphSAGE)
2.   GCN (torch_geometric.nn.models.GCN)




In [6]:
hyperparams = {'model' : 'GraphSAGE', 'hidden_layer_size' : 128, 'num_layers' : 3, 'dropout' : 0.5, 'learning_rate' : 1e-3, 'epochs' : 1000}

## Create an end-to-end training loop



In [7]:
def train_loop(load_data_obj, loss_obj, hyperparams):
  """
    Executes the training loop for a graph neural network using specified hyperparameters.


    Parameters:
    -----------
    load_data_obj : LoadData
        An instance of the LoadData class that is used to load the dataset.
    loss_obj : Loss
        An instance of the Loss class that defines the loss function to be used during training.
    hyperparams : dict
        A dictionary containing hyperparameters for the model. Expected keys are:
        'model' (str), 'hidden_layer_size' (int), 'num_layers' (int), 'out_channels' (int),
        'dropout' (float), 'learning_rate' (float), 'epochs' (int).

    Returns:
    --------
    tuple:
        best_test_acc (float): The highest test accuracy achieved during training.
        model: The trained model instance.
  """
  data, split_idx, num_classes = load_data_obj.load_data()
  if hyperparams['model'] == 'GraphSAGE':
    model = torch_geometric.nn.models.GraphSAGE(data.x.shape[1], hidden_channels=hyperparams['hidden_layer_size'],
                                              num_layers=hyperparams['num_layers'], out_channels=num_classes, dropout=hyperparams['dropout']).to(device)
  else:
    model = torch_geometric.nn.models.GCN(data.x.shape[1], hidden_channels=hyperparams['hidden_layer_size'],
                                              num_layers=hyperparams['num_layers'], out_channels=num_classes, dropout=hyperparams['dropout']).to(device)
  optimizer = torch.optim.Adam(model.parameters(), lr=hyperparams['learning_rate'])
  evaluator = Evaluator(name=DATASET)
  train_idx = split_idx['train']
  model.to(device)
  data.to(device)
  valid_accs=[]
  test_accs=[]
  for epoch in range(hyperparams['epochs']):
    loss = train(model, data, train_idx, optimizer, loss_obj)
    result = test(model, data, split_idx, evaluator)
    train_acc, valid_acc, test_acc = result
    valid_accs.append(valid_acc)
    test_accs.append(test_acc)
    print(f'Epoch: {epoch:02d}, '
          f'Loss: {loss:.4f}, '
          f'Train: {100 * train_acc:.2f}%, '
          f'Valid: {100 * valid_acc:.2f}% '
          f'Test: {100 * test_acc:.2f}%')
  best_test_acc = test_accs[np.argmax(valid_accs)]
  print(f'Best Test accuracy is {best_test_acc}')
  return best_test_acc, model

## Run the training loop using the Base Model

In [8]:
load_data=LoadData()
loss_obj=Loss()
standard_acc, base_model = train_loop(load_data, loss_obj, hyperparams)

Epoch: 00, Loss: 3.6682, Train: 11.00%, Valid: 22.97% Test: 21.56%
Epoch: 01, Loss: 3.6248, Train: 11.09%, Valid: 23.00% Test: 21.56%
Epoch: 02, Loss: 3.5821, Train: 12.13%, Valid: 23.21% Test: 21.71%
Epoch: 03, Loss: 3.5354, Train: 15.53%, Valid: 24.33% Test: 22.50%
Epoch: 04, Loss: 3.4821, Train: 21.51%, Valid: 27.00% Test: 24.45%
Epoch: 05, Loss: 3.4190, Train: 25.42%, Valid: 28.87% Test: 25.96%
Epoch: 06, Loss: 3.3461, Train: 26.96%, Valid: 29.55% Test: 26.56%
Epoch: 07, Loss: 3.2713, Train: 27.44%, Valid: 29.69% Test: 26.67%
Epoch: 08, Loss: 3.2037, Train: 27.43%, Valid: 29.46% Test: 26.48%
Epoch: 09, Loss: 3.1528, Train: 24.90%, Valid: 23.63% Test: 20.96%
Epoch: 10, Loss: 3.1300, Train: 18.13%, Valid: 8.02% Test: 6.15%
Epoch: 11, Loss: 3.1302, Train: 17.91%, Valid: 7.63% Test: 5.86%
Epoch: 12, Loss: 3.1323, Train: 17.91%, Valid: 7.63% Test: 5.86%
Epoch: 13, Loss: 3.1219, Train: 17.91%, Valid: 7.63% Test: 5.86%
Epoch: 14, Loss: 3.0938, Train: 17.91%, Valid: 7.63% Test: 5.86%
Epoch

# Load the dataset with LM initialized embeddings

We finetune an MPNet model on the ogbn-arxiv task using the title+abstract as an input feature. We then use the embeddings generated by this model as our initial node features for the GNN.

The embeddings are available at https://drive.google.com/file/d/184qquWQuXbSog2PDZMG5xuWZU043hn3u/view?usp=sharing, please create a copy of this file in your Google Drive account and update the filepath in the code accordingly



In [9]:
filepath= "/content/drive/Shareddrives/CS224W Project/graph_embeddings/finetuned/mpnet_arxiv.pkl"
embs = pickle.load(open(filepath, "rb"))


In [10]:
class LoadDataLMInit(LoadData):
    """
    A class extending LoadData to load graph data with initial node embeddings.

    Attributes:
    -----------
    embs : array-like
        An array of node embeddings used to initialize the node features in the graph dataset.
    """

    def __init__(self, embs):
        """
        The constructor for LoadDataLMInit class.

        Initializes the LoadDataLMInit instance with the provided node embeddings.

        Parameters:
        -----------
        embs : array-like
            An array-like structure containing node embeddings. Each element in the array
            represents the embedding of a node in the graph.
        """
        self.embs = embs
        super(LoadDataLMInit, self).__init__()

    def load_data(self):
        """
        Loads the graph dataset with initial node embeddings, applies transformations, and retrieves split indices.

        Returns:
        --------
        data : PyG data object
            The graph data object with initialized node features.
        split_idx : dict
            A dictionary containing the indices for train, validation, and test splits.
        num_classes : int
            The number of classes in the dataset.
        """
        dataset = PygNodePropPredDataset(name=DATASET)
        data = dataset[0]
        embs = torch.nn.functional.normalize(torch.tensor(self.embs), dim=-1)
        data.x = torch.tensor(self.embs)
        split_idx = dataset.get_idx_split()
        transform = T.Compose([T.ToUndirected(), T.ToSparseTensor()])
        data = transform(data)
        num_classes = dataset.num_classes
        return data, split_idx, num_classes


## Run the pipeline with LM initialized embeddings

In [None]:
load_data=LoadDataLMInit(embs)
lminit_acc, lminit_model = train_loop(load_data, loss_obj, hyperparams)

Epoch: 00, Loss: 3.6924, Train: 1.83%, Valid: 3.60% Test: 4.25%
Epoch: 01, Loss: 3.6616, Train: 4.53%, Valid: 9.02% Test: 9.91%
Epoch: 02, Loss: 3.6309, Train: 30.15%, Valid: 39.30% Test: 37.35%
Epoch: 03, Loss: 3.5958, Train: 46.82%, Valid: 51.64% Test: 50.88%
Epoch: 04, Loss: 3.5544, Train: 50.87%, Valid: 54.65% Test: 54.58%
Epoch: 05, Loss: 3.5045, Train: 52.80%, Valid: 56.15% Test: 56.55%
Epoch: 06, Loss: 3.4433, Train: 52.78%, Valid: 56.52% Test: 57.17%
Epoch: 07, Loss: 3.3713, Train: 52.14%, Valid: 56.24% Test: 57.15%
Epoch: 08, Loss: 3.2855, Train: 51.21%, Valid: 55.72% Test: 56.90%
Epoch: 09, Loss: 3.1847, Train: 50.23%, Valid: 55.10% Test: 56.39%
Epoch: 10, Loss: 3.0684, Train: 49.26%, Valid: 54.51% Test: 56.00%
Epoch: 11, Loss: 2.9403, Train: 48.30%, Valid: 54.01% Test: 55.61%
Epoch: 12, Loss: 2.8038, Train: 47.47%, Valid: 53.41% Test: 55.29%
Epoch: 13, Loss: 2.6633, Train: 46.81%, Valid: 52.72% Test: 54.75%
Epoch: 14, Loss: 2.5287, Train: 46.13%, Valid: 51.73% Test: 53.99%
E

## Add virtual edges to the Dataset

We attempt data augmentation by adding virtual edges, to capture relationships between nodes that have similar titles and abstracts. We add edges between each node and its k nearest neighbours based on the LLM embeddings of the MPNet model explained above. We use the [FAISS](https://github.com/facebookresearch/faiss) library for fast calculation of nearest neighbours.

In [None]:
import faiss
class LoadDataVirtualEdges(LoadData):
  """
    A class extending LoadData to load graph data with additional virtual edges.

    Attributes:
    -----------
    embs : array-like
        An array of node embeddings used to find nearest neighbors and create virtual edges.
    k : int
        The number of nearest neighbors to consider for creating virtual edges.
  """

  def __init__(self, embs, k):
    self.embs=embs
    self.k=k
    super(LoadDataVirtualEdges, self).__init__()

  def get_nearest_neighbours(self, embs, k):
      """
      Computes the k-nearest neighbors for each node in the graph based on embeddings.

      Parameters:
      -----------
      embs : array-like
          An array-like structure containing node embeddings.
      k : int
          The number of nearest neighbors to find for each node.

      Returns:
      --------
      numpy.ndarray
          An array of indices representing the k-nearest neighbors for each node.
      """
      res = faiss.StandardGpuResources()
      index = faiss.IndexFlatL2(embs.shape[1])   # build the index
      gpu_index_flat = faiss.index_cpu_to_gpu(res, 0, index)
      gpu_index_flat.add(embs)                  # add vectors to the index
      D, I = gpu_index_flat.search(embs, k+1)     # actual search
      return I

  def convert_to_edge_index(self, I, k):
    """
    Converts nearest neighbor information into PyTorch Geometric edge indices.

    Parameters:
    -----------
    I : numpy.ndarray
        An array of indices representing the nearest neighbors for each node.
    k : int
        The number of nearest neighbors for each node.

    Returns:
    --------
    tuple: (src_nodes, dst_nodes)
        Two tensors representing the source and destination nodes of each virtual edge.
    """
    num_nodes = I.shape[0]
    src_nodes = torch.arange(num_nodes).repeat_interleave(k).to(device)
    dst_nodes = torch.tensor(I[:,1:]).flatten().to(device)
    return src_nodes, dst_nodes

  def load_data(self):
    """
    Loads the graph dataset, adds virtual edges based on nearest neighbors, applies transformations,
    and retrieves split indices.

    Returns:
    --------
    data : PyG data object
        The graph data object with virtual edges added.
    split_idx : dict
        A dictionary containing the indices for train, validation, and test splits.
    num_classes : int
        The number of classes in the dataset.
    """
    dataset = PygNodePropPredDataset(name='ogbn-arxiv')
    data = dataset[0]
    I=self.get_nearest_neighbours(self.embs,self.k)
    S,D=self.convert_to_edge_index(I,k)
    data.edge_index=torch.cat((data.edge_index,torch.stack((S,D)).to('cpu')),dim=1)
    transform=T.Compose([T.ToUndirected(), T.ToSparseTensor()])
    data=transform(data)
    num_classes = dataset.num_classes
    split_idx = dataset.get_idx_split()
    return data, split_idx, num_classes

## Run the pipeline with Virtual Edges added

In [None]:
k=4
load_data=LoadDataVirtualEdges(embs,k)
virtual_edge_acc, virtual_edge_model = train_loop(load_data, loss_obj, hyperparams)

## Add KL Divergence Regularization

KLLoss extends the basic loss computation (like cross-entropy) by adding a regularization term
    based on KL divergence. This is useful in scenarios where one wishes to penalize the divergence
    between the model's output distribution and a target distribution, which in this case is provided
    by 'lm_logits'. The lambda (lmbda) parameter controls the weight of this regularization term.


The MPNET logits are available at https://drive.google.com/file/d/13yY-y7FEpFhe2lOOVX-LMZy58oEFxTS4/view?usp=sharing, please create a copy of this file in your Google Drive account and update the filepath in the code accordingly


In [None]:
filepath_logits='/content/drive/Shareddrives/CS224W Project/graph_embeddings/finetuned/mpnet_logits_arxiv.pkl'
lm_logits=pickle.load(open(filepath_logits, "rb"))

In [None]:
class KLLoss(Loss):
    """
    A class extending Loss to incorporate Kullback-Leibler (KL) divergence as a regularization term.

    Attributes:
    -----------
    lm_logits : torch.Tensor
        The logits from a language model or a pre-defined target distribution.
    lmbda : float
        The weight (lambda) of the KL divergence regularization term in the overall loss.
    """

    def __init__(self, lm_logits, lmbda):
        """
        The constructor for KLLoss class.

        Initializes the KLLoss instance with provided language model logits and lambda value.

        Parameters:
        -----------
        lm_logits : torch.Tensor
            The logits from a language model or a pre-defined target distribution.
        lmbda : float
            The weight (lambda) of the KL divergence regularization term in the overall loss.
        """
        self.lmbda = lmbda
        self.lm_logits = lm_logits
        super(KLLoss, self).__init__()

    def get_loss(self, out, labels, train_idx):
        """
        Parameters:
        -----------
        out : torch.Tensor
            The output predictions from the neural network model, typically the logits.
        labels : torch.Tensor
            The true labels for the training data.
        train_idx : torch.Tensor or list
            The indices of the training data samples.

        Returns:
        --------
        torch.Tensor
            The computed loss, combining cross-entropy and KL divergence.
        """
        loss = F.cross_entropy(out, labels[train_idx])
        reg_penalty = torch.nn.KLDivLoss(log_target=True)
        reg = reg_penalty(F.log_softmax(out, dim=1), F.log_softmax(torch.tensor(self.lm_logits[train_idx]).to(device), dim=1))
        loss += reg * self.lmbda
        return loss


In [None]:
lmbda=1
load_data=LoadData()
kl_loss_obj=KLLoss(lm_logits,lmbda)
kl_best_acc = train_loop(load_data, kl_loss_obj, hyperparams)

## Ensembling

We demonstrate an ensemble approach combining a base graph neural network model with scaled logits from a language model.



In [None]:
data,split_idx,_=LoadData().load_data()
data.to(device)
base_model.to(device)
final_out = base_model(data.x, data.adj_t)
lm_logits = torch.tensor(lm_logits).to(device)
max_test_acc=0
max_i=0
for i in [0.2,0.4,0.6,0.8,1]:
  final_out = final_out + i*lm_logits
  y_pred = final_out.argmax(dim=-1, keepdim=True)
  evaluator = Evaluator(name=DATASET)
  test_acc = evaluator.eval({
          'y_true': data.y[split_idx['test']],
          'y_pred': y_pred[split_idx['test']],
      })['acc']
  if test_acc>max_test_acc:
    max_test_acc=test_acc
    max_i=i
print(f'Max test acc: {max_test_acc}')