# Loading Benchmark Datasets

In [1]:
import tsl
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import MessagePassing
import math

from torch.optim import Adam
from tsl.datasets import PeMS04, PeMS07, PeMS08, PemsBay
from tsl.datasets import MetrLA, ElectricityBenchmark, SolarBenchmark, TrafficBenchmark, ExchangeBenchmark
from tsl.data import SpatioTemporalDataset
from tsl.data.datamodule import (SpatioTemporalDataModule,
                                 TemporalSplitter)
from tsl.data.preprocessing import StandardScaler
from einops.layers.torch import Rearrange

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

from tqdm import tqdm

### Approaches to Connectivity Construction

In [2]:
from torch_geometric.data import Data

In [3]:
def create_ramanujan_expander(num_nodes, d):
    """
    Creates a d-regular graph and attempts to optimize it to have 
    Ramanujan properties by iteratively improving the spectral gap.
    
    Args:
        num_nodes: Number of nodes in the graph
        d: Degree of each node (must be even for d-regular graphs)
    
    Returns:
        edge_index: Tensor of shape [2, num_edges] containing the edge indices
    """
    if d % 2 != 0:
        raise ValueError("For d-regular graphs, d must be even")
    
    if d >= num_nodes:
        raise ValueError("d must be less than the number of nodes")
    
    # Initialize with a random d-regular graph
    stubs = torch.arange(num_nodes).repeat_interleave(d)
    stubs = stubs[torch.randperm(stubs.size(0))]
    
    edges = []
    for i in range(0, stubs.size(0), 2):
        u, v = stubs[i].item(), stubs[i+1].item()
        if u != v and (u, v) not in edges and (v, u) not in edges:
            edges.append((u, v))
            edges.append((v, u))
    
    edge_index = torch.tensor(edges, dtype=torch.long).t()
    
    # Optimize the graph to have better Ramanujan properties
    best_edge_index = edge_index.clone()
    best_lambda2 = float('inf')
    ramanujan_bound = 2 * np.sqrt(d - 1)
    
    # Simple optimization: Try several random configurations and keep the best
    for _ in range(10):  # Try 10 different configurations
        # Create a new random d-regular graph
        stubs = torch.arange(num_nodes).repeat_interleave(d)
        stubs = stubs[torch.randperm(stubs.size(0))]
        
        edges = []
        for i in range(0, stubs.size(0), 2):
            u, v = stubs[i].item(), stubs[i+1].item()
            if u != v and (u, v) not in edges and (v, u) not in edges:
                edges.append((u, v))
                edges.append((v, u))
        
        edge_index = torch.tensor(edges, dtype=torch.long).t()
        
        # Check spectral properties
        adj_matrix = torch.zeros((num_nodes, num_nodes))
        for i, j in edge_index.t():
            adj_matrix[i, j] = 1
        
        eigenvalues = torch.linalg.eigvalsh(adj_matrix)
        eigenvalues = torch.sort(torch.abs(eigenvalues))[0]
        lambda2 = eigenvalues[-2].item()
        
        # Keep the best configuration
        if lambda2 < best_lambda2:
            best_lambda2 = lambda2
            best_edge_index = edge_index.clone()
            
            # If we found a Ramanujan graph, we can stop
            if lambda2 <= ramanujan_bound:
                break
    
    # Create a PyTorch Geometric Data object
    data = Data(x=torch.ones(num_nodes, 1), edge_index=best_edge_index)
    
    is_ramanujan = best_lambda2 <= ramanujan_bound
    print(f"Created {'Ramanujan' if is_ramanujan else 'non-Ramanujan'} graph with λ₂ = {best_lambda2:.4f} (bound: {ramanujan_bound:.4f})")
    
    return best_edge_index


In [4]:
benchmark_electricity = ElectricityBenchmark(root='data/benchmark_electricity')
benchmark_exchange = ExchangeBenchmark(root='data/benchmark_exchange')
benchmark_solar = SolarBenchmark(root='data/benchmark_solar')
benchmark_traffic = TrafficBenchmark(root='data/benchmark_traffic')

benchmarks = [benchmark_electricity, benchmark_exchange, benchmark_solar, benchmark_traffic]
benchmark_loader_dict = {}

for benchmark in benchmarks:
    print(f"Processing {benchmark.name}...")
    
    edge_index = create_ramanujan_expander(num_nodes=len(benchmark.nodes), d=6)

    # subclass of torch.utils.data.Dataset
    torch_dataset = SpatioTemporalDataset(
        target=benchmark.dataframe(),
        connectivity=(edge_index, None),
        mask=benchmark.mask,
        horizon=96,
        window=96,
        stride=1
    )

    scalers = {'target': StandardScaler(axis=(0, 1))}

    # Split data sequentially:
    #   |------------ dataset -----------|
    #   |--- train ---|- val -|-- test --|
    splitter = TemporalSplitter(val_len=0.1, test_len=0.2)

    dm = SpatioTemporalDataModule(
        dataset=torch_dataset,
        scalers=scalers,
        splitter=splitter,
        batch_size=64,
    )

    dm.setup()

    # train_loader = dm.train_dataloader()
    # val_loader = dm.val_dataloader()
    # test_loader = dm.test_dataloader()
    
    # benchmark_loader_dict[benchmark.name] = {
    #     'train': train_loader,
    #     'val': val_loader,
    #     'test': test_loader
    # }
    
    benchmark_loader_dict[benchmark.name] = dm

Processing ElectricityBenchmark...
Created Ramanujan graph with λ₂ = 4.4055 (bound: 4.4721)
Processing ExchangeBenchmark...
Created Ramanujan graph with λ₂ = 2.5929 (bound: 4.4721)
Processing SolarBenchmark...
Created Ramanujan graph with λ₂ = 4.3100 (bound: 4.4721)
Processing TrafficBenchmark...
Created Ramanujan graph with λ₂ = 4.4207 (bound: 4.4721)


In [10]:
for key, value in benchmark_loader_dict.items():
    print(key)
    train = value.train_dataloader()
    val = value.val_dataloader()
    test = value.test_dataloader()
    print(len(train), len(val), len(test))

ElectricityBenchmark
292 32 82
ExchangeBenchmark
81 8 24
SolarBenchmark
587 64 164
TrafficBenchmark
193 21 55


In [45]:
benchmark_loader_dict['ElectricityBenchmark']['train'].dataset[:10]

StaticBatch(
  input=(x=[b=10, t=96, n=321, f=1], edge_index=[2, e=1912]),
  target=(y=[b=10, t=96, n=321, f=1]),
  has_mask=True,
  transform=[x, y]
)