This file creates the heterogeneous graph for a given topic-decade, by using the datasets on those topic-decades. It exports the resulting graph.

In [None]:
# Imports

import pickle
from google.colab import drive
import os
import numpy as np

from tqdm import tqdm
from collections import defaultdict
import ast
!pip install datasets > /dev/null 2>&1
from datasets import load_dataset
import pandas as pd
from datetime import datetime

%pip install torch > /dev/null 2>&1
import torch
%pip install torch-geometric > /dev/null 2>&1
from torch_geometric.data import HeteroData

drive.mount('/content/drive', force_remount=True)
os.chdir('/content/drive/MyDrive/STANFORD/SENIOR (2024-2025)/CS224W/cs224w_project')

In [None]:
def create_hetero_graph(topic, decade):
  """
  Creates a heterogeneous graph for a given topic-decade.
  """
  data = HeteroData()

  newspaper_df = pd.read_pickle(f"karsen_redo/DATA/outlet_metadata_{decade}{decade+10}_{topic}.pkl")
  topic_data = load_dataset(f"pnsahoo/{decade}-{decade+10}-{topic}-embedding")['train']
  topic_data_df = topic_data.to_pandas()

  article_node_index = {}
  newspaper_node_index = {}

  # (source nodes, target nodes)
  article_to_newspaper_edges = []
  batch_size = 1000

  for i in tqdm(range(0, len(topic_data_df), batch_size), desc="Processing batches"):
      batch = topic_data_df[i: i + batch_size]

      l = len(batch['article']) # Number of articles in this batch


      for j in range(l):
          row = batch['newspaper_metadata'][i+j]
          # Track articles
          article_node_index[i+j] = len(article_node_index)

          for outlet in row: # Per newspaper
              names = ast.literal_eval(outlet['newspaper_title'])

              if names:
                  outlet_name = names[0] # Only look at first newspaper name in the metadata if multiple names
                  if outlet_name not in newspaper_node_index:
                      newspaper_node_index[outlet_name] = len(newspaper_node_index) # Track newspapers
                  # Add edges if newspaper publishes article
                  article_to_newspaper_edges.append((article_node_index[i + j], newspaper_node_index[outlet_name]))


  # Construct bidirectional edges
  edge_index_tensor = torch.tensor(article_to_newspaper_edges, dtype=torch.long).t().contiguous()
  data['article', 'published_in', 'newspaper'].edge_index = edge_index_tensor
  newspaper_to_article_edges = [(target, source) for source, target in article_to_newspaper_edges]
  reverse_edge_index_tensor = torch.tensor(newspaper_to_article_edges, dtype=torch.long).t().contiguous()
  data['newspaper', 'publishes', 'article'].edge_index = reverse_edge_index_tensor

  # Newspaper node features
  numeric_columns = [
      col for col in newspaper_df.columns
      if col not in ['outlet_name', 'avg_embedding', 'city', 'state', 'newspaper_city', 'newspaper_state', 'newspaper_coordinates']
  ]
  features_df = newspaper_df[numeric_columns]
  newspaper_features_tensor = torch.tensor(features_df.values, dtype=torch.float)
  data['newspaper'].x = newspaper_features_tensor

  # Article node features
  embeddings = []
  wire_coordinates = []
  date_features = []

  for article in topic_data:
      embeddings.append(article['embedding'])
      wire_coordinates.append(article['wire_coordinates'])

      # Find the earliest date
      date_list = article['dates']
      if date_list:
          # Convert all dates in the list to datetime objects and find the earliest
          earliest_date = min(datetime.strptime(date, '%b-%d-%Y') for date in date_list)
          earliest_timestamp = earliest_date.timestamp()
      else:
          earliest_timestamp = 0.0

      date_features.append(earliest_timestamp)


  embeddings_df = pd.DataFrame(embeddings)
  wire_coordinates_df = pd.DataFrame(wire_coordinates, columns=['latitude', 'longitude'])
  date_features_df = pd.DataFrame(date_features, columns=['earliest_timestamp'])

  # Convert to tensor
  embeddings_tensor = torch.tensor(embeddings_df.values, dtype=torch.float)
  wire_coordinates_tensor = torch.tensor(wire_coordinates_df.values, dtype=torch.float)
  date_features_tensor = torch.tensor(date_features_df.values, dtype=torch.float)
  article_features = torch.cat((embeddings_tensor, wire_coordinates_tensor, date_features_tensor), dim=1)

  data['article'].x = article_features

  # Save the graph
  torch.save(data, f'karsen_redo/HETEROGNN/{topic}_{decade}{decade+10}_hetgraph.pt')
  with open(f"karsen_redo/HETEROGNN/newspaper_node_index-{decade}-{decade+10}-{topic}.pkl", "wb") as f:
      pickle.dump(newspaper_node_index, f)



In [None]:
# Edit depending on the topic-decade you would like to create the graph for

TOPIC = "labor"
DECADE = 50
create_hetero_graph(TOPIC, DECADE)