# Introduction to Torch Spatio-Temporal (TSL)
official documentation: https://torch-spatiotemporal.readthedocs.io/en/latest/index.html

purpose of this notebook:
- data structures
- datasets
- models

In [24]:
import tsl
import torch
import torch_geometric

from torch.optim import Adam
from tsl.datasets import PeMS04, PeMS07, PeMS08, PemsBay
from tsl.datasets import MetrLA

import numpy as np
import pandas as pd

In [25]:
print(f"tsl version  : {tsl.__version__}")
print(f"torch version: {torch.__version__}")
print(f"torch_geometric version: {torch_geometric.__version__}")

tsl version  : 0.9.5
torch version: 2.4.0
torch_geometric version: 2.6.1


### Preprocessing

In [26]:
dataset_MetrLA = MetrLA(root='data/MetrLA')

print(f"Sampling period: {dataset_MetrLA.freq}")
print(f"Has missing values: {dataset_MetrLA.has_mask}")
print(f"Percentage of missing values: {(1 - dataset_MetrLA.mask.mean()) * 100:.2f}%")
print(f"Has exogenous variables: {dataset_MetrLA.has_covariates}")
print(f"Covariates: {', '.join(dataset_MetrLA.covariates.keys())}")

dataset_MetrLA  # type tsl.datasets.Dataset

Sampling period: <5 * Minutes>
Has missing values: True
Percentage of missing values: 8.11%
Has exogenous variables: True
Covariates: dist


  date_range = pd.date_range(df.index[0], df.index[-1], freq='5T')
  df = df.replace(to_replace=0., method='ffill')


MetrLA(length=34272, n_nodes=207, n_channels=1)

In [27]:
print(dataset_MetrLA.dist)

[[    0.      inf     inf ...     inf  8114.8 10009.7]
 [    inf     0.   2504.6 ...     inf     inf     inf]
 [    inf  1489.3     0.  ...     inf     inf  9837. ]
 ...
 [    inf     inf     inf ...     0.      inf     inf]
 [ 9599.8     inf     inf ...     inf     0.      inf]
 [10119.9  9374.8     inf ...     inf  9018.7     0. ]]


In [28]:
print(f"Default similarity: {dataset_MetrLA.similarity_score}")
print(f"Available similarity options: {dataset_MetrLA.similarity_options}")
print("==========================================")

sim = dataset_MetrLA.get_similarity("distance")  # or dataset_MetrLA.compute_similarity()

sim.shape

Default similarity: distance
Available similarity options: {'distance'}


(207, 207)

In [29]:
# get_connectivity uses get_similarity under the hood
connectivity = dataset_MetrLA.get_connectivity(threshold=0.1, include_self=False, normalize_axis=1, layout="edge_index")

connectivity[0].shape, connectivity[1].shape

((2, 1515), (1515,))

### Building a torch-ready dataset

convert dataset_MetrLA in (tsl.datasets.Dataset) format to (tsl.data.SpatioTemporalDataset) format

In [30]:
from tsl.data import SpatioTemporalDataset

In [31]:
df = dataset_MetrLA.dataframe()

df.head()

nodes,773869,767541,767542,717447,717446,717445,773062,767620,737529,717816,...,772167,769372,774204,769806,717590,717592,717595,772168,718141,769373
channels,0,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
2012-03-01 00:00:00,64.375,67.625,67.125,61.5,66.875,68.75,65.125,67.125,59.625,62.75,...,45.625,65.5,64.5,66.428574,66.875,59.375,69.0,59.25,69.0,61.875
2012-03-01 00:05:00,62.666668,68.555557,65.444443,62.444443,64.444443,68.111115,65.0,65.0,57.444443,63.333332,...,50.666668,69.875,66.666664,58.555557,62.0,61.111111,64.444443,55.888889,68.444443,62.875
2012-03-01 00:10:00,64.0,63.75,60.0,59.0,66.5,66.25,64.5,64.25,63.875,65.375,...,44.125,69.0,56.5,59.25,68.125,62.5,65.625,61.375,69.85714,62.0
2012-03-01 00:15:00,64.0,63.75,60.0,59.0,66.5,66.25,64.5,64.25,63.875,65.375,...,44.125,69.0,56.5,59.25,68.125,62.5,65.625,61.375,69.85714,62.0
2012-03-01 00:20:00,64.0,63.75,60.0,59.0,66.5,66.25,64.5,64.25,63.875,65.375,...,44.125,69.0,56.5,59.25,68.125,62.5,65.625,61.375,69.85714,62.0


In [32]:
# subclass of torch.utils.data.Dataset
torch_dataset = SpatioTemporalDataset(
    target=dataset_MetrLA.dataframe(),
    connectivity=connectivity,
    mask=dataset_MetrLA.mask,
    horizon=6,
    window=18,
    stride=1
)

torch_dataset

SpatioTemporalDataset(n_samples=34249, n_nodes=207, n_channels=1)

In [33]:
sample = torch_dataset[0]

