In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import networkx as nx
import contextily

import torch
import torch.nn as nn
import torch.nn.functional as F

# Plots

In [None]:
from torch_geometric_temporal.dataset import METRLADatasetLoader

loader = METRLADatasetLoader('data')
dataset = loader.get_dataset(num_timesteps_in=12, num_timesteps_out=12)

In [None]:
graph = nx.Graph()
for edge, weight in zip(dataset.edge_index.T, dataset.edge_weight):
    graph.add_edge(*edge, weight=weight)
graph.remove_edges_from(nx.selfloop_edges(graph))

In [None]:
locations = pd.read_csv('https://raw.githubusercontent.com/tijsmaas/TrafficPrediction/master/data/metr-la/graph_sensor_locations.csv')
locations = locations[['longitude', 'latitude']].to_numpy()

In [None]:
speed = np.load('data/node_values.npy')
speed = speed[:,:,0] # timestamp, sensor, feature[speed, timeofday]

In [None]:
positions = dict(zip(graph.nodes, locations[graph.nodes]))
node_colors = speed[0,:][graph.nodes]
edge_colors = [graph[x][y]['weight'] for x,y in graph.edges]

fig, ax = plt.subplots(figsize=(10,10))
nx.draw(graph, positions, node_color=node_colors, edge_color=edge_colors, edgecolors='black', ax=ax, node_size=50)
contextily.add_basemap(ax=ax, crs=4326, source=contextily.providers.Stamen.TonerLite)

In [None]:
with plt.style.context('default'):
    plt.figure(figsize=(10,5))
    plt.plot(speed[-200:,6], label='6')
    plt.plot(speed[-200:,91], label='91')
    plt.plot(speed[-200:,93], label='93')
    plt.plot(speed[-200:,136], label='136')
    plt.legend()
    plt.title('Sensor data of adjacent nodes')

In [None]:
adjacency = np.load('data/adj_mat.npy')

In [None]:
fig, ax = plt.subplots(figsize=(5,5))
ax.pcolormesh(adjacency[::-1,:])
ax.set_aspect('equal')

# Model training

In [None]:
import pytorch_lightning as pl
from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint
from pytorch_lightning.loggers import CSVLogger
from src import METRLADataModule, TemporalGNN
from src.printer import PrintMetricsCallback
from src.utils import read_logs

In [None]:
import warnings
warnings.filterwarnings('ignore', '.*does not have many workers.*')

In [None]:
data = METRLADataModule(root_dir='data', train_steps=12, predict_steps=12, num_workers=0, batch_size=16)
model = TemporalGNN(in_features=2, hidden_features=64)

In [None]:
best_checkpointer = ModelCheckpoint(
    save_top_k=1, save_last=True, monitor='val_loss', mode='max', filename='best')
csv_logger = CSVLogger('')
printer = PrintMetricsCallback(
    metrics=['val_loss', 'train_loss'])

trainer = pl.Trainer(
    log_every_n_steps=1,
    logger=csv_logger,
    callbacks=[best_checkpointer, printer],
    max_epochs=10,
    accelerator='auto',
    devices=1)

In [None]:
trainer.fit(model, data)

In [None]:
logs = read_logs('lightning_logs/version_9/metrics.csv')
plt.plot(logs['train_loss_step'])