This code takes in a homogeneous graph and trains a GNN based on link prediction, with either GCN or GAT and using negative sampling with binary cross-entropy loss. It exports the resulting embeddings and loss plots.

In [None]:
# Imports
import os
from google.colab import drive
import torch
import torch.nn.functional as F
!pip install torch_geometric > /dev/null 2>&1
!pip install pytorch_lightning > /dev/null 2>&1
from torch_geometric.nn import GCNConv, GATConv
from torch_geometric.utils import negative_sampling
import pytorch_lightning as pl
from torch_geometric.data import DataLoader
import matplotlib.pyplot as plt
import glob

drive.mount('/content/drive', force_remount=True)
os.chdir('/content/drive/MyDrive/STANFORD/SENIOR (2024-2025)/CS224W/cs224w_project')

In [None]:
def find_filepath(decade, topic):
    """
    Given a decade and topic, this helper function finds the filepath of the plot.
    """
    pattern = f'karsen_redo/HOMOGNN/homogeneous_graph-{decade}-{decade+10}-{topic}*.pt'

    # Use glob to find files matching the pattern
    filepaths = glob.glob(pattern)

    if filepaths:
        return filepaths[0]  # Return the first match
    else:
        return None


In [None]:
"""
This defines the class used for the GNN. It is unsupervised and based on link prediction, and uses negative sampling with binary cross-entropy loss. It has two layers, either GCN or GAT.
"""
class LinkPredictionGNN(pl.LightningModule):
    def __init__(self, input_dim, hidden_dim, output_dim, lr):
        super().__init__()
        # GNN layers
        # Change to GATConv if using GAT
        self.conv1 = GCNConv(input_dim, hidden_dim)
        self.conv2 = GCNConv(hidden_dim, output_dim)

        # Decoder for link prediction
        self.decoder = torch.nn.Bilinear(output_dim, output_dim, 1)
        self.losses = []
        self.lr = lr

    def forward(self, x, edge_index):
        # First GNN layer
        x = self.conv1(x, edge_index)
        x = F.relu(x)
        x = F.dropout(x, training=self.training)

        # Second GNN layer
        x = self.conv2(x, edge_index)

        return x

    def training_step(self, batch, batch_idx):
        # Access x and edge_index directly from the batch (graph)
        x, edge_index = batch.x, batch.edge_index

        # Learn node embeddings
        z = self.forward(x, edge_index)

        # Positive edges (same as edge_index)
        pos_edge_index = edge_index

        # Negative sampling
        neg_edge_index = negative_sampling(
            edge_index,
            num_nodes=x.size(0),
            num_neg_samples=pos_edge_index.size(1)
        )

        # Compute link prediction loss
        pos_pred = self.decoder(
            z[pos_edge_index[0]],
            z[pos_edge_index[1]]
        ).squeeze()
        neg_pred = self.decoder(
            z[neg_edge_index[0]],
            z[neg_edge_index[1]]
        ).squeeze()

        # Binary cross-entropy loss
        pos_loss = F.binary_cross_entropy_with_logits(pos_pred, torch.ones_like(pos_pred))
        neg_loss = F.binary_cross_entropy_with_logits(neg_pred, torch.zeros_like(neg_pred))

        loss = pos_loss + neg_loss

        self.losses.append(loss.item())
        print(f"Step {batch_idx}: train_loss = {loss.item()}")  # Print the loss
        self.log('train_loss', loss)
        return loss

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), self.lr)


def train_link_prediction_gnn(graph, hidden_dim, output_dim, lr):
    """"
    Given the graph and a set of parameters, this function trains the GNN and returns the final node embeddings.
    """
    # Initialize model
    model = LinkPredictionGNN(input_dim=graph.x.shape[1], hidden_dim=hidden_dim, output_dim=output_dim, lr=lr)  # Use graph.x to get input dimension

    # Convert graph to a DataLoader (since there's only one graph, we wrap it in a list)
    train_loader = DataLoader([graph], batch_size=1)

    # Lightning Trainer
    trainer = pl.Trainer(
        max_epochs=100,
        accelerator='gpu' if torch.cuda.is_available() else 'cpu',
        devices=1,
        log_every_n_steps=5,  # Log every 5 steps
        callbacks=[
            pl.callbacks.EarlyStopping(monitor='train_loss', patience=10),
            pl.callbacks.ModelCheckpoint(monitor='train_loss')
        ]
    )

    # Train the model using the DataLoader
    trainer.fit(model, train_dataloaders=train_loader)  # Use DataLoader here

    # Get final node embeddings
    with torch.no_grad():
        final_embeddings = model(graph.x, graph.edge_index)  # Use graph.x and graph.edge_index directly

    return final_embeddings, model

In [None]:
def run_homo_gnn(TOPIC, DECADE, hidden_dim, output_dim, lr, embeddings_path_name, newspaper_path_name):
  """
  This function runs the homogeneous graph GNN, taking in a set of parameters and paths that determines which graph is used and what parameters are used in the GNN.
  """
  file_path = find_filepath(DECADE, TOPIC)
  filename = os.path.basename(file_path)
  num_newspapers = int(filename.split("-")[-1].split(".")[0])
  graph = torch.load(file_path)
  node_embeddings, trained_model = train_link_prediction_gnn(graph, hidden_dim, output_dim, lr)


  torch.save(node_embeddings, f'{embeddings_path_name}.pt')
  torch.save(node_embeddings[:num_newspapers], f'newspaper_node_embeddings-{DECADE}-{DECADE+10}-{TOPIC}') # Isolate newspaper embeddings from the node embeddings

  epochs = range(1, len(trained_model.losses) + 1)  # Epochs are index + 1

  # Plot loss function
  plt.plot(epochs, trained_model.losses)
  plt.xlabel('Epoch')
  plt.ylabel('Loss')
  plt.title('Training Loss over Epochs')
  plt.legend()
  plt.savefig(f'{embeddings_path_name}.png')
  plt.show()

In [None]:
"""
This cell trains the GNN for a set of parameters and saves the resulting embeddings and loss plots in predetermined paths.
"""

conv = 'gcn'
sizes = [(64, 32), (128, 64)]
num_layers = 2
lr = 0.005

for decade in [60]:
  for topic in ['labor', 'civil-rights', 'macro']:
    for hidden_size, out_size in sizes:
        print(f'\n DECADE {decade}, TOPIC {topic}, HIDDEN SIZE {hidden_size}, OUTPUT SIZE {out_size}')
        embeddings_path_name = f'karsen_redo/HOMOGNN/{topic}_{decade}{decade+10}/homo_{conv}_h{hidden_size}_o{out_size}_l{num_layers}_lr{lr*1000}'
        newspaper_path_name = f'karsen_redo/HOMOGNN/{topic}_{decade}{decade+10}/newspaper_embeds-{conv}_h{hidden_size}_o{out_size}_l{num_layers}_lr{lr*1000}.pt'


          # Train the GNN and get embeddings
        run_homo_gnn(topic, decade, hidden_size, out_size, lr, embeddings_path_name, newspaper_path_name)