"""
    sample is of 
    - type torch_geometric.data.Data
    - shape (window, num_nodes, num_features)
"""
sample.x.shape, sample

(torch.Size([18, 207, 1]),
 Data(
   input=(x=[t=18, n=207, f=1], edge_index=[2, e=1515], edge_weight=[e=1515]),
   target=(y=[t=6, n=207, f=1]),
   has_mask=True
 ))

In [34]:
"""
    t = time steps dimension
    n = node dimension
    e = edge dimension
    f = feature dimension
    b = batch dimension
"""
sample.pattern

{'x': 't n f',
 'mask': 't n f',
 'edge_index': '2 e',
 'edge_weight': 'e',
 'y': 't n f'}

In [35]:
batch = torch_dataset[:5]
batch

StaticBatch(
  input=(x=[b=5, t=18, n=207, f=1], edge_index=[2, e=1515], edge_weight=[e=1515]),
  target=(y=[b=5, t=6, n=207, f=1]),
  has_mask=True
)

In [36]:
from tsl.data.datamodule import (SpatioTemporalDataModule,
                                 TemporalSplitter)
from tsl.data.preprocessing import StandardScaler

# Normalize data using mean and std computed over time and node dimensions
scalers = {'target': StandardScaler(axis=(0, 1))}

# Split data sequentially:
#   |------------ dataset -----------|
#   |--- train ---|- val -|-- test --|
splitter = TemporalSplitter(val_len=0.1, test_len=0.2)

dm = SpatioTemporalDataModule(
    dataset=torch_dataset,
    scalers=scalers,
    splitter=splitter,
    batch_size=64,
)

dm.setup()
print(dm)

{Train dataloader: size=24642}
{Validation dataloader: size=2722}
{Test dataloader: size=6849}
{Predict dataloader: None}


In [37]:
# extracting the loaders separately for training, validation and testing (for pytorch-lightning, dm can be plugged in as is)
train_loader = dm.train_dataloader()
val_loader = dm.val_dataloader()
test_loader = dm.test_dataloader()

In [38]:
import torch.nn as nn

from tsl.nn.blocks.encoders import RNN
from tsl.nn.layers import NodeEmbedding, DiffConv
from einops.layers.torch import Rearrange  # reshape data with Einstein notation


class TimeThenSpaceModel(nn.Module):
    def __init__(self, input_size: int, n_nodes: int, horizon: int,
                 hidden_size: int = 32,
                 rnn_layers: int = 1,
                 gnn_kernel: int = 2):
        super(TimeThenSpaceModel, self).__init__()

        self.encoder = nn.Linear(input_size, hidden_size)

        self.node_embeddings = NodeEmbedding(n_nodes, hidden_size) # free params learned individually for each node (taming local effects in STGNNs, Cini et al.)

        self.time_nn = RNN(input_size=hidden_size,
                           hidden_size=hidden_size,
                           n_layers=rnn_layers,
                           cell='gru',
                           return_only_last_state=True)
        
        self.space_nn = DiffConv(in_channels=hidden_size,
                                 out_channels=hidden_size,
                                 k=gnn_kernel)

        self.decoder = nn.Linear(hidden_size, input_size * horizon)
        self.rearrange = Rearrange('b n (t f) -> b t n f', t=horizon)

    def forward(self, x, edge_index, edge_weight):
        # x: [batch time nodes features]
        x_enc = self.encoder(x)  # linear encoder: x_enc = xΘ + b
        x_emb = x_enc + self.node_embeddings()  # add node-identifier embeddings
        h = self.time_nn(x_emb)  # temporal processing: x=[b t n f] -> h=[b n f]
        z = self.space_nn(h, edge_index, edge_weight)  # spatial processing
        x_out = self.decoder(z)  # linear decoder: z=[b n f] -> x_out=[b n t⋅f]
        x_horizon = self.rearrange(x_out)
        return x_horizon

In [39]:
hidden_size = 32   #@param
rnn_layers = 1     #@param
gnn_kernel = 2     #@param

input_size = torch_dataset.n_channels   # 1 channel
n_nodes = torch_dataset.n_nodes         # 207 nodes
horizon = torch_dataset.horizon         # 12 time steps

stgnn = TimeThenSpaceModel(input_size=input_size,
                           n_nodes=n_nodes,
                           horizon=horizon,
                           hidden_size=hidden_size,
                           rnn_layers=rnn_layers,
                           gnn_kernel=gnn_kernel)
print(stgnn)

TimeThenSpaceModel(
  (encoder): Linear(in_features=1, out_features=32, bias=True)
  (node_embeddings): NodeEmbedding(n_nodes=207, embedding_size=32)
  (time_nn): RNN(
    (rnn): GRU(32, 32)
  )
  (space_nn): DiffConv(32, 32)
  (decoder): Linear(in_features=32, out_features=6, bias=True)
  (rearrange): Rearrange('b n (t f) -> b t n f', t=6)
)


In [40]:
epochs = 1
criterion = nn.MSELoss()
optimizer = Adam(stgnn.parameters(), lr=1e-3)

stgnn.train()

