In [7]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GCNConv

# Temporal Graph Neural Network models
This notebook will present an implementation of different Temporal Graph Neural Network (TGNN) models, that will be used in the next notebooks.

## Graph Convolutional Network (GCN)
We define a simple Graph Convolutional Network (GCN) model, that will be used as a building block for the Temporal Graph Convolutional Network (T-GCN) model.

In [3]:
class GCN(nn.Module):
	"""
	Simple two layers GCN model.
	"""
	def __init__(self, in_channels: int, layer_sizes: list[int] = None):
		super(GCN, self).__init__()
		layer_sizes = layer_sizes or [32, 32]
		self.convs = nn.ModuleList([
			GCNConv(in_channels, layer_sizes[0]),
		] + [
			GCNConv(layer_sizes[i], layer_sizes[i + 1]) for i in range(len(layer_sizes) - 1)
		])
							   		
	def forward(self, x: torch.tensor, edge_index: torch.tensor, edge_weight: torch.tensor) -> torch.tensor:
		"""
		Performs a forward pass on the GCN model.
		:param x: The feature matrix of the graph X_t (Nodes_nb, Features_nb)
		:param edge_index: The edge index of the graph A (2, Edges_nb) 
		:param edge_weight: The edge weight of the graph (Edges_nb,)
		:return: The hidden state of the GCN h_t (Nodes_nb, Hidden_size)
		"""
		for conv in self.convs:
			x = F.relu(conv(x, edge_index, edge_weight))
		return x

## Temporal Graph Convolutional Network (T-GCN)
The Temporal Graph Convolutional Network (T-GCN) -- from the paper [T-GCN: A Temporal Graph Convolutional Network for Traffic Prediction](https://arxiv.org/pdf/1811.05320) -- is a simple model that uses GCN layers followed by a GRU layer. The GCN layers are used to extract the spatial features of the graph at a given timestep, and the GRU layer is used to model the temporal dependencies between the different timestamps.

In [4]:
class TGCNCell(nn.Module):
	"""
	T-GCN Cell for one timestep, from https://arxiv.org/pdf/1811.05320.
	"""
	def __init__(self, in_channels: int, hidden_size: int):
		super(TGCNCell, self).__init__()
		self.gcn = GCN(in_channels, [hidden_size, hidden_size])
		self.lin_u = nn.Linear(2 * hidden_size, hidden_size)
		self.lin_r = nn.Linear(2 * hidden_size, hidden_size)
		self.lin_c = nn.Linear(2 * hidden_size, hidden_size)
		
	def forward(self, x: torch.tensor, edge_index: torch.tensor, edge_weight: torch.tensor, h: torch.tensor) -> tuple[torch.tensor, torch.tensor]:
		"""
		Performs a forward pass on a single T-GCN cell (GCN + GRU).
		:param x: The feature matrix of the graph X_t (Nodes_nb, Features_nb)
		:param edge_index: The edge index of the graph A (2, Edges_nb) 
		:param edge_weight: The edge weight of the graph (Edges_nb,)
		:param h: The hidden state of the GRU h_{t-1} (Nodes_nb, Hidden_size)
		:return: The hidden state of the GRU h_t (Nodes_nb, Hidden_size)
		"""
		x = F.sigmoid(self.gcn(x, edge_index, edge_weight))  # f(A,X_t), Eq. 2
		u = F.sigmoid(self.lin_u(torch.cat([x, h], dim=-1)))  # u_t, Eq. 3
		r = F.sigmoid(self.lin_r(torch.cat([x, h], dim=-1)))  # r_t,  Eq. 4
		c = F.tanh(self.lin_c(torch.cat([x, r * h], dim=-1)))  # c_t, Eq. 5

		return u * h + (1 - u) * c  # h_t, Eq. 6

In [5]:
class TGCN(nn.Module):
	"""
	T-GCN model from https://arxiv.org/pdf/1811.05320.
	"""
	def __init__(self, in_channels: int, out_channels: int, hidden_size: int):
		super(TGCN, self).__init__()
		self.hidden_size = hidden_size
		self.cell = TGCNCell(in_channels, hidden_size)
		self.out = nn.Linear(hidden_size, out_channels)
		
	def forward(self, x: torch.tensor, edge_index: torch.tensor, edge_weight: torch.tensor) -> torch.tensor:
		"""
		Performs a forward pass on the T-GCN model.
		:param x: The feature matrix of the graph X_t (Nodes_nb, Features_nb, SeqLength)
		:param edge_index: The edge index of the graph A (2, Edges_nb) 
		:param edge_weight: The edge weight of the graph (Edges_nb,)
		:return: The output of the model (Nodes_nb, OutFeatures_nb)
		"""
		h = torch.zeros(x.shape[0], self.hidden_size)
		for t in range(x.shape[-1]):
			h = self.cell(x[:, :, t], edge_index, edge_weight, h)
		return self.out(h)

In [30]:
# Example of usage
model = TGCN(2, 1, 32)
x = torch.rand(10, 2, 5)
edge_index = torch.tensor(
	[[0, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7, 7, 8, 8, 9], 
	 [1, 0, 2, 1, 3, 2, 4, 3, 5, 4, 6, 5, 7, 6, 8, 7, 9, 8]]
)
edge_weight = torch.rand(edge_index.shape[1])
model(x, edge_index, edge_weight).shape

torch.Size([10, 1])

## Attention Temporal Graph Convolutional Network (A3T-GCN)
The Attention Temporal Graph Convolutional Network (A3T-GCN) -- from the paper [A3T-GCN: Attention Temporal Graph Convolutional Network for Traffic Forecasting](https://arxiv.org/pdf/2006.11583) -- is a model that extends the T-GCN model by adding an attention mechanism to the hidden states computed by the GRU cells. The attention mechanism is used to learn the importance of the different nodes in the graph at a given timestep.

In [31]:
class A3TGCN(nn.Module):
	"""
	A3T-GCN model from https://arxiv.org/pdf/2006.11583.
	"""
	def __init__(self, in_channels: int, out_channels: int, hidden_size: int):
		super(A3TGCN, self).__init__()
		self.hidden_size = hidden_size
		self.cell = TGCNCell(in_channels, hidden_size)
		self.attention = nn.Sequential(
			nn.Linear(hidden_size, hidden_size),
			nn.ReLU(),
			nn.Linear(hidden_size, 1),
			nn.Softmax(dim=0),
		)
		self.out = nn.Linear(hidden_size, out_channels)
		
	def forward(self, x: torch.tensor, edge_index: torch.tensor, edge_weight: torch.tensor) -> torch.tensor:
		"""
		Performs a forward pass on the A3T-GCN model.
		:param x: The feature matrix of the graph X_t (Nodes_nb, Features_nb, SeqLength)
		:param edge_index: The edge index of the graph A (2, Edges_nb) 
		:param edge_weight: The edge weight of the graph (Edges_nb,)
		:return: The output of the model (Nodes_nb, OutFeatures_nb)
		"""
		h = torch.zeros(x.shape[0], self.hidden_size)
		hs = torch.zeros(x.shape[0], x.shape[-1], self.hidden_size)
		for t in range(x.shape[-1]):
			h = self.cell(x[:, :, t], edge_index, edge_weight, h)
			hs[:, t, :] = h
		attention_scores = self.attention(hs) # alpha_i, Eq. 8,9
		c = torch.sum(hs * attention_scores, dim=1)  # C_t, Eq. 10
		return self.out(c)

In [33]:
# Example of usage
model = A3TGCN(2, 1, 32)
x = torch.rand(10, 2, 5)
edge_index = torch.tensor(
	[[0, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7, 7, 8, 8, 9], 
	 [1, 0, 2, 1, 3, 2, 4, 3, 5, 4, 6, 5, 7, 6, 8, 7, 9, 8]]
)
edge_weight = torch.rand(edge_index.shape[1])
model(x, edge_index, edge_weight).shape

torch.Size([10, 1])