In [62]:
import tsl
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch_geometric
import math

from torch.optim import Adam
from tsl.datasets import PeMS04, PeMS07, PeMS08, PemsBay
from tsl.datasets import MetrLA
from tsl.data import SpatioTemporalDataset
from tsl.data.datamodule import (SpatioTemporalDataModule,
                                 TemporalSplitter)
from tsl.data.preprocessing import StandardScaler
from einops.layers.torch import Rearrange

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

from tqdm import tqdm

In [31]:
dataset_MetrLA = MetrLA(root='data/MetrLA')

# get_connectivity uses get_similarity under the hood
connectivity = dataset_MetrLA.get_connectivity(threshold=0.1, include_self=False, normalize_axis=1, layout="edge_index")

# subclass of torch.utils.data.Dataset
torch_dataset = SpatioTemporalDataset(
    target=dataset_MetrLA.dataframe(),
    connectivity=connectivity,
    mask=dataset_MetrLA.mask,
    horizon=6,
    window=12,
    stride=1
)

scalers = {'target': StandardScaler(axis=(0, 1))}

# Split data sequentially:
#   |------------ dataset -----------|
#   |--- train ---|- val -|-- test --|
splitter = TemporalSplitter(val_len=0.1, test_len=0.2)

dm = SpatioTemporalDataModule(
    dataset=torch_dataset,
    scalers=scalers,
    splitter=splitter,
    batch_size=64,
)

dm.setup()

train_loader = dm.train_dataloader()
val_loader = dm.val_dataloader()
test_loader = dm.test_dataloader()

  date_range = pd.date_range(df.index[0], df.index[-1], freq='5T')
  df = df.replace(to_replace=0., method='ffill')


In [60]:
sample1 = torch_dataset[:10]
sample1

StaticBatch(
  input=(x=[b=10, t=12, n=207, f=1], edge_index=[2, e=1515], edge_weight=[e=1515]),
  target=(y=[b=10, t=6, n=207, f=1]),
  has_mask=True,
  transform=[x, y]
)

In [2]:
data = torch.randn(1, 2, 12, 1)

