In [None]:
!pip install torch==2.2.1+cu121 torchvision==0.17.1+cu121 -f https://download.pytorch.org/whl/torch_stable.html --no-cache-dir
!pip install "numpy<2" --no-cache-dir
!pip install dgl==2.4.0 -f https://data.dgl.ai/wheels/torch-2.2/cu121/repo.html --no-cache-dir

In [None]:
import argparse
from functools import partial
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import DataLoader

In [None]:
import os
import ssl

import dgl

import numpy as np
import torch
from six.moves import urllib
from torch.utils.data import DataLoader, Dataset


def download_file(dataset):
    print("Start Downloading data: {}".format(dataset))
    url = "https://s3.us-west-2.amazonaws.com/dgl-data/dataset/{}".format(
        dataset
    )
    print("Start Downloading File....")
    context = ssl._create_unverified_context()
    data = urllib.request.urlopen(url, context=context)
    with open("./data/{}".format(dataset), "wb") as handle:
        handle.write(data.read())


class SnapShotDataset(Dataset):
    def __init__(self, path, npz_file):
        if not os.path.exists(path + "/" + npz_file):
            if not os.path.exists(path):
                os.mkdir(path)
            download_file(npz_file)
        zipfile = np.load(path + "/" + npz_file)
        self.x = zipfile["x"]
        self.y = zipfile["y"]

    def __len__(self):
        return len(self.x)

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()

        return self.x[idx, ...], self.y[idx, ...]


def METR_LAGraphDataset():
    if not os.path.exists("data/graph_la.bin"):
        if not os.path.exists("data"):
            os.mkdir("data")
        download_file("graph_la.bin")
    g, _ = dgl.load_graphs("data/graph_la.bin")
    return g[0]


class METR_LATrainDataset(SnapShotDataset):
    def __init__(self):
        super(METR_LATrainDataset, self).__init__("data", "metr_la_train.npz")
        self.mean = self.x[..., 0].mean()
        self.std = self.x[..., 0].std()


class METR_LATestDataset(SnapShotDataset):
    def __init__(self):
        super(METR_LATestDataset, self).__init__("data", "metr_la_test.npz")


class METR_LAValidDataset(SnapShotDataset):
    def __init__(self):
        super(METR_LAValidDataset, self).__init__("data", "metr_la_valid.npz")


def PEMS_BAYGraphDataset():
    if not os.path.exists("data/graph_bay.bin"):
        if not os.path.exists("data"):
            os.mkdir("data")
        download_file("graph_bay.bin")
    g, _ = dgl.load_graphs("data/graph_bay.bin")
    return g[0]


class PEMS_BAYTrainDataset(SnapShotDataset):
    def __init__(self):
        super(PEMS_BAYTrainDataset, self).__init__("data", "pems_bay_train.npz")
        self.mean = self.x[..., 0].mean()
        self.std = self.x[..., 0].std()


class PEMS_BAYTestDataset(SnapShotDataset):
    def __init__(self):
        super(PEMS_BAYTestDataset, self).__init__("data", "pems_bay_test.npz")


class PEMS_BAYValidDataset(SnapShotDataset):
    def __init__(self):
        super(PEMS_BAYValidDataset, self).__init__("data", "pems_bay_valid.npz")


In [None]:
import dgl
import dgl.function as fn
import numpy as np
import scipy.sparse as sparse
import torch
import torch.nn as nn
from dgl.base import DGLError


class DiffConv(nn.Module):
    """DiffConv is the implementation of diffusion convolution from paper DCRNN
    It will compute multiple diffusion matrix and perform multiple diffusion conv on it,
    this layer can be used for traffic prediction, pedamic model.
    Parameter
    ==========
    in_feats : int
        number of input feature

    out_feats : int
        number of output feature

    k : int
        number of diffusion steps

    dir : str [both/in/out]
        direction of diffusion convolution
        From paper default both direction
    """

    def __init__(
        self, in_feats, out_feats, k, in_graph_list, out_graph_list, dir="both"
    ):
        super(DiffConv, self).__init__()
        self.in_feats = in_feats
        self.out_feats = out_feats
        self.k = k
        self.dir = dir
        self.num_graphs = self.k - 1 if self.dir == "both" else 2 * self.k - 2
        self.project_fcs = nn.ModuleList()
        for i in range(self.num_graphs):
            self.project_fcs.append(
                nn.Linear(self.in_feats, self.out_feats, bias=False)
            )
        self.merger = nn.Parameter(torch.randn(self.num_graphs + 1))
        self.in_graph_list = in_graph_list
        self.out_graph_list = out_graph_list

    @staticmethod
    def attach_graph(g, k):
        device = g.device
        out_graph_list = []
        in_graph_list = []
        wadj, ind, outd = DiffConv.get_weight_matrix(g)
        adj = sparse.coo_matrix(wadj / outd.cpu().numpy())
        outg = dgl.from_scipy(adj, eweight_name="weight").to(device)
        outg.edata["weight"] = outg.edata["weight"].float().to(device)
        out_graph_list.append(outg)
        for i in range(k - 1):
            out_graph_list.append(
                DiffConv.diffuse(out_graph_list[-1], wadj, outd)
            )
        adj = sparse.coo_matrix(wadj.T / ind.cpu().numpy())
        ing = dgl.from_scipy(adj, eweight_name="weight").to(device)
        ing.edata["weight"] = ing.edata["weight"].float().to(device)
        in_graph_list.append(ing)
        for i in range(k - 1):
            in_graph_list.append(
                DiffConv.diffuse(in_graph_list[-1], wadj.T, ind)
            )
        return out_graph_list, in_graph_list

    @staticmethod
    def get_weight_matrix(g):
        adj = g.adj_external(scipy_fmt="coo")
        ind = g.in_degrees()
        outd = g.out_degrees()
        weight = g.edata["weight"]
        adj.data = weight.cpu().numpy()
        return adj, ind, outd

    @staticmethod
    def diffuse(progress_g, weighted_adj, degree):
        device = progress_g.device
        progress_adj = progress_g.adj_external(scipy_fmt="coo")
        progress_adj.data = progress_g.edata["weight"].cpu().numpy()
        ret_adj = sparse.coo_matrix(
            progress_adj @ (weighted_adj / degree.cpu().numpy())
        )
        ret_graph = dgl.from_scipy(ret_adj, eweight_name="weight").to(device)
        ret_graph.edata["weight"] = ret_graph.edata["weight"].float().to(device)
        return ret_graph

    def forward(self, g, x):
        feat_list = []
        if self.dir == "both":
            graph_list = self.in_graph_list + self.out_graph_list
        elif self.dir == "in":
            graph_list = self.in_graph_list
        elif self.dir == "out":
            graph_list = self.out_graph_list

        for i in range(self.num_graphs):
            g = graph_list[i]
            with g.local_scope():
                g.ndata["n"] = self.project_fcs[i](x)
                g.update_all(
                    fn.u_mul_e("n", "weight", "e"), fn.sum("e", "feat")
                )
                feat_list.append(g.ndata["feat"])
                # Each feat has shape [N,q_feats]
        feat_list.append(self.project_fcs[-1](x))
        feat_list = torch.cat(feat_list).view(
            len(feat_list), -1, self.out_feats
        )
        ret = (
            (self.merger * feat_list.permute(1, 2, 0)).permute(2, 0, 1).mean(0)
        )
        return ret


In [None]:
import dgl
import dgl.function as fn
import dgl.nn as dglnn
import numpy as np
import torch
import torch.nn as nn
from dgl.base import DGLError
from dgl.nn.functional import edge_softmax


class WeightedGATConv(dglnn.GATConv):
    """
    This model inherit from dgl GATConv for traffic prediction task,
    it add edge weight when aggregating the node feature.
    """

    def forward(self, graph, feat, get_attention=False):
        with graph.local_scope():
            if not self._allow_zero_in_degree:
                if (graph.in_degrees() == 0).any():
                    raise DGLError(
                        "There are 0-in-degree nodes in the graph, "
                        "output for those nodes will be invalid. "
                        "This is harmful for some applications, "
                        "causing silent performance regression. "
                        "Adding self-loop on the input graph by "
                        "calling `g = dgl.add_self_loop(g)` will resolve "
                        "the issue. Setting ``allow_zero_in_degree`` "
                        "to be `True` when constructing this module will "
                        "suppress the check and let the code run."
                    )

            if isinstance(feat, tuple):
                h_src = self.feat_drop(feat[0])
                h_dst = self.feat_drop(feat[1])
                if not hasattr(self, "fc_src"):
                    feat_src = self.fc(h_src).view(
                        -1, self._num_heads, self._out_feats
                    )
                    feat_dst = self.fc(h_dst).view(
                        -1, self._num_heads, self._out_feats
                    )
                else:
                    feat_src = self.fc_src(h_src).view(
                        -1, self._num_heads, self._out_feats
                    )
                    feat_dst = self.fc_dst(h_dst).view(
                        -1, self._num_heads, self._out_feats
                    )
            else:
                h_src = h_dst = self.feat_drop(feat)
                feat_src = feat_dst = self.fc(h_src).view(
                    -1, self._num_heads, self._out_feats
                )
                if graph.is_block:
                    feat_dst = feat_src[: graph.number_of_dst_nodes()]
            # NOTE: GAT paper uses "first concatenation then linear projection"
            # to compute attention scores, while ours is "first projection then
            # addition", the two approaches are mathematically equivalent:
            # We decompose the weight vector a mentioned in the paper into
            # [a_l || a_r], then
            # a^T [Wh_i || Wh_j] = a_l Wh_i + a_r Wh_j
            # Our implementation is much efficient because we do not need to
            # save [Wh_i || Wh_j] on edges, which is not memory-efficient. Plus,
            # addition could be optimized with DGL's built-in function u_add_v,
            # which further speeds up computation and saves memory footprint.
            el = (feat_src * self.attn_l).sum(dim=-1).unsqueeze(-1)
            er = (feat_dst * self.attn_r).sum(dim=-1).unsqueeze(-1)
            graph.srcdata.update({"ft": feat_src, "el": el})
            graph.dstdata.update({"er": er})
            # compute edge attention, el and er are a_l Wh_i and a_r Wh_j respectively.
            graph.apply_edges(fn.u_add_v("el", "er", "e"))
            e = self.leaky_relu(graph.edata.pop("e"))
            # compute softmax
            graph.edata["a"] = self.attn_drop(edge_softmax(graph, e))
            # compute weighted attention
            graph.edata["a"] = (
                graph.edata["a"].permute(1, 2, 0) * graph.edata["weight"]
            ).permute(2, 0, 1)
            # message passing
            graph.update_all(fn.u_mul_e("ft", "a", "m"), fn.sum("m", "ft"))
            rst = graph.dstdata["ft"]
            # residual
            if self.res_fc is not None:
                resval = self.res_fc(h_dst).view(
                    h_dst.shape[0], -1, self._out_feats
                )
                rst = rst + resval
            # activation
            if self.activation:
                rst = self.activation(rst)

            if get_attention:
                return rst, graph.edata["a"]
            else:
                return rst


class GatedGAT(nn.Module):
    """Gated Graph Attention module, it is a general purpose
    graph attention module proposed in paper GaAN. The paper use
    it for traffic prediction task
    Parameter
    ==========
    in_feats : int
        number of input feature

    out_feats : int
        number of output feature

    map_feats : int
        intermediate feature size for gate computation

    num_heads : int
        number of head for multihead attention
    """

    def __init__(self, in_feats, out_feats, map_feats, num_heads):
        super(GatedGAT, self).__init__()
        self.in_feats = in_feats
        self.out_feats = out_feats
        self.map_feats = map_feats
        self.num_heads = num_heads
        self.gatlayer = WeightedGATConv(
            self.in_feats, self.out_feats, self.num_heads
        )
        self.gate_fn = nn.Linear(
            2 * self.in_feats + self.map_feats, self.num_heads
        )
        self.gate_m = nn.Linear(self.in_feats, self.map_feats)
        self.merger_layer = nn.Linear(
            self.in_feats + self.out_feats, self.out_feats
        )

    def forward(self, g, x):
        with g.local_scope():
            g.ndata["x"] = x
            g.ndata["z"] = self.gate_m(x)
            g.update_all(fn.copy_u("x", "x"), fn.mean("x", "mean_z"))
            g.update_all(fn.copy_u("z", "z"), fn.max("z", "max_z"))
            nft = torch.cat(
                [g.ndata["x"], g.ndata["max_z"], g.ndata["mean_z"]], dim=1
            )
            gate = self.gate_fn(nft).sigmoid()
            attn_out = self.gatlayer(g, x)
            node_num = g.num_nodes()
            gated_out = (
                (gate.view(-1) * attn_out.view(-1, self.out_feats).T).T
            ).view(node_num, self.num_heads, self.out_feats)
            gated_out = gated_out.mean(1)
            merge = self.merger_layer(torch.cat([x, gated_out], dim=1))
            return merge


In [None]:
import dgl
import dgl.function as fn
import dgl.nn as dglnn
import numpy as np
import scipy.sparse as sparse
import torch
import torch.nn as nn
from dgl.base import DGLError
from dgl.nn.functional import edge_softmax


class GraphGRUCell(nn.Module):
    """Graph GRU unit which can use any message passing
    net to replace the linear layer in the original GRU
    Parameter
    ==========
    in_feats : int
        number of input features

    out_feats : int
        number of output features

    net : torch.nn.Module
        message passing network
    """

    def __init__(self, in_feats, out_feats, net):
        super(GraphGRUCell, self).__init__()
        self.in_feats = in_feats
        self.out_feats = out_feats
        self.dir = dir
        # net can be any GNN model
        self.r_net = net(in_feats + out_feats, out_feats)
        self.u_net = net(in_feats + out_feats, out_feats)
        self.c_net = net(in_feats + out_feats, out_feats)
        # Manually add bias Bias
        self.r_bias = nn.Parameter(torch.rand(out_feats))
        self.u_bias = nn.Parameter(torch.rand(out_feats))
        self.c_bias = nn.Parameter(torch.rand(out_feats))

    def forward(self, g, x, h):
        r = torch.sigmoid(self.r_net(g, torch.cat([x, h], dim=1)) + self.r_bias)
        u = torch.sigmoid(self.u_net(g, torch.cat([x, h], dim=1)) + self.u_bias)
        h_ = r * h
        c = torch.sigmoid(
            self.c_net(g, torch.cat([x, h_], dim=1)) + self.c_bias
        )
        new_h = u * h + (1 - u) * c
        return new_h


class StackedEncoder(nn.Module):
    """One step encoder unit for hidden representation generation
    it can stack multiple vertical layers to increase the depth.

    Parameter
    ==========
    in_feats : int
        number if input features

    out_feats : int
        number of output features

    num_layers : int
        vertical depth of one step encoding unit

    net : torch.nn.Module
        message passing network for graph computation
    """

    def __init__(self, in_feats, out_feats, num_layers, net):
        super(StackedEncoder, self).__init__()
        self.in_feats = in_feats
        self.out_feats = out_feats
        self.num_layers = num_layers
        self.net = net
        self.layers = nn.ModuleList()
        if self.num_layers <= 0:
            raise DGLError("Layer Number must be greater than 0! ")
        self.layers.append(
            GraphGRUCell(self.in_feats, self.out_feats, self.net)
        )
        for _ in range(self.num_layers - 1):
            self.layers.append(
                GraphGRUCell(self.out_feats, self.out_feats, self.net)
            )

    # hidden_states should be a list which for different layer
    def forward(self, g, x, hidden_states):
        hiddens = []
        for i, layer in enumerate(self.layers):
            x = layer(g, x, hidden_states[i])
            hiddens.append(x)
        return x, hiddens


class StackedDecoder(nn.Module):
    """One step decoder unit for hidden representation generation
    it can stack multiple vertical layers to increase the depth.

    Parameter
    ==========
    in_feats : int
        number if input features

    hid_feats : int
        number of feature before the linear output layer

    out_feats : int
        number of output features

    num_layers : int
        vertical depth of one step encoding unit

    net : torch.nn.Module
        message passing network for graph computation
    """

    def __init__(self, in_feats, hid_feats, out_feats, num_layers, net):
        super(StackedDecoder, self).__init__()
        self.in_feats = in_feats
        self.hid_feats = hid_feats
        self.out_feats = out_feats
        self.num_layers = num_layers
        self.net = net
        self.out_layer = nn.Linear(self.hid_feats, self.out_feats)
        self.layers = nn.ModuleList()
        if self.num_layers <= 0:
            raise DGLError("Layer Number must be greater than 0!")
        self.layers.append(GraphGRUCell(self.in_feats, self.hid_feats, net))
        for _ in range(self.num_layers - 1):
            self.layers.append(
                GraphGRUCell(self.hid_feats, self.hid_feats, net)
            )

    def forward(self, g, x, hidden_states):
        hiddens = []
        for i, layer in enumerate(self.layers):
            x = layer(g, x, hidden_states[i])
            hiddens.append(x)
        x = self.out_layer(x)
        return x, hiddens


class GraphRNN(nn.Module):
    """Graph Sequence to sequence prediction framework
    Support multiple backbone GNN. Mainly used for traffic prediction.

    Parameter
    ==========
    in_feats : int
        number of input features

    out_feats : int
        number of prediction output features

    seq_len : int
        input and predicted sequence length

    num_layers : int
        vertical number of layers in encoder and decoder unit

    net : torch.nn.Module
        Message passing GNN as backbone

    decay_steps : int
        number of steps for the teacher forcing probability to decay
    """

    def __init__(
        self, in_feats, out_feats, seq_len, num_layers, net, decay_steps
    ):
        super(GraphRNN, self).__init__()
        self.in_feats = in_feats
        self.out_feats = out_feats
        self.seq_len = seq_len
        self.num_layers = num_layers
        self.net = net
        self.decay_steps = decay_steps

        self.encoder = StackedEncoder(
            self.in_feats, self.out_feats, self.num_layers, self.net
        )

        self.decoder = StackedDecoder(
            self.in_feats,
            self.out_feats,
            self.in_feats,
            self.num_layers,
            self.net,
        )

    # Threshold For Teacher Forcing

    def compute_thresh(self, batch_cnt):
        return self.decay_steps / (
            self.decay_steps + np.exp(batch_cnt / self.decay_steps)
        )

    def encode(self, g, inputs, device):
        hidden_states = [
            torch.zeros(g.num_nodes(), self.out_feats).to(device)
            for _ in range(self.num_layers)
        ]
        for i in range(self.seq_len):
            _, hidden_states = self.encoder(g, inputs[i], hidden_states)

        return hidden_states

    def decode(self, g, teacher_states, hidden_states, batch_cnt, device):
        outputs = []
        inputs = torch.zeros(g.num_nodes(), self.in_feats).to(device)
        for i in range(self.seq_len):
            if (
                np.random.random() < self.compute_thresh(batch_cnt)
                and self.training
            ):
                inputs, hidden_states = self.decoder(
                    g, teacher_states[i], hidden_states
                )
            else:
                inputs, hidden_states = self.decoder(g, inputs, hidden_states)
            outputs.append(inputs)
        outputs = torch.stack(outputs)
        return outputs

    def forward(self, g, inputs, teacher_states, batch_cnt, device):
        hidden = self.encode(g, inputs, device)
        outputs = self.decode(g, teacher_states, hidden, batch_cnt, device)
        return outputs


In [None]:
import dgl
import numpy as np
import scipy.sparse as sparse
import torch
import torch.nn as nn


class NormalizationLayer(nn.Module):
    def __init__(self, mean, std):
        self.mean = mean
        self.std = std

    # Here we shall expect mean and std be scaler
    def normalize(self, x):
        return (x - self.mean) / self.std

    def denormalize(self, x):
        return x * self.std + self.mean


def masked_mae_loss(y_pred, y_true):
    mask = (y_true != 0).float()
    mask /= mask.mean()
    loss = torch.abs(y_pred - y_true)
    loss = loss * mask
    # trick for nans: https://discuss.pytorch.org/t/how-to-set-nan-in-tensor-to-0/3918/3
    loss[loss != loss] = 0
    return loss.mean()


def get_learning_rate(optimizer):
    for param in optimizer.param_groups:
        return param["lr"]


In [None]:
batch_cnt = [0]


def train(
    model,
    graph,
    dataloader,
    optimizer,
    scheduler,
    normalizer,
    loss_fn,
    device,
    args,
):
    total_loss = []
    graph = graph.to(device)
    model.train()
    batch_size = args.batch_size
    for i, (x, y) in enumerate(dataloader):
        optimizer.zero_grad()
        # Padding: Since the diffusion graph is precmputed we need to pad the batch so that
        # each batch have same batch size
        if x.shape[0] != batch_size:
            x_buff = torch.zeros(batch_size, x.shape[1], x.shape[2], x.shape[3])
            y_buff = torch.zeros(batch_size, x.shape[1], x.shape[2], x.shape[3])
            x_buff[: x.shape[0], :, :, :] = x
            x_buff[x.shape[0] :, :, :, :] = x[-1].repeat(
                batch_size - x.shape[0], 1, 1, 1
            )
            y_buff[: x.shape[0], :, :, :] = y
            y_buff[x.shape[0] :, :, :, :] = y[-1].repeat(
                batch_size - x.shape[0], 1, 1, 1
            )
            x = x_buff
            y = y_buff
        # Permute the dimension for shaping
        x = x.permute(1, 0, 2, 3)
        y = y.permute(1, 0, 2, 3)

        x_norm = (
            normalizer.normalize(x)
            .reshape(x.shape[0], -1, x.shape[3])
            .float()
            .to(device)
        )
        y_norm = (
            normalizer.normalize(y)
            .reshape(x.shape[0], -1, x.shape[3])
            .float()
            .to(device)
        )
        y = y.reshape(y.shape[0], -1, y.shape[3]).float().to(device)

        batch_graph = dgl.batch([graph] * batch_size)
        output = model(batch_graph, x_norm, y_norm, batch_cnt[0], device)
        # Denormalization for loss compute
        y_pred = normalizer.denormalize(output)
        loss = loss_fn(y_pred, y)
        loss.backward()
        nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm)
        optimizer.step()
        if get_learning_rate(optimizer) > args.minimum_lr:
            scheduler.step()
        total_loss.append(float(loss))
        batch_cnt[0] += 1
        print("\rBatch: ", i, end="")
    return np.mean(total_loss)


def eval(model, graph, dataloader, normalizer, loss_fn, device, args):
    total_loss = []
    preds, gts = [], []       # changed by me
    graph = graph.to(device)
    model.eval()
    batch_size = args.batch_size
    for i, (x, y) in enumerate(dataloader):
        # Padding: Since the diffusion graph is precmputed we need to pad the batch so that
        # each batch have same batch size
        if x.shape[0] != batch_size:
            x_buff = torch.zeros(batch_size, x.shape[1], x.shape[2], x.shape[3])
            y_buff = torch.zeros(batch_size, x.shape[1], x.shape[2], x.shape[3])
            x_buff[: x.shape[0], :, :, :] = x
            x_buff[x.shape[0] :, :, :, :] = x[-1].repeat(
                batch_size - x.shape[0], 1, 1, 1
            )
            y_buff[: x.shape[0], :, :, :] = y
            y_buff[x.shape[0] :, :, :, :] = y[-1].repeat(
                batch_size - x.shape[0], 1, 1, 1
            )
            x = x_buff
            y = y_buff
        # Permute the order of dimension
        x = x.permute(1, 0, 2, 3)
        y = y.permute(1, 0, 2, 3)

        x_norm = (
            normalizer.normalize(x)
            .reshape(x.shape[0], -1, x.shape[3])
            .float()
            .to(device)
        )
        y_norm = (
            normalizer.normalize(y)
            .reshape(x.shape[0], -1, x.shape[3])
            .float()
            .to(device)
        )
        y = y.reshape(x.shape[0], -1, x.shape[3]).to(device)

        batch_graph = dgl.batch([graph] * batch_size)
        output = model(batch_graph, x_norm, y_norm, i, device)
        y_pred = normalizer.denormalize(output)
        preds.append(y_pred.cpu().detach().numpy())   # changed by me
        gts.append(y.cpu().detach().numpy())          # changed by me
        loss = loss_fn(y_pred, y)
        total_loss.append(float(loss))
    # Concatenate along batch/time for plotting
    
    # shape [num_samples, seq_len, num_nodes, out_feats]
    gts = np.concatenate(gts, axis=0)      # same shape as preds
    
    return preds, gts, np.mean(total_loss)

In [None]:
# Define arguments directly
class Args:
    def __init__(self):
        self.batch_size = 64
        self.num_workers = 0
        self.model = "dcrnn"  # Choose between "dcrnn" and "gaan"
        self.gpu = 0  # Set to 0 for GPU, -1 for CPU
        self.diffsteps = 2
        self.num_heads = 2
        self.decay_steps = 2000
        self.lr = 0.01
        self.minimum_lr = 2e-6
        self.dataset = "LA"  # Choose between "LA" and "BAY"
        self.epochs = 100
        self.max_grad_norm = 5.0

# Create an instance of Args
args = Args()

# Load the datasets
if args.dataset == "LA":
    g = METR_LAGraphDataset()
    train_data = METR_LATrainDataset()
    test_data = METR_LATestDataset()
    valid_data = METR_LAValidDataset()
elif args.dataset == "BAY":
    g = PEMS_BAYGraphDataset()
    train_data = PEMS_BAYTrainDataset()
    test_data = PEMS_BAYTestDataset()
    valid_data = PEMS_BAYValidDataset()

if args.gpu == -1:
    device = torch.device("cpu")
else:
    device = torch.device("cuda:{}".format(args.gpu))

train_loader = DataLoader(
    train_data,
    batch_size=args.batch_size,
    num_workers=args.num_workers,
    shuffle=True,
)
valid_loader = DataLoader(
    valid_data,
    batch_size=args.batch_size,
    num_workers=args.num_workers,
    shuffle=True,
)
test_loader = DataLoader(
    test_data,
    batch_size=args.batch_size,
    num_workers=args.num_workers,
    shuffle=True,
)
normalizer = NormalizationLayer(train_data.mean, train_data.std)

if args.model == "dcrnn":
    batch_g = dgl.batch([g] * args.batch_size).to(device)
    out_gs, in_gs = DiffConv.attach_graph(batch_g, args.diffsteps)
    net = partial(
        DiffConv,
        k=args.diffsteps,
        in_graph_list=in_gs,
        out_graph_list=out_gs,
    )
elif args.model == "gaan":
    net = partial(GatedGAT, map_feats=64, num_heads=args.num_heads)

dcrnn = GraphRNN(
    in_feats=2,
    out_feats=64,
    seq_len=12,
    num_layers=2,
    net=net,
    decay_steps=args.decay_steps,
).to(device)

optimizer = torch.optim.Adam(dcrnn.parameters(), lr=args.lr)
scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.99)

loss_fn = masked_mae_loss

for e in range(args.epochs):
    train_loss = train(
        dcrnn,
        g,
        train_loader,
        optimizer,
        scheduler,
        normalizer,
        loss_fn,
        device,
        args,
    )
    preds_val, gts_val, valid_loss = eval(
        dcrnn, g, valid_loader, normalizer, loss_fn, device, args
    )
    preds_test, gts_test, test_loss = eval(
        dcrnn, g, test_loader, normalizer, loss_fn, device, args
    )
    print(
        "\rEpoch: {} Train Loss: {} Valid Loss: {} Test Loss: {}".format(
            e, train_loss, valid_loss, test_loss
        )
    )

In [None]:
# import torch
# import numpy as np
# import matplotlib.pyplot as plt

# # Convert to numpy arrays (if not already)
# preds_np = np.array(preds_test)
# gts_np = np.array(gts_test)

# print("Predictions shape:", preds_np.shape)  # (108, 12, 13248, 2)
# print("Ground truth shape:", gts_np.shape)   # (1296, 13248, 2)

# # Reshape ground truth from (1296, 13248, 2) to (108, 12, 13248, 2)
# num_batches = preds_np.shape[0]        # 108
# forecast_steps = preds_np.shape[1]       # 12
# flat_dim = preds_np.shape[2]             # 13248
# features = preds_np.shape[3]             # 2

# expected_ground_truth_size = num_batches * forecast_steps  # 108*12 = 1296
# if gts_np.shape[0] != expected_ground_truth_size:
#     raise ValueError("Ground truth first dimension does not match expected (num_batches * forecast_steps)")

# gts_np_reshaped = gts_np.reshape(num_batches, forecast_steps, flat_dim, features)
# print("Reshaped ground truth shape:", gts_np_reshaped.shape)  # Should be (108, 12, 13248, 2)

# # For METR-LA, we expect a batch size of 64 and 207 sensors: 64 * 207 = 13248.
# batch_size = 64
# num_nodes = 207

# if flat_dim != batch_size * num_nodes:
#     raise ValueError("Flat dimension does not match batch_size*num_nodes")

# # Select a specific batch (for example, the first batch: batch_idx = 0)
# batch_idx = 0
# selected_preds = preds_np[batch_idx]     # shape: (12, 13248, 2)
# selected_gts   = gts_np_reshaped[batch_idx]  # shape: (12, 13248, 2)

# # Reshape each to (seq_len, batch_size, num_nodes, features)
# seq_len = forecast_steps  # 12 time steps
# selected_preds_reshaped = selected_preds.reshape(seq_len, batch_size, num_nodes, features)
# selected_gts_reshaped   = selected_gts.reshape(seq_len, batch_size, num_nodes, features)

# # Choose a specific sample (from the batch) and a sensor to plot.
# sample_idx = 0    # choose the first sample in the batch (0 <= sample_idx < 64)
# sensor_idx = 10   # choose sensor index 10 (0 <= sensor_idx < 207)

# # Extract the time series for the chosen sensor and sample.
# # We use the first feature (index 0) assuming it represents traffic speed.
# pred_series = selected_preds_reshaped[:, sample_idx, sensor_idx, 0]  # shape: (seq_len,)
# gt_series   = selected_gts_reshaped[:, sample_idx, sensor_idx, 0]      # shape: (seq_len,)

# # Create a time axis for the forecast horizon.
# time_axis = np.arange(seq_len)

# # Plot the results.
# plt.figure(figsize=(10, 4))
# plt.plot(time_axis, gt_series, label="Ground Truth", color="blue", linewidth=2)
# plt.plot(time_axis, pred_series, label="DCRNN Prediction", color="orange", linestyle="--", linewidth=2)
# plt.xlabel("Forecast Time Step")
# plt.ylabel("Traffic Speed (mph)")
# plt.title(f"Traffic Speed Forecast at Sensor {sensor_idx} (Sample {sample_idx}, Batch {batch_idx})")
# plt.legend()
# plt.tight_layout()
# plt.show()

In [None]:
import torch
import numpy as np
import matplotlib.pyplot as plt

# Suppose you already have:
# preds_test -> shape (108, 12, 13248, 2)
# gts_test   -> shape (1296, 13248, 2)
# and you have confirmed 64*207 = 13248, etc.

preds_np = np.array(preds_test)
gts_np   = np.array(gts_test)

num_batches, forecast_steps, flat_dim, features = preds_np.shape  # (108, 12, 13248, 2)
print("Predictions shape:", preds_np.shape)
print("Ground truth shape:", gts_np.shape)

# Reshape ground truth to match predictions' shape on the time dimension
expected_ground_truth_size = num_batches * forecast_steps  # 108*12 = 1296
if gts_np.shape[0] != expected_ground_truth_size:
    raise ValueError("Ground truth shape mismatch.")

gts_np_reshaped = gts_np.reshape(num_batches, forecast_steps, flat_dim, features)
print("Reshaped ground truth shape:", gts_np_reshaped.shape)

# Merge all batches into a single time dimension: (num_batches * forecast_steps, batch_size, num_nodes, features)
# In your case, batch_size=64, num_nodes=207 -> flat_dim = 13248
preds_merged = preds_np.reshape(num_batches * forecast_steps, 64, 207, features)
gts_merged   = gts_np_reshaped.reshape(num_batches * forecast_steps, 64, 207, features)

# Pick a single sample in the batch and a single sensor
sample_idx = 0   # which of the 64 samples you want to see
sensor_idx = 10  # which sensor among the 207

# Extract time series across the entire test set
pred_series = preds_merged[:, sample_idx, sensor_idx, 0]  # shape: (num_batches*forecast_steps,)
gt_series   = gts_merged[:, sample_idx, sensor_idx, 0]

# ---------------------------------------------------------
# Example: plotting only the first 24 hours 
# (assuming each step is 5 minutes -> 12 steps/hour -> 288 steps/day)
# If your data actually covers exactly 24 hrs total, skip slicing.
# ---------------------------------------------------------
steps_per_hour = 12  # e.g., 5-min intervals -> 12 steps in 1 hour
steps_24hrs    = 24 * steps_per_hour  # 288

# Make sure your total length is >= 288 if you want 24 hours
max_t = min(steps_24hrs, len(pred_series))
pred_series_24h = pred_series[:max_t]
gt_series_24h   = gt_series[:max_t]

# Create time axis in hours (0..24)
time_axis = np.linspace(0, 24, max_t, endpoint=False)

plt.figure(figsize=(10,4))
plt.plot(time_axis, gt_series_24h, label="Ground Truth", color="blue", linewidth=2)
plt.plot(time_axis, pred_series_24h, label="DCRNN Prediction", color="orange", linestyle="--", linewidth=2)
plt.xlabel("Time (Hours)")
plt.ylabel("Traffic Speed (mph)")
plt.title(f"24-Hour Forecast - Sensor {sensor_idx}, Sample {sample_idx}")
plt.legend()
plt.tight_layout()
plt.show()