for epoch in range(epochs):
    for batch in train_loader:
        optimizer.zero_grad()
        
        x, edge_index, edge_weight, y = batch.x, batch.edge_index, batch.edge_weight, batch.y
        
        y_hat = stgnn(x, edge_index, edge_weight)
        loss = criterion(y_hat, y)
        
        loss.backward()
        optimizer.step()

    print(f"Epoch {epoch + 1}/{epochs} - loss: {loss.item():.4f}")

Epoch 1/1 - loss: 75.9901


### Pytorch Lightning

In [41]:
from tsl.metrics.torch import MaskedMAE, MaskedMAPE
from tsl.engines import Predictor

loss_fn = MaskedMAE()

metrics = {'mae': MaskedMAE(),
           'mape': MaskedMAPE(),
           'mae_at_15': MaskedMAE(at=2),  # '2' indicates the third time step,
                                          # which correspond to 15 minutes ahead
           'mae_at_30': MaskedMAE(at=5),
           'mae_at_60': MaskedMAE(at=11)}

# setup predictor
predictor = Predictor(
    model=stgnn,                   # our initialized model
    optim_class=torch.optim.Adam,  # specify optimizer to be used...
    optim_kwargs={'lr': 0.001},    # ...and parameters for its initialization
    loss_fn=loss_fn,               # which loss function to be used
    metrics=metrics                # metrics to be logged during train/val/test
)

In [42]:
import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint

checkpoint_callback = ModelCheckpoint(
    dirpath='logs',
    save_top_k=1,
    monitor='val_mae',
    mode='min',
)

trainer = pl.Trainer(max_epochs=1,
                     limit_train_batches=100,  # end an epoch after 100 updates
                     callbacks=[checkpoint_callback])

trainer.fit(predictor, datamodule=dm)

GPU available: True (mps), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
/opt/anaconda3/envs/torch-st/lib/python3.10/site-packages/pytorch_lightning/callbacks/model_checkpoint.py:654: Checkpoint directory /Users/HunJ/Documents/ResearchLab/CodeBases/temporal-gnn/notebooks/logs exists and is not empty.

  | Name          | Type               | Params | Mode 
-------------------------------------------------------------
0 | loss_fn       | MaskedMAE          | 0      | train
1 | train_metrics | MetricCollection   | 0      | train
2 | val_metrics   | MetricCollection   | 0      | train
3 | test_metrics  | MetricCollection   | 0      | train
4 | model         | TimeThenSpaceModel | 18.4 K | train
-------------------------------------------------------------
18.4 K    Trainable params
0         Non-trainable params
18.4 K    Total params
0.073     Total estimated model params size (MB)
29        Modules in train mode
0         Modules in eval mode


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

/opt/anaconda3/envs/torch-st/lib/python3.10/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:425: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=9` in the `DataLoader` to improve performance.


Sanity Checking DataLoader 0:   0%|          | 0/2 [00:00<?, ?it/s]

Only args ['edge_weight', 'x', 'edge_index'] are forwarded to the model (TimeThenSpaceModel).


                                                                           

/opt/anaconda3/envs/torch-st/lib/python3.10/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:425: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=9` in the `DataLoader` to improve performance.


Epoch 0: 100%|██████████| 100/100 [00:07<00:00, 12.55it/s, v_num=3, val_mae=6.900, val_mae_at_15=6.540, val_mae_at_30=6.530, val_mae_at_60=0.000, val_mape=0.150, train_mae=202.0, train_mae_at_15=202.0, train_mae_at_30=202.0, train_mae_at_60=0.000, train_mape=3.690]

`Trainer.fit` stopped: `max_epochs=1` reached.


Epoch 0: 100%|██████████| 100/100 [00:07<00:00, 12.52it/s, v_num=3, val_mae=6.900, val_mae_at_15=6.540, val_mae_at_30=6.530, val_mae_at_60=0.000, val_mape=0.150, train_mae=202.0, train_mae_at_15=202.0, train_mae_at_30=202.0, train_mae_at_60=0.000, train_mape=3.690]


In [43]:
predictor.load_model(checkpoint_callback.best_model_path)
predictor.freeze()

trainer.test(predictor, datamodule=dm);

  storage = torch.load(filename, lambda storage, loc: storage)
Predictor with already instantiated model is loading a state_dict from /Users/HunJ/Documents/ResearchLab/CodeBases/temporal-gnn/notebooks/logs/epoch=0-step=100-v1.ckpt. Cannot  check if model hyperparameters are the same.
/opt/anaconda3/envs/torch-st/lib/python3.10/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:425: The 'test_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=9` in the `DataLoader` to improve performance.


Testing DataLoader 0: 100%|██████████| 108/108 [00:03<00:00, 30.79it/s]
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
       Test metric             DataLoader 0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
        test_loss            7.02519416809082
        test_mae             7.313732624053955
     test_mae_at_15          6.849016189575195
     test_mae_at_30          7.109511375427246
     test_mae_at_60                 0.0
        test_mape           0.1687723696231842
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