In [64]:
class PositionalEncoding(nn.Module):
    """
    Implements the classic sinusoidal positional encoding.
    
    For an input tensor of shape (B, T, d_model) or (B, N, T, d_model),
    it adds a positional encoding to every token along the time axis.
    """
    def __init__(self, d_model, dropout=0.1, max_len=500):
        """
        Args:
            d_model (int): Dimensionality of the token embeddings.
            dropout (float): Dropout rate applied after adding PE.
            max_len (int): Maximum sequence length to precompute positional encoding.
        """
        super(PositionalEncoding, self).__init__()
        self.dropout = nn.Dropout(p=dropout)

        # Create a (max_len, d_model) matrix; each row is the positional encoding for that time step.
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        # Compute the division term using the logarithm of 10000 (a common choice)
        div_term = torch.exp(torch.arange(0, d_model, 2, dtype=torch.float) * 
                             (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)  # even indices
        pe[:, 1::2] = torch.cos(position * div_term)  # odd indices
        pe = pe.unsqueeze(0)  # Shape: (1, max_len, d_model)
        self.register_buffer('pe', pe)  # pe is not a parameter, but persistent

    def forward(self, x):
        """
        Args:
            x (torch.Tensor): Input embedding with shape either (B, T, d_model)
                              or (B, N, T, d_model).
        Returns:
            torch.Tensor: Output tensor after adding positional encodings.
        """
        if x.dim() == 3:
            # x has shape (B, T, d_model)
            x = x + self.pe[:, :x.size(1)]
        elif x.dim() == 4:
            # x has shape (B, N, T, d_model). Expand pe to (1,1,T,d_model) and add along the time dimension.
            x = x + self.pe[:, :x.size(2)].unsqueeze(1)
        else:
            raise ValueError("Unsupported input dimension for PositionalEncoding")
        return self.dropout(x)

In [69]:
class TemporalAttention(nn.Module):
    def __init__(self, input_size, hidden_size, window_size, num_heads=8, dropout=0.6):
        super(TemporalAttention, self).__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.window_size = window_size
        self.num_heads = num_heads
        self.dropout = dropout
        
        self.encoder = nn.Linear(input_size, hidden_size)
        self.pe = PositionalEncoding(hidden_size)

        self.self_q = nn.Linear(hidden_size, hidden_size * num_heads)
        self.self_k = nn.Linear(hidden_size, hidden_size * num_heads)
        self.self_v = nn.Linear(hidden_size, hidden_size * num_heads)
        
        self.cross_q = nn.Linear(hidden_size, hidden_size * num_heads)
        self.cross_k = nn.Linear(hidden_size, hidden_size * num_heads)
        self.cross_v = nn.Linear(hidden_size, hidden_size * num_heads)

        self.out = nn.Linear(hidden_size * num_heads, hidden_size)
        self.dropout = nn.Dropout(dropout)
        
        self.scale = hidden_size ** 0.5
        causal_mask = torch.triu(torch.ones(window_size, window_size), diagonal=1).bool()
        self.register_buffer('causal_mask', causal_mask)
    
    def forward(self, x):
        x = x.permute(0, 2, 1, 3)
        B, N, T, D = x.size()
        
        x = self.encoder(x)
        x = self.pe(x)
        print("x pe: ", x.size())
        
        Q_self = self.self_q(x).view(B, N, T, self.num_heads, self.hidden_size).transpose(2, 3)
        K_self = self.self_k(x).view(B, N, T, self.num_heads, self.hidden_size).transpose(2, 3)
        V_self = self.self_v(x).view(B, N, T, self.num_heads, self.hidden_size).transpose(2, 3)
        
        e_self = (Q_self @ K_self.mT) / self.scale
        e_self = e_self.masked_fill(self.causal_mask, float('-inf'))
        print("mask size: ", self.causal_mask.size())
        print("e_self: ", e_self.size())
        
        attention_Self = F.softmax(e_self, dim=-1)
        attention_Self = self.dropout(attention_Self)
        out_self = attention_Self @ V_self
        print("out_self: ", out_self.size())
        
        out_self = out_self.transpose(2, 3).contiguous().view(B, N, T, self.hidden_size * self.num_heads)
        out_self = self.out(out_self)
        print("out_self (after out proj): ", out_self.size())

In [70]:
STA = TemporalAttention(input_size=1, hidden_size=4, window_size=12, num_heads=8, dropout=0.6)
STA(sample1.x)

x pe:  torch.Size([10, 207, 12, 4])
mask size:  torch.Size([12, 12])
e_self:  torch.Size([10, 207, 8, 12, 12])
out_self:  torch.Size([10, 207, 8, 12, 4])
out_self (after out proj):  torch.Size([10, 207, 12, 4])


In [54]:
# iterative version
class IterCrossAttention(nn.Module):
    def __init__(self, input_size, hidden_size, window_size, num_heads=8, dropout=0.6):
        super(IterCrossAttention, self).__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.window_size = window_size
        self.num_heads = num_heads
        self.dropout = dropout
        
        self.cross_q = nn.Linear(input_size, hidden_size * num_heads)
        self.cross_k = nn.Linear(input_size, hidden_size * num_heads)
        self.cross_v = nn.Linear(input_size, hidden_size * num_heads)

        self.out = nn.Linear(hidden_size * num_heads, hidden_size)
        self.dropout = nn.Dropout(dropout)
        
        self.scale = hidden_size ** 0.5
        self.causal_mask = torch.triu(torch.ones(window_size, window_size), diagonal=1).bool()
    
    def forward(self, x):
        x = x.permute(0, 2, 1, 3)
        B, N, T, D = x.size()
        
        out_cross_list = []
        for i in range(N):
            x_i = x[:, i, :, :] # shape (B, T, D)
            Q_i = self.cross_q(x_i).view(B, T, self.num_heads, self.hidden_size).transpose(1, 2) # shape (B, num_heads, T, hidden_size)
            
            N_i = [j for j in range(N) if j != i]
            x_neighbors = x[:, N_i, :, :] # shape (B, N-1, T, D)
            
            K_j = self.cross_k(x_neighbors).view(B, (N-1) * T, self.num_heads, self.hidden_size).transpose(1, 2) # shape (B, num_heads, (N-1)*T, hidden_size)
            V_j = self.cross_v(x_neighbors).view(B, (N-1) * T, self.num_heads, self.hidden_size).transpose(1, 2) # shape (B, num_heads, (N-1)*T, hidden_size)
            
            print("Q, K, V: ", Q_i.size(), K_j.size(), V_j.size())
            
            e_cross = (Q_i @ K_j.mT) / self.scale
            print("e_cross: ", e_cross.size())
            
            # expand the causal mask. for each node, we have (N-1) blocks each of size (T * T)
            expanded_mask = self.causal_mask.repeat(1, (N-1))
            e_cross = e_cross.masked_fill(expanded_mask, float('-inf'))
            
            attention_cross = F.softmax(e_cross, dim=-1)
            attention_cross = self.dropout(attention_cross)
            
            out_i = attention_cross @ V_j
            print("out_i: ", out_i.size())
            out_i = out_i.transpose(1, 2).contiguous().view(B, T, self.num_heads * self.hidden_size)
            out_i = self.out(out_i)
            print("out_i (after out proj): ", out_i.size())
            
            out_cross_list.append(out_i)
        
        out_cross = torch.stack(out_cross_list, dim=1) # shape (B, N, T, hidden_size)
        print("out_cross: ", out_cross.size())

In [6]:
ic_attention = IterCrossAttention(input_size=1, hidden_size=4, window_size=12, num_heads=8, dropout=0.6)
ic_attention(data)

Q, K, V:  torch.Size([1, 8, 12, 4]) torch.Size([1, 8, 12, 4]) torch.Size([1, 8, 12, 4])
e_cross:  torch.Size([1, 8, 12, 12])
out_i:  torch.Size([1, 8, 12, 4])
out_i (after out proj):  torch.Size([1, 12, 4])
Q, K, V:  torch.Size([1, 8, 12, 4]) torch.Size([1, 8, 12, 4]) torch.Size([1, 8, 12, 4])
e_cross:  torch.Size([1, 8, 12, 12])
out_i:  torch.Size([1, 8, 12, 4])
out_i (after out proj):  torch.Size([1, 12, 4])
out_cross:  torch.Size([1, 2, 12, 4])


In [36]:
def build_combined_mask(edge_index, edge_weight, N, T, mask_self=False):
    """
    Constructs a combined mask of shape (N*T, N*T) that embeds:
      - A temporal causal mask (T x T), and
      - A graph connectivity mask derived from edge_index and edge_weight.
    
    For every pair of nodes, if a connection exists (edge_weight > 0),
    the corresponding (T, T) block equals the temporal mask;
    otherwise, it is set to -inf.
    
    Arguments:
      edge_index: LongTensor of shape (2, E) where each column is (source, target).
      edge_weight: Tensor of shape (E,) with positive weights (e.g. 1).
      N: Number of nodes.
      T: Number of time steps.
      device: torch.device.
    
    Returns:
      A mask of shape (N*T, N*T) to be added to the attention scores.
    """
    # Build a dense connectivity indicator of shape (N, N)
    # For each edge (u->v): set A[v,u] = edge_weight (so only when there is an edge, we want to allow attention).
    A = torch.zeros((N, N))
    A[edge_index[1], edge_index[0]] = edge_weight  # Note: our convention: edge (u, v) means u -> v

    # Build a connectivity mask: if A > 0, we allow attention (0 added), else, we want to block by setting to -inf.
    connectivity_indicator = torch.where(A > 0, torch.zeros_like(A), torch.full_like(A, float('-inf')))
    
    if mask_self:
        # Mask self-attention: we do not want nodes to attend to themselves.
        connectivity_indicator.fill_diagonal_(float('-inf'))
    
    # Create a dense temporal (causal) mask of shape (T, T)
    # For example, we use a lower-triangular mask to prevent query at time t from attending to keys at future times
    temporal_mask = torch.triu(torch.ones((T, T)) * float('-inf'), diagonal=1)
    
    """
    Now, we “lift” these masks to the full (N*T, N*T) mask.
    For the graph part, we create a block mask using a Kronecker product.
    The idea is:
       mask = kron(connectivity_indicator, ones(T, T)) + kron(ones(N, N), temporal_mask)
    For a connected node pair, connectivity_indicator==0 so that block becomes 0 + temporal_mask,
    and for an unconnected pair, connectivity_indicator==-inf so block remains -inf.
    """
    mask = torch.kron(connectivity_indicator, torch.ones((T, T))) + torch.kron(torch.ones((N, N)), temporal_mask)
    
    return mask.bool()  # shape: (N*T, N*T)

In [71]:
# iterative version
class SpatioTemporalAttention(nn.Module):
    def __init__(self, input_size, hidden_size, num_nodes, window_size, edge_index, edge_weight, num_heads=8, dropout=0.6, mask_self=False):
        super(SpatioTemporalAttention, self).__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.num_nodes = num_nodes
        self.window_size = window_size
        self.num_heads = num_heads
        self.dropout = dropout
        
        self.encoder = nn.Linear(input_size, hidden_size)
        self.pe = PositionalEncoding(hidden_size)
        
        self.cross_q = nn.Linear(hidden_size, hidden_size * num_heads)
        self.cross_k = nn.Linear(hidden_size, hidden_size * num_heads)
        self.cross_v = nn.Linear(hidden_size, hidden_size * num_heads)

        self.out = nn.Linear(hidden_size * num_heads, hidden_size)
        self.dropout = nn.Dropout(dropout)
        
        self.scale = hidden_size ** 0.5
        
        # includes self-loop for self-attention + causal mask for ST attention
        # self.causal_mask = torch.triu(torch.ones(num_nodes * window_size, num_nodes * window_size), diagonal=1).bool()
        self.mask = build_combined_mask(edge_index, edge_weight, num_nodes, window_size)
    
    def forward(self, x):
        x = x.permute(0, 2, 1, 3)
        B, N, T, D = x.size()
        
        x = self.encoder(x)
        x = self.pe(x)
        
        # (B, N, T, D) -> (B, N, T, num_heads, hidden_size) -> (B, num_heads, N, T, hidden_size)
        Q = self.cross_q(x).view(B, N, T, self.num_heads, self.hidden_size).transpose(1, 3)
        K = self.cross_k(x).view(B, N, T, self.num_heads, self.hidden_size).transpose(1, 3)
        V = self.cross_v(x).view(B, N, T, self.num_heads, self.hidden_size).transpose(1, 3)
        
        print("Q, K, V: ", Q.size(), K.size(), V.size())
        
        # (B, num_heads, N * T, hidden_size)
        Q_flat = Q.reshape(B, self.num_heads, N * T, self.hidden_size)
        K_flat = K.reshape(B, self.num_heads, N * T, self.hidden_size)
        V_flat = V.reshape(B, self.num_heads, N * T, self.hidden_size)
        
        print("Q_flat, K_flat, V_flat: ", Q_flat.size(), K_flat.size(), V_flat.size())
        
        e = Q_flat @ K_flat.mT / self.scale # (B, num_heads, N * T, N * T)
        e = e.masked_fill(self.mask, float('-inf'))
        
        print("e: ", e.size())
        print("mask size: ", self.mask.size())
        
        attention = F.softmax(e, dim=-1)
        attention = self.dropout(attention)
        
        out_flat = attention @ V_flat # (B, num_heads, N * T, hidden_size)
        print("out_flat: ", out_flat.size())
        out = out_flat.reshape(B, self.num_heads, N, T, self.hidden_size).transpose(1,3) # (B, N, T, num_heads, hidden_size)
        print("out: ", out.size())
        out = out.contiguous().view(B, N, T, self.num_heads * self.hidden_size)
        out = self.out(out)
        print("out (after out proj): ", out.size())
        
        return out

In [72]:
# edge_index = torch.tensor([[0, 1], [1, 0]], dtype=torch.long)
# edge_weight = torch.tensor([1, 1], dtype=torch.float)
edge_index = sample1.edge_index
edge_weight = sample1.edge_weight

sta = SpatioTemporalAttention(input_size=1, hidden_size=4, num_nodes=sample1.x.shape[2], window_size=12, edge_index=edge_index, edge_weight=edge_weight, num_heads=8, dropout=0.6)
print("original data: ", sample1.x.size())
out = sta(sample1.x)

original data:  torch.Size([10, 12, 207, 1])
Q, K, V:  torch.Size([10, 8, 12, 207, 4]) torch.Size([10, 8, 12, 207, 4]) torch.Size([10, 8, 12, 207, 4])
Q_flat, K_flat, V_flat:  torch.Size([10, 8, 2484, 4]) torch.Size([10, 8, 2484, 4]) torch.Size([10, 8, 2484, 4])
e:  torch.Size([10, 8, 2484, 2484])
mask size:  torch.Size([2484, 2484])
out_flat:  torch.Size([10, 8, 2484, 4])
out:  torch.Size([10, 12, 207, 8, 4])
out (after out proj):  torch.Size([10, 207, 12, 4])


### Sparse mask build

In [37]:
def build_sparse_combined_mask(edge_index, edge_weight, N, T, device, mask_self=False):
    """
    Constructs a sparse combined mask of shape (N*T, N*T) that fuses:
      - A temporal (causal) mask (only allowing keys at or before each time step), and 
      - A graph connectivity mask based on edge_index and edge_weight.
    
    For each valid edge (from source u to target v) – where by convention we 
    set A[v,u] from edge_index and edge_weight – we allow a block of size T×T 
    corresponding to that node pair. Within the block only the lower-triangular 
    positions (t_query >= t_key) are allowed (with value 0, meaning no penalty). 
    All other positions should be (conceptually) -inf.
    
    Args:
        edge_index : LongTensor of shape (2, E). Each column is an edge (u, v)
                     meaning node u feeds into node v.
        edge_weight: Tensor of shape (E,) with positive weights.
        N          : Number of nodes.
        T          : Number of time steps.
        device     : torch.device.
        mask_self  : If True, even if an edge exists for a self-connection,
                     its block will be entirely disallowed.
    
    Returns:
        sparse_mask: A sparse COO tensor of shape (N*T, N*T) that stores
                     the allowed positions (with value 0). When applying the mask,
                     one should treat missing entries as -inf.
    """
    # Filter out self-connections if needed
    if mask_self:
        valid = edge_index[0] != edge_index[1]
        valid_edges = edge_index[:, valid]
    else:
        valid_edges = edge_index

    E = valid_edges.shape[1]

    # Obtain lower-triangular indices for a T x T block.
    # These indices indicate positions where t_query >= t_key.
    # tril_indices returns a tensor of shape (2, L) where L = T*(T+1)//2.
    tril = torch.tril_indices(T, T, device=device)  # shape (2, L)
    L_val = tril.shape[1]  # number of allowed temporal positions per block

    # For each valid edge, the block corresponds to:
    #   Row block:   target node's time steps, i.e., indices: v * T + t_query.
    #   Column block: source node's time steps, i.e., indices: u * T + t_key.
    # valid_edges[0] are source node indices, valid_edges[1] are target node indices.
    # We use broadcasting to generate all allowed indices.

    # Compute row indices: (E, L) tensor where each row corresponds to a valid edge.
    row_indices = valid_edges[1].unsqueeze(1) * T + tril[0].unsqueeze(0)
    # Compute column indices in the same way.
    col_indices = valid_edges[0].unsqueeze(1) * T + tril[1].unsqueeze(0)

    # Flatten the indices so that each valid edge contributes L_val entries.
    row_indices = row_indices.reshape(-1)
    col_indices = col_indices.reshape(-1)
    sparse_indices = torch.stack([row_indices, col_indices], dim=0)  # shape (2, E * L_val)

    # The allowed positions (in the causal part) get a value equal to the temporal mask.
    # With our construction, the temporal mask is defined as:
    #    0 if t_query >= t_key, and -inf otherwise.
    # Here we only store positions where t_query >= t_key, so we set value 0.
    values = torch.zeros(E * L_val, device=device)

    # Construct the sparse mask—in our sparse representation, missing entries are implicitly 0.
    # To treat missing entries as -inf during attention, later on you can compute:
    #   combined_mask = dense_full + sparse_mask.to_dense()
    # where dense_full is a (N*T, N*T) tensor filled with -inf.
    sparse_mask = torch.sparse_coo_tensor(sparse_indices, values, size=(N*T, N*T), device=device)
    return sparse_mask

# Example usage:
if __name__ == '__main__':
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    N = 5    # number of nodes
    T = 12   # number of time steps

    # Example edge_index: shape (2, E)
    # Here, edge (u, v) means node u's features are used to attend to node v.
    edge_index = torch.tensor([
        [0, 2, 3, 1, 4, 0],  # source nodes
        [1, 0, 4, 4, 2, 3]   # target nodes
    ], dtype=torch.long, device=device)
    edge_weight = torch.ones(edge_index.shape[1], device=device)

    # Build sparse mask; set mask_self=True to mask self-attention.
    sparse_mask = build_sparse_combined_mask(edge_index, edge_weight, N, T, device, mask_self=True)
    print("Sparse mask indices shape:", sparse_mask._indices().shape)
    print("Sparse mask values shape:", sparse_mask._values().shape)

    # Note: When applying this mask to attention scores (of shape (B, num_heads, N*T, N*T)),
    # you may want to convert this sparse mask to a dense tensor and then add it to the attention scores:
    #   dense_mask = torch.full((N*T, N*T), float('-inf'), device=device)
    #   dense_mask = dense_mask + sparse_mask.to_dense()
    # Or, if you can leverage a sparse-aware attention function, have it interpret missing entries as -inf.


Sparse mask indices shape: torch.Size([2, 468])
Sparse mask values shape: torch.Size([468])
