In [6]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GCNConv

# Spatio-Temporal Graph Neural Network models
This notebook will present an implementation of different Spatio-Temporal Graph Neural Network (STGNN) models, that will be used in the next notebooks.

## Graph Convolutional Network (GCN)
We define a simple Graph Convolutional Network (GCN) model, that will be used as a building block for the following models.

In [17]:
class GCN(nn.Module):
	"""
	Simple two layers GCN model.
	"""
	def __init__(self, in_channels: int, layer_sizes: list[int] = None, bias: bool = True):
		super(GCN, self).__init__()
		layer_sizes = layer_sizes or [32, 32]
		self.convs = nn.ModuleList([
			GCNConv(in_channels, layer_sizes[0], bias=bias),
		] + [
			GCNConv(layer_sizes[i], layer_sizes[i + 1], bias=bias) for i in range(len(layer_sizes) - 1)
		])
							   		
	def forward(self, x: torch.tensor, edge_index: torch.tensor, edge_weight: torch.tensor) -> torch.tensor:
		"""
		Performs a forward pass on the GCN model.
		:param x: The feature matrix of the graph X_t (Nodes_nb, Features_nb)
		:param edge_index: The edge index of the graph A (2, Edges_nb) 
		:param edge_weight: The edge weight of the graph (Edges_nb,)
		:return: The hidden state of the GCN h_t (Nodes_nb, Hidden_size)
		"""
		for conv in self.convs:
			x = F.relu(conv(x, edge_index, edge_weight))
		return x

