In [1]:
import tsl
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch_geometric

from torch.optim import Adam
from tsl.datasets import PeMS04, PeMS07, PeMS08, PemsBay
from tsl.datasets import MetrLA
from tsl.data import SpatioTemporalDataset
from tsl.data.datamodule import (SpatioTemporalDataModule,
                                 TemporalSplitter)
from tsl.data.preprocessing import StandardScaler
from einops.layers.torch import Rearrange

import numpy as np
import pandas as pd

from tqdm import tqdm

In [2]:
dataset_MetrLA = MetrLA(root='data/MetrLA')

# get_connectivity uses get_similarity under the hood
connectivity = dataset_MetrLA.get_connectivity(threshold=0.1, include_self=False, normalize_axis=1, layout="edge_index")

# subclass of torch.utils.data.Dataset
torch_dataset = SpatioTemporalDataset(
    target=dataset_MetrLA.dataframe(),
    connectivity=connectivity,
    mask=dataset_MetrLA.mask,
    horizon=6,
    window=18,
    stride=1
)

scalers = {'target': StandardScaler(axis=(0, 1))}

# Split data sequentially:
#   |------------ dataset -----------|
#   |--- train ---|- val -|-- test --|
splitter = TemporalSplitter(val_len=0.1, test_len=0.2)

dm = SpatioTemporalDataModule(
    dataset=torch_dataset,
    scalers=scalers,
    splitter=splitter,
    batch_size=64,
)

dm.setup()

train_loader = dm.train_dataloader()
val_loader = dm.val_dataloader()
test_loader = dm.test_dataloader()

  date_range = pd.date_range(df.index[0], df.index[-1], freq='5T')
  df = df.replace(to_replace=0., method='ffill')


### Modeling Attention Schemes with ST data

In [65]:
# feature vectorization approach (flatten the temporal dim along the feature dim)
class SelfAttention(nn.Module):
    def __init__(self, input_size, hidden_size, window_size, num_heads=8, dropout=0.6):
        super(SelfAttention, self).__init__()
        
        self.hidden_size = hidden_size
        self.num_heads = num_heads
        
        self.query = nn.Linear(input_size * window_size, hidden_size * num_heads)
        self.key = nn.Linear(input_size * window_size, hidden_size * num_heads)
        self.value = nn.Linear(input_size * window_size, hidden_size * num_heads)
        self.out = nn.Linear(hidden_size * num_heads, hidden_size)
        
        self.dropout = nn.Dropout(dropout)
        
        self.scale = self.hidden_size ** 0.5

    def forward(self, x):
        # feature vectorization: flatten the temporal dim along the feature dim
        b, t, n, f = x.size()
        x = x.permute(0, 2, 1, 3).reshape(b, n, t * f)

        Q = self.query(x)
        K = self.key(x)
        V = self.value(x)
        
        # Reshape for multi-head attention
        Q = Q.view(b, n, self.num_heads, self.hidden_size).transpose(1, 2)
        K = K.view(b, n, self.num_heads, self.hidden_size).transpose(1, 2)
        V = V.view(b, n, self.num_heads, self.hidden_size).transpose(1, 2)
        
        # e = torch.matmul(Q, K.transpose(-2, -1)) / self.scale
        e = (Q @ K.mT) / self.scale
        attention = F.softmax(e, dim=-1)
        
        # Apply dropout to attention weights
        attention = self.dropout(attention)
        
        out = attention @ V
        
        # combine the heads
        out = out.transpose(1, 2).contiguous().view(b, n, self.hidden_size * self.num_heads)
        out = self.out(out)
        
        return out

In [66]:
attention_block = SelfAttention(input_size=1, hidden_size=4, window_size=18, num_heads=8, dropout=0.6)

for batch in train_loader:
    x, edge_index, edge_weight, y = batch.x, batch.edge_index, batch.edge_weight, batch.y
    out = attention_block(x)
    break