# STEN Forecasting demo

In this notebook, we will:
1. **Load data**: Load pytorch geometric temporal data and split into torch train test val dataloaders.
2. **Model Building**: We will build and train a STEN model - the DenseGCNGRU model using the processed data.
3. **Model Evaluation**: We will print and plot model outputs to look at its performance. 

In [2]:
# import all the necessary libraries
import os
import torch

# for relative imports
os.chdir('..') 
print(os.getcwd()) # should print /your_local_dir/Campus-Crowd

# cuda or cpu
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

/Users/vivianwong/Documents/Research Codes/Campus-Crowd
cpu


In [3]:
class Args:
    def __init__(self):
        self.DATASET = 'Stadium'
        self.forecasting_horizon = 20
        self.train_ratio = 0.7
        self.test_ratio = 0.3
        self.val_ratio = 0.0
        self.batch_size = 32
        self.lr = 0.001
        self.epochs = 40
        self.save_model = False
        self.save_dir = './checkpoints'
args = Args()

## 1. Load data. this step is the same as one shown in demo_dataset.ipynb. 

In [4]:
from campuscrowd.data_utils import get_pyg_temporal_dataset, get_loaders
# get pytorch dataloaders
dataset, _ = get_pyg_temporal_dataset(args.DATASET, args.forecasting_horizon)
train_loader, val_loader, test_loader = get_loaders(dataset, 
                                                    args.batch_size, 
                                                    args.train_ratio, 
                                                    args.val_ratio, 
                                                    args.test_ratio, 
                                                    device)



Dataset type:   <torch_geometric_temporal.signal.static_graph_temporal_signal.StaticGraphTemporalSignal object at 0x107843eb0>
Number of samples / sequences:  2356
Data(x=[6, 2, 20], edge_index=[2, 10], edge_attr=[10], y=[6, 20])
Number of train buckets:  1649
Number of val buckets:  0
Number of test buckets:  707


In order to run a graph neural network, the inputs are $A$, the adjacency matrix, and $X$, the node feature matrix. Since our $A$ is defined by inter-PAR connections and is assumed to never change, we can reduce computational time by only loading $A$ once. In PyTorch Geometric, $A$ is represented as the edge_index object.

In [5]:
# get static edge index (i.e. adjacency matrix). Only need to do this one since edge index doesn't change for each CMGraph. 
for snapshot in dataset:
    static_edge_index = snapshot.edge_index.to(device)
    break;
# Edge indices (represents adjacency matrix/PAR connections) of the CMGraphs. 
print(static_edge_index)

tensor([[0, 1, 1, 2, 2, 3, 3, 4, 4, 5],
        [1, 0, 2, 1, 3, 2, 4, 3, 5, 4]])


## 2. Set up and Train a STEN Model

STEN stands for spatio-temporal encoder network. It is a framework for crowd flow forecasting that involves spatially connected pedestrian activity regions (PARs). In this repo we have provided two STEN models for easy plug in and play. We'll use the best performing mode, Dense-GCN-GRU here. The model can simply be called from our STEN model zoo.

In [6]:
from campuscrowd.models import DenseGCNGRU
model = DenseGCNGRU(in_channels=2, 
                    periods=args.forecasting_horizon, 
                    batch_size=args.batch_size).to(device) 
print(model)

DenseGCNGRU(
  (densegcn): DeepGCNLayer(block=dense)
  (gru): GRU(130, 64, num_layers=2, batch_first=True)
  (fc): Linear(in_features=64, out_features=20, bias=True)
)


To train the model
with the data, we can use the train function from campuscrowd.train_test_utils. This function trains the model and return the model checkpoint. Inside train(), there is a training loop that essentially computes loss and does backprop based on the model-generated prediction vector y_hat. y_hat is computed in the following code snippet: 
```python
 for encoder_inputs, labels in train_loader:
    y_hat = model(encoder_inputs, static_edge_index)
    # torch.tensor storing model predictions. Full training loop omitted for conciseness.
``` 

In [9]:
from campuscrowd.train_test_utils import train
# train model
model, checkpoint_dict = train( model, 
                                train_loader, 
                                val_loader, 
                                static_edge_index, 
                                num_epochs=args.epochs, lr=args.lr
                                )


Epoch 0 Step 0 train MSE: 0.8875
Epoch 0 Average Training MSE: 0.3191
Epoch 1 Step 0 train MSE: 0.0561
Epoch 1 Average Training MSE: 0.0458
Epoch 2 Step 0 train MSE: 0.0387
Epoch 2 Average Training MSE: 0.0413
Epoch 3 Step 0 train MSE: 0.0415
Epoch 3 Average Training MSE: 0.0404
Epoch 4 Step 0 train MSE: 0.0302
Epoch 4 Average Training MSE: 0.0395
Epoch 5 Step 0 train MSE: 0.0390
Epoch 5 Average Training MSE: 0.0393
Epoch 6 Step 0 train MSE: 0.0346
Epoch 6 Average Training MSE: 0.0382
Epoch 7 Step 0 train MSE: 0.0381
Epoch 7 Average Training MSE: 0.0380
Epoch 8 Step 0 train MSE: 0.0362
Epoch 8 Average Training MSE: 0.0380
Epoch 9 Step 0 train MSE: 0.0485
Epoch 9 Average Training MSE: 0.0374
Epoch 10 Step 0 train MSE: 0.0382
Epoch 10 Average Training MSE: 0.0370
Epoch 11 Step 0 train MSE: 0.0433
Epoch 11 Average Training MSE: 0.0369
Epoch 12 Step 0 train MSE: 0.0404
Epoch 12 Average Training MSE: 0.0370
Epoch 13 Step 0 train MSE: 0.0385
Epoch 13 Average Training MSE: 0.0365
Epoch 14 Ste

You can save the model checkpoint if you want

In [8]:
from campuscrowd import save_or_update_checkpoint
 # save model if needed
if args.save_model:
    filename = model.__class__.__name__+'_'+args.DATASET+'_'+'{}_steps'.format(args.forecasting_horizon)+'.pt'
    path = os.path.join(args.save_dir,
                        filename)
    save_or_update_checkpoint(checkpoint_dict, path)

## 3. Evaluate MSE of Prediction

Similar to train(), evaluation can be done with the evaluate() function. 
It computes MSE and update it in checkpoint_dict

In [11]:
from campuscrowd import evaluate
model, checkpoint_dict = evaluate(model, test_loader, static_edge_index, checkpoint_dict=checkpoint_dict)

Test MSE: 0.0334
Test MAE: 0.1314