## Temporal Graph Convolutional Network (T-GCN)
The Temporal Graph Convolutional Network (T-GCN) -- from the paper [T-GCN: A Temporal Graph Convolutional Network for Traffic Prediction](https://arxiv.org/pdf/1811.05320) -- is a simple model that uses GCN layers followed by a GRU layer. The GCN layers are used to extract the spatial features of the graph at a given timestep, and the GRU layer is used to model the temporal dependencies between the different timestamps.
However, in the case of stock prices, the data on the stock itself is very important. The drawback of using GCN is that the node feature is lost through the aggregation phase. To address this issue, we can concatenate the node feature with the output of the GCN layer before feeding it to the GRU layer.

In [18]:
class TGCNCell(nn.Module):
	"""
	T-GCN Cell for one timestep, from https://arxiv.org/pdf/1811.05320.
	"""
	def __init__(self, in_channels: int, hidden_size: int):
		super(TGCNCell, self).__init__()
		self.gcn = GCN(in_channels, [hidden_size, hidden_size])
		self.lin_u = nn.Linear(2 * hidden_size, hidden_size)
		self.lin_r = nn.Linear(2 * hidden_size, hidden_size)
		self.lin_c = nn.Linear(2 * hidden_size, hidden_size)
		
	def forward(self, x: torch.tensor, edge_index: torch.tensor, edge_weight: torch.tensor, h: torch.tensor) -> tuple[torch.tensor, torch.tensor]:
		"""
		Performs a forward pass on a single T-GCN cell (GCN + GRU).
		:param x: The feature matrix of the graph X_t (Nodes_nb, Features_nb)
		:param edge_index: The edge index of the graph A (2, Edges_nb) 
		:param edge_weight: The edge weight of the graph (Edges_nb,)
		:param h: The hidden state of the GRU h_{t-1} (Nodes_nb, Hidden_size)
		:return: The hidden state of the GRU h_t (Nodes_nb, Hidden_size)
		"""
		gcn_out = F.sigmoid(self.gcn(x, edge_index, edge_weight))  # f(A,X_t), Eq. 2
		u = F.sigmoid(self.lin_u(torch.cat([gcn_out, h], dim=-1)))  # u_t, Eq. 3
		r = F.sigmoid(self.lin_r(torch.cat([gcn_out, h], dim=-1)))  # r_t,  Eq. 4
		c = F.tanh(self.lin_c(torch.cat([gcn_out, r * h], dim=-1)))  # c_t, Eq. 5

		return u * h + (1 - u) * c  # h_t, Eq. 6

In [19]:
class TGCN(nn.Module):
	"""
	T-GCN model from https://arxiv.org/pdf/1811.05320.
	"""
	def __init__(self, in_channels: int, out_channels: int, hidden_size: int, layers_nb: int = 2):
		super(TGCN, self).__init__()
		self.hidden_size = hidden_size
		self.layers_nb = max(1, layers_nb)
		self.cells = nn.ModuleList(
			[TGCNCell(in_channels, hidden_size)] + [TGCNCell(hidden_size, hidden_size) for _ in range(self.layers_nb - 1)]
		)
		self.out = nn.Linear(hidden_size, out_channels)
		
	def forward(self, x: torch.tensor, edge_index: torch.tensor, edge_weight: torch.tensor) -> torch.tensor:
		"""
		Performs a forward pass on the T-GCN model.
		:param x: The feature matrix of the graph X_t (Nodes_nb, Features_nb, SeqLength)
		:param edge_index: The edge index of the graph A (2, Edges_nb) 
		:param edge_weight: The edge weight of the graph (Edges_nb,)
		:return: The output of the model (Nodes_nb, OutFeatures_nb)
		"""
		h_prev = [
			torch.zeros(x.shape[0], self.hidden_size) for _ in range(self.layers_nb)
		]
		for t in range(x.shape[-1]):
			h = x[:, :, t]  # h is the output of the previous GRU layer (the input features for the first layer)
			for i, cell in enumerate(self.cells):
				h = cell(h, edge_index, edge_weight, h_prev[i])
				h_prev[i] = h
		return self.out(h_prev[-1])

In [27]:
# Example of usage
model = TGCN(2, 1, 32)
x = torch.rand(10, 2, 5)
edge_index = torch.tensor(
	[[0, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7, 7, 8, 8, 9], 
	 [1, 0, 2, 1, 3, 2, 4, 3, 5, 4, 6, 5, 7, 6, 8, 7, 9, 8]]
)
edge_weight = torch.rand(edge_index.shape[1])
model(x, edge_index, edge_weight)

tensor([[0.0054],
        [0.0054],
        [0.0054],
        [0.0054],
        [0.0054],
        [0.0054],
        [0.0054],
        [0.0054],
        [0.0054],
        [0.0054]], grad_fn=<AddmmBackward0>)

## Attention Temporal Graph Convolutional Network (A3T-GCN)
The Attention Temporal Graph Convolutional Network (A3T-GCN) -- from the paper [A3T-GCN: Attention Temporal Graph Convolutional Network for Traffic Forecasting](https://arxiv.org/pdf/2006.11583) -- is a model that extends the T-GCN model by adding an attention mechanism to the hidden states computed by the GRU cells. The attention mechanism is used to learn the importance of the different nodes in the graph at a given timestep.

In [21]:
class A3TGCN(nn.Module):
	"""
	A3T-GCN model from https://arxiv.org/pdf/2006.11583.
	"""
	def __init__(self, in_channels: int, out_channels: int, hidden_size: int, layers_nb: int = 2):
		super(A3TGCN, self).__init__()
		self.hidden_size = hidden_size
		self.layers_nb = max(1, layers_nb)
		self.cells = nn.ModuleList(
			[TGCNCell(in_channels, hidden_size)] + [TGCNCell(hidden_size, hidden_size) for _ in range(self.layers_nb - 1)]
		)
		self.attention = nn.Sequential(
			nn.Linear(hidden_size, 1),
			nn.Softmax(dim=1),
		)
		nn.init.uniform_(self.attention[0].weight)
		self.out = nn.Linear(hidden_size, out_channels)
		
	def forward(self, x: torch.tensor, edge_index: torch.tensor, edge_weight: torch.tensor) -> torch.tensor:
		"""
		Performs a forward pass on the A3T-GCN model.
		:param x: The feature matrix of the graph X_t (Nodes_nb, Features_nb, SeqLength)
		:param edge_index: The edge index of the graph A (2, Edges_nb) 
		:param edge_weight: The edge weight of the graph (Edges_nb,)
		:return: The output of the model (Nodes_nb, OutFeatures_nb)
		"""
		h_prev = [
			torch.zeros(x.shape[0], self.hidden_size) for _ in range(self.layers_nb)
		]
		h_final = torch.zeros(x.shape[0], x.shape[-1], self.hidden_size)
		for t in range(x.shape[-1]):
			h = x[:, :, t]  # h is the output of the previous GRU layer (the input features for the first layer)
			for i, cell in enumerate(self.cells):
				h = cell(h, edge_index, edge_weight, h_prev[i])
				h_prev[i] = h
			h_final[:, t, :] = h
		att = self.attention(h_final)
		c_t = self.out(torch.sum(h_final * att, dim=1))
		return c_t

In [22]:
# Example of usage
model = A3TGCN(2, 1, 32)
x = torch.rand(10, 2, 5)
edge_index = torch.tensor(
	[[0, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7, 7, 8, 8, 9], 
	 [1, 0, 2, 1, 3, 2, 4, 3, 5, 4, 6, 5, 7, 6, 8, 7, 9, 8]]
)
edge_weight = torch.rand(edge_index.shape[1])
model(x, edge_index, edge_weight)

tensor([[0.2586],
        [0.2596],
        [0.2595],
        [0.2590],
        [0.2590],
        [0.2593],
        [0.2590],
        [0.2592],
        [0.2604],
        [0.2591]], grad_fn=<AddmmBackward0>)

## Diffusion Convolutional Recurrent Neural Network (DCRNN)
The Diffusion Convolutional Recurrent Neural Network (DCRNN) -- from the paper [Diffusion Convolutional Recurrent Neural Network: Data-Driven Traffic Forecasting](https://arxiv.org/pdf/1707.01926) -- is a model that uses a diffusion convolutional layer to extract the spatial features of the graph at a given timestep, and a GRU layer to model the temporal dependencies between the different timestamps.

In [23]:
class DCGRUCell(nn.Module):
	"""
	DCRNN Cell for one timestep, from https://arxiv.org/pdf/1707.01926.
	"""
	def __init__(self, in_channels: int, hidden_size: int):
		super(DCGRUCell, self).__init__()
		self.gcn_r = GCN(in_channels + hidden_size, [hidden_size, hidden_size], bias=True)
		self.gcn_u = GCN(in_channels + hidden_size, [hidden_size, hidden_size], bias=True)
		self.gcn_c = GCN(in_channels + hidden_size, [hidden_size, hidden_size], bias=True)
		
	def forward(self, x: torch.tensor, edge_index: torch.tensor, edge_weight: torch.tensor, h: torch.tensor) -> torch.tensor:
		"""
		Performs a forward pass on a single DCRNN cell.
		:param x: The feature matrix of the graph X_t (Nodes_nb, Features_nb)
		:param edge_index: The edge index of the graph A (2, Edges_nb) 
		:param edge_weight: The edge weight of the graph (Edges_nb,)
		:param h: The hidden state of the GRU h_{t-1} (Nodes_nb, Hidden_size)
		:return: The hidden state of the GRU h_t (Nodes_nb, Hidden_size)
		"""
		x_h = torch.cat([x, h], dim=-1)
		r = F.sigmoid(self.gcn_r(x_h, edge_index, edge_weight))
		u = F.sigmoid(self.gcn_u(x_h, edge_index, edge_weight))
		c = F.tanh(self.gcn_c(torch.cat([x, r * h], dim=-1), edge_index, edge_weight))
		return u * h + (1 - u) * c

In [24]:
class DCGNN(nn.Module):
	"""
	DCGNN model from https://arxiv.org/pdf/1707.01926.
	"""
	def __init__(self, in_channels: int, out_channels: int, hidden_size: int, layers_nb: int = 2):
		super(DCGNN, self).__init__()
		self.hidden_size = hidden_size
		self.layers_nb = max(1, layers_nb)
		self.cells = nn.ModuleList(
			[DCGRUCell(in_channels, hidden_size)] + [DCGRUCell(hidden_size, hidden_size) for _ in range(self.layers_nb - 1)]
		)
		self.out = nn.Linear(hidden_size, out_channels)
		
	def forward(self, x: torch.tensor, edge_index: torch.tensor, edge_weight: torch.tensor) -> torch.tensor:
		"""
		Performs a forward pass on the DCRNN model.
		:param x: The feature matrix of the graph X_t (Nodes_nb, Features_nb, SeqLength)
		:param edge_index: The edge index of the graph A (2, Edges_nb) 
		:param edge_weight: The edge weight of the graph (Edges_nb,)
		:return: The output of the model (Nodes_nb, OutFeatures_nb)
		"""
		h_prev = [
			torch.zeros(x.shape[0], self.hidden_size) for _ in range(self.layers_nb)
		]
		for t in range(x.shape[-1]):
			h = x[:, :, t]  # h is the output of the previous GRU layer (the input features for the first layer)
			for i, cell in enumerate(self.cells):
				h = cell(h, edge_index, edge_weight, h_prev[i])
				h_prev[i] = h
		return self.out(h_prev[-1])

In [25]:
# Example of usage
model = DCGNN(2, 1, 32)
x = torch.rand(10, 2, 5)
edge_index = torch.tensor(
	[[0, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7, 7, 8, 8, 9], 
	 [1, 0, 2, 1, 3, 2, 4, 3, 5, 4, 6, 5, 7, 6, 8, 7, 9, 8]]
)
edge_weight = torch.rand(edge_index.shape[1])
model(x, edge_index, edge_weight)

tensor([[-0.0706],
        [-0.0714],
        [-0.0708],
        [-0.0704],
        [-0.0709],
        [-0.0702],
        [-0.0712],
        [-0.0713],
        [-0.0706],
        [-0.0696]], grad_fn=<AddmmBackward0>)