In [1]:
import tsl
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch_geometric

from torch.optim import Adam
from tsl.datasets import PeMS04, PeMS07, PeMS08, PemsBay
from tsl.datasets import MetrLA
from tsl.data import SpatioTemporalDataset
from tsl.data.datamodule import (SpatioTemporalDataModule,
                                 TemporalSplitter)
from tsl.data.preprocessing import StandardScaler
from einops.layers.torch import Rearrange

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

from tqdm import tqdm

In [2]:
dataset_MetrLA = MetrLA(root='data/MetrLA')

# get_connectivity uses get_similarity under the hood
connectivity = dataset_MetrLA.get_connectivity(threshold=0.1, include_self=False, normalize_axis=1, layout="edge_index")

# subclass of torch.utils.data.Dataset
torch_dataset = SpatioTemporalDataset(
    target=dataset_MetrLA.dataframe(),
    connectivity=connectivity,
    mask=dataset_MetrLA.mask,
    horizon=6,
    window=12,
    stride=1
)

scalers = {'target': StandardScaler(axis=(0, 1))}

# Split data sequentially:
#   |------------ dataset -----------|
#   |--- train ---|- val -|-- test --|
splitter = TemporalSplitter(val_len=0.1, test_len=0.2)

dm = SpatioTemporalDataModule(
    dataset=torch_dataset,
    scalers=scalers,
    splitter=splitter,
    batch_size=64,
)

dm.setup()

train_loader = dm.train_dataloader()
val_loader = dm.val_dataloader()
test_loader = dm.test_dataloader()

  date_range = pd.date_range(df.index[0], df.index[-1], freq='5T')
  df = df.replace(to_replace=0., method='ffill')


In [None]:
sample_data = torch_dataset[:10]
sample_data

StaticBatch(
  input=(x=[b=10, t=12, n=207, f=1], edge_index=[2, e=1515], edge_weight=[e=1515]),
  target=(y=[b=10, t=6, n=207, f=1]),
  has_mask=True,
  transform=[x, y]
)

In [20]:
x = sample_data.x
B, L, N, D = x.shape

print("x: ", x.shape)

hidden = 24
num_heads = 8
hidden_per_head = hidden // num_heads

print("hidden: ", hidden, " num heads: ", num_heads, " hidden per head: ", hidden_per_head)

query = nn.Linear(1, hidden)
key = nn.Linear(1, hidden)
value = nn.Linear(1, hidden)
out_proj = nn.Linear(hidden, 1)

Q = query(x).view(B, L, N, num_heads, hidden_per_head)
K = key(x).view(B, L, N, num_heads, hidden_per_head)
V = value(x).view(B, L, N, num_heads, hidden_per_head)

# Rearrange for batch computation across nodes
Q = Q.permute(0, 2, 3, 1, 4)  # (B, N, num_heads, L, head_dim)
K = K.permute(0, 2, 3, 1, 4)  # (B, N, num_heads, L, head_dim)
V = V.permute(0, 2, 3, 1, 4)  # (B, N, num_heads, L, head_dim)

print("Q, K, V: ", Q.shape, K.shape, V.shape)

attn = Q @ K.mT / hidden_per_head ** 0.5
attn = F.softmax(attn, dim=-1)

print("attn: ", attn.shape)

out = attn @ V # (B, N, num_heads, L, head_dim)
out = out.permute(0, 3, 1, 2, 4).contiguous()  # (B, L, N, num_heads, head_dim)
out = out.view(B, L, N, hidden)
out = out_proj(out)

print("out: ", out.shape)

x:  torch.Size([10, 12, 207, 1])
hidden:  24  num heads:  8  hidden per head:  3
Q, K, V:  torch.Size([10, 207, 8, 12, 3]) torch.Size([10, 207, 8, 12, 3]) torch.Size([10, 207, 8, 12, 3])
attn:  torch.Size([10, 207, 8, 12, 12])
out:  torch.Size([10, 12, 207, 1])


### Modeling Attention Schemes with ST data

In [14]:
# scratch pad

x = sample_data.x[:10]
x_flat = x.permute(0, 2, 1, 3).reshape(10, 207, 12)
print("input shape: ", x_flat.shape)

num_heads = 8

query = nn.Linear(1 * 12, 4 * num_heads)
key = nn.Linear(1 * 12, 4 * num_heads)
value = nn.Linear(1 * 12, 4 * num_heads)

Q = query(x_flat).unsqueeze(-1)
K = key(x_flat).unsqueeze(-1)
V = value(x_flat).unsqueeze(-1)
print("build block shapes (Q, K, V): ", Q.shape, K.shape, V.shape)

# Reshape for multi-head attention
Q = Q.view(10, 207, num_heads, 4, 1).transpose(1, 2)
K = K.view(10, 207, num_heads, 4, 1).transpose(1, 2)
V = V.view(10, 207, num_heads, 4, 1).transpose(1, 2)

print("build block shapes (Q, K, V): ", Q.shape, K.shape, V.shape)

attn = F.scaled_dot_product_attention(Q, K, V, attn_mask=None).squeeze(-1)
print("attention shape: ", attn.shape)

out = attn.transpose(1, 2).reshape(10, 207, 4 * num_heads)
print("output shape: ", out.shape)

out = nn.Linear(4 * num_heads, 4)(out)
print("output shape (after proj): ", out.shape)

# e = (Q @ K.mT) / 4 ** 0.5
# print("energy shape: ", e.shape)

# attention = F.softmax(e, dim=-1)
# out = (attention @ V).squeeze(-1)
# print("output shape: ", out.shape)

input shape:  torch.Size([10, 207, 12])
build block shapes (Q, K, V):  torch.Size([10, 207, 32, 1]) torch.Size([10, 207, 32, 1]) torch.Size([10, 207, 32, 1])
build block shapes (Q, K, V):  torch.Size([10, 8, 207, 4, 1]) torch.Size([10, 8, 207, 4, 1]) torch.Size([10, 8, 207, 4, 1])
attention shape:  torch.Size([10, 8, 207, 4])
output shape:  torch.Size([10, 207, 32])
output shape (after proj):  torch.Size([10, 207, 4])


In [51]:
# feature vectorization approach (flatten the temporal dim along the feature dim)
class SelfAttention(nn.Module):
    def __init__(self, input_size, hidden_size, window_size, horizon, num_heads=8, dropout=0.6):
        super(SelfAttention, self).__init__()
        
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.window_size = window_size
        self.horizon = horizon
        self.num_heads = num_heads
        
        self.query = nn.Linear(input_size * window_size, hidden_size * num_heads)
        self.key = nn.Linear(input_size * window_size, hidden_size * num_heads)
        self.value = nn.Linear(input_size * window_size, hidden_size * num_heads)
        self.out = nn.Linear(hidden_size * num_heads, hidden_size)
        
        self.dropout = nn.Dropout(dropout)
        
        # STGNN specific utility functions
        self.decoder = nn.Linear(hidden_size, input_size * horizon)
        
        self.scale = self.hidden_size ** 0.5

    def forward(self, x):
        # feature vectorization: flatten the temporal dim along the feature dim
        b, t, n, f = x.size()
        x = x.permute(0, 2, 1, 3).reshape(b, n, t * f)
        
        # masks out the upper triangle (diagonal not included)
        self_causal_mask = torch.triu(torch.ones(b, self.num_heads, n, self.hidden_size, self.hidden_size), diagonal=1).bool()

        Q = self.query(x).unsqueeze(-1)
        K = self.key(x).unsqueeze(-1)
        V = self.value(x).unsqueeze(-1)
        
        # Reshape for multi-head attention
        Q = Q.view(b, n, self.num_heads, self.hidden_size, 1).transpose(1, 2)
        K = K.view(b, n, self.num_heads, self.hidden_size, 1).transpose(1, 2)
        V = V.view(b, n, self.num_heads, self.hidden_size, 1).transpose(1, 2)
        
        # e = torch.matmul(Q, K.transpose(-2, -1)) / self.scale 
        e = (Q @ K.mT) / self.scale
        e = e.masked_fill(self_causal_mask, float('-inf'))
        attention = F.softmax(e, dim=-1)
        
        # Apply dropout to attention weights
        attention = self.dropout(attention)
        
        out = (attention @ V).squeeze(-1)
        
        # combine the heads
        out = out.transpose(1, 2).contiguous().view(b, n, self.hidden_size * self.num_heads)
        out = self.out(out)
        
        out_horizon = self.decoder(out)
        out_horizon = out_horizon.view(b, n, self.horizon, f).permute(0, 2, 1, 3)
        
        return out_horizon
    

In [52]:
sample_model = SelfAttention(input_size=1, hidden_size=4, window_size=12, horizon=6, num_heads=8, dropout=0.6)

out = sample_model(x)

In [53]:
attention_block = SelfAttention(input_size=1, hidden_size=4, window_size=12, horizon=6, num_heads=8, dropout=0.6)

epochs = 1
criterion = nn.MSELoss()
optimizer = Adam(attention_block.parameters(), lr=0.001)

for epoch in range(epochs):
    attention_block.train()
    for batch in train_loader:
        x, edge_index, edge_weight, y = batch.x, batch.edge_index, batch.edge_weight, batch.y
        
        out = attention_block(x)
        loss = criterion(out, y)
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
    print(f'Epoch {epoch+1}/{epochs} Loss: {loss.item()}')
    break

eval_loss = 0
attention_block.eval()
with torch.no_grad():
    for batch in test_loader:
        x, edge_index, edge_weight, y = batch.x, batch.edge_index, batch.edge_weight, batch.y

        out = attention_block(x)
        eval_loss += criterion(out, y).item()

print(f'Evaluation Loss: {eval_loss / len(test_loader)}')

Epoch 1/1 Loss: 706.1438598632812
Evaluation Loss: 533.530632019043


### Cross Attention for neighborhood aggregation (TraverseNet)

In [4]:
x = sample_data.x
x_flat = x.permute(0, 2, 1, 3).reshape(10, 207, 12)
print("x shape after flattening time (query): ", x_flat.shape)

cw = 6 # cross attention window
x_flat_cw = x_flat[:,:,cw:]
print("cross attention shape (key, val): ", x_flat_cw.shape)

x shape after flattening time (query):  torch.Size([10, 207, 12])
cross attention shape (key, val):  torch.Size([10, 207, 6])


In [12]:
num_heads = 8

query = nn.Linear(1 * 12, 4 * num_heads)
key = nn.Linear(1 * 12, 4 * num_heads)
value = nn.Linear(1 * 12, 4 * num_heads)

Q = query(x_flat)
K = key(x_flat)
V = value(x_flat)
print("Q, K, V: ", Q.shape, K.shape, V.shape)

e = (Q @ K.mT) / 4 ** 0.5
print("energy shape: ", e.shape)

attn = F.softmax(e, dim=-1)

out = attn @ V
print("output shape: ", out.shape)

Q, K, V:  torch.Size([10, 207, 32]) torch.Size([10, 207, 32]) torch.Size([10, 207, 32])
energy shape:  torch.Size([10, 207, 207])
output shape:  torch.Size([10, 207, 32])


In [10]:
class CrossAttention(nn.Module):
    def __init__(self, input_size, hidden_size, window_size, horizon, num_heads=8, dropout=0.6):
        super(CrossAttention, self).__init__()
        
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.window_size = window_size
        self.horizon = horizon
        self.num_heads = num_heads
        
        self.query = nn.Linear(input_size * window_size, hidden_size * num_heads)
        self.key = nn.Linear(input_size * window_size, hidden_size * num_heads)
        self.value = nn.Linear(input_size * window_size, hidden_size * num_heads)
        self.out = nn.Linear(hidden_size * num_heads, hidden_size)
        
        self.dropout = nn.Dropout(dropout)
        
        # STGNN specific utility functions
        self.decoder = nn.Linear(hidden_size, input_size * horizon)
        
        self.scale = self.hidden_size ** 0.5

    def forward(self, x):
        b, t, n, f = x.size()
        x = x.permute(0, 2, 1, 3).reshape(b, n, t * f)
        
        # masks out the upper triangle (diagonal not included)

        Q = self.query(x)
        K = self.key(x)
        V = self.value(x)
        
        e = (Q @ K.mT) / self.scale
        attention = F.softmax(e, dim=-1)
        
        # Apply dropout to attention weights
        attention = self.dropout(attention)
        
        out = (attention @ V).squeeze(-1)
        
        # combine the heads
        out = out.transpose(1, 2).contiguous().view(b, n, self.hidden_size * self.num_heads)
        out = self.out(out)
        
        out_horizon = self.decoder(out)
        out_horizon = out_horizon.view(b, n, self.horizon, f).permute(0, 2, 1, 3)
        
        return out_horizon

### Causal Masks

In [37]:
self_causal_mask = torch.triu(torch.ones(8, 4,4), diagonal=1).bool()
self_causal_mask

tensor([[[False,  True,  True,  True],
         [False, False,  True,  True],
         [False, False, False,  True],
         [False, False, False, False]],

        [[False,  True,  True,  True],
         [False, False,  True,  True],
         [False, False, False,  True],
         [False, False, False, False]],

        [[False,  True,  True,  True],
         [False, False,  True,  True],
         [False, False, False,  True],
         [False, False, False, False]],

        [[False,  True,  True,  True],
         [False, False,  True,  True],
         [False, False, False,  True],
         [False, False, False, False]],

        [[False,  True,  True,  True],
         [False, False,  True,  True],
         [False, False, False,  True],
         [False, False, False, False]],

        [[False,  True,  True,  True],
         [False, False,  True,  True],
         [False, False, False,  True],
         [False, False, False, False]],

        [[False,  True,  True,  True],
         [Fal