# Spatio-Temporal Traffic Forecasting with Neural Graph Cellular Automata
Petrônio C.  L. Silva  <span itemscope itemtype="https://schema.org/Person"><a itemprop="sameAs" content="https://orcid.org/0000-0002-1202-2552" href="https://orcid.org/0000-0002-1202-2552" target="orcid.widget" rel="noopener noreferrer" style="vertical-align:top;"><img src="https://orcid.org/sites/default/files/images/orcid_16x16.png" style="width:1em;margin-right:.5em;" alt="ORCID iD icon"></a></span>, Omid Orang  <span itemscope itemtype="https://schema.org/Person"><a itemprop="sameAs" content="https://orcid.org/0000-0002-4077-3775" href="https://orcid.org/0000-0002-4077-3775" target="orcid.widget" rel="noopener noreferrer" style="vertical-align:top;"><img src="https://orcid.org/sites/default/files/images/orcid_16x16.png" style="width:1em;margin-right:.5em;" alt="ORCID iD icon"></a></span>, Lucas Astore, Frederico G. Guimarães <span itemscope itemtype="https://schema.org/Person"><a itemprop="sameAs" content="https://orcid.org/0000-0001-9238-8839" href="https://orcid.org/0000-0001-9238-8839" target="orcid.widget" rel="noopener noreferrer" style="vertical-align:top;"><img src="https://orcid.org/sites/default/files/images/orcid_16x16.png" style="width:1em;margin-right:.5em;" alt="ORCID iD icon"></a></span>

In case you have any questions, do not hesitate in contact us using the following e-mail: petronio.candido@ifnmg.edu.br


## Imports

In [1]:
import torch
from torch import nn

from st_nca.common import resume, get_device, checkpoint_all
from st_nca.datasets.PEMS import PEMS03, get_config as pems_get_config
from st_nca.cellmodel import CellModel, load_config, get_config
from st_nca.gca import GraphCellularAutomata, timestamp_generator
from st_nca.finetune import FineTunningDataset, finetune_loop
from st_nca.evaluate import evaluate, diff_states

## Data

In [None]:
def save_as(from_file, to_file, pems, NTRANSF, NHEADS, NTRANSFF, TRANSFACT, MLP, MLPD, MLPACT, 
           DEVICE = get_device(), DTYPE = torch.float32):
    model = CellModel(num_tokens = pems.max_length, dim_token= pems.token_dim,
               num_transformers = NTRANSF, num_heads = NHEADS, transformer_feed_forward= NTRANSFF, 
               transformer_activation = TRANSFACT,
               feed_forward = MLP, feed_forward_dim = MLPD, feed_forward_activation = MLPACT,
               device = DEVICE, dtype = DTYPE)
    model.load_state_dict(torch.load(from_file, 
                                 weights_only=True,
                                 map_location=torch.device(get_device())), strict=False)
    torch.save({
        'config': get_config(model, **pems_get_config(pems)),
        "weights": model.state_dict() }, 
        to_file)
    
    
    

def setup(file, pems):
    saved_config = torch.load(file)
    tmp = load_config(saved_config['config'])
    tmp.load_state_dict(saved_config['weights'], strict=False)
    pems.steps_ahead = saved_config['steps_ahead']
    return tmp, pems

In [2]:
DEVICE = get_device()
DTYPE = torch.float32
#DEFAULT_PATH = 'C:\\Users\\petro\\Dropbox\\Projetos\\futurelab\\posdoc\\st_nca\\st_nca\\st_nca\\'
DEFAULT_PATH = 'D:\\Dropbox\\Projetos\\futurelab\\posdoc\\st_nca\\st_nca\\st_nca\\'
DATA_PATH = DEFAULT_PATH + 'data\\PEMS03\\'
MODELS_PATH = DEFAULT_PATH + 'weights\\PEMS03\\'

NTRANSF = 3
NHEADS = 16
NTRANSFF = 1024
TRANSFACT = nn.GELU()
MLP = 3
MLPD = 1024
MLPACT = nn.GELU()
STEPS_AHEAD = 12
ITERATIONS = 1

pems = PEMS03(edges_file = DATA_PATH + 'edges.csv', nodes_file = DATA_PATH + 'nodes.csv', data_file = DATA_PATH + 'data.csv',
    device = DEVICE, dtype = DTYPE, steps_ahead = STEPS_AHEAD)


In [5]:
model = CellModel(num_tokens = pems.max_length, dim_token= pems.token_dim,
               num_transformers = NTRANSF, num_heads = NHEADS, transformer_feed_forward= NTRANSFF, 
               transformer_activation = TRANSFACT,
               feed_forward = MLP, feed_forward_dim = MLPD, feed_forward_activation = MLPACT,
               device = DEVICE, dtype = DTYPE)

file = MODELS_PATH + 'h12_cell_model_{}_{}_{}_{}_{}.h5'.format(NTRANSF,NHEADS,NTRANSFF,MLP,MLPD)

#resume(cm, )

#resume(model, MODELS_PATH + 'UlGfLpFGDbWjtwcWVn59C_weights.h5')

model.load_state_dict(torch.load(file, 
                                 weights_only=True,
                                 map_location=torch.device(get_device())), strict=False)

get_config(model, **pems_get_config(pems))

{'num_heads': 16,
 'normalization': torch.nn.modules.normalization.LayerNorm,
 'pre_norm': False,
 'transformer_feed_forward': 1024,
 'transformer_activation': GELU(approximate='none'),
 'num_tokens': 7,
 'dim_token': 7,
 'num_transformers': 3,
 'feed_forward': 3,
 'feed_forward_dim': 1024,
 'feed_forward_activation': GELU(approximate='none'),
 'device': 'cpu',
 'dtype': torch.float32,
 'steps_ahead': 12}

In [6]:
torch.save({
    'config': get_config(model, **pems_get_config(pems)),
    "weights": model.state_dict()
}, file + "NEW")

In [7]:
saved_config = torch.load(file + "NEW")
tmp = load_config(saved_config['config'])
tmp.load_state_dict(saved_config['weights'], strict=False)

  saved_config = torch.load(file + "NEW")


<All keys matched successfully>

In [12]:
from st_nca.finetune import FineTunningDataset, finetune_loop

#ds = pems.get_allsensors_dataset(behavior='deterministic')

model = CellModel(num_tokens = pems.max_length, dim_token = pems.token_dim,
               num_transformers = NTRANSF, num_heads = NHEADS, feed_forward = NTRANSFF, 
               transformer_activation = TRANSFACT,
               mlp = MLP, mlp_dim = MLPD, mlp_activation = MLPACT,
               device = DEVICE, dtype = DTYPE)

file = MODELS_PATH + 'h12_cell_model_{}_{}_{}_{}_{}.h5'.format(NTRANSF,NHEADS,NTRANSFF,MLP,MLPD)

#resume(cm, )

#resume(model, MODELS_PATH + 'UlGfLpFGDbWjtwcWVn59C_weights.h5')

model.load_state_dict(torch.load(file, 
                                 weights_only=True,
                                 map_location=torch.device(get_device())), strict=False)


gca = GraphCellularAutomata(device=DEVICE, dtype=DTYPE, graph = pems.G,
                            max_length = pems.max_length, token_size=pems.token_dim,
                            tokenizer=pems.tokenizer, cell_model = model)

#finetune_ds = FineTunningDataset(pems, increment_type='minutes', increment=5, 
finetune_ds = FineTunningDataset(pems, increment_type='minutes', increment=5, 
                                 steps_ahead=12, step=250)

#X,_ = finetune_ds[0]
#y,_ = finetune_ds[1]

#p = gca.run_dict(X, ITERATIONS,increment_type='minute', increment=5)

#diff_states(y, p[0])

## Fine Tunning

In [None]:
finetune_loop(DEVICE, finetune_ds, gca, 
#              iterations = ITERATIONS, increment_type='minutes', increment=5,
              iterations = 1, increment_type='hours', increment=1,
              epochs = 150, batch = 1, lr = 0.00001,
              checkpoint_file =  MODELS_PATH + 'h12_gca_{}_{}_{}_{}_{}.pt'.format(NTRANSF,NHEADS,NTRANSFF,MLP,MLPD))

## Evaluate

In [None]:
gca_file = MODELS_PATH + 'gca_{}_{}_{}_{}_{}.ptBEST'.format(NTRANSF,NHEADS,NTRANSFF,MLP,MLPD)

cell_model = CellModel(num_tokens = pems.max_length, dim_token = pems.token_dim,
               num_transformers = NTRANSF, num_heads = NHEADS, feed_forward = NTRANSFF, 
               transformer_activation = TRANSFACT,
               mlp = MLP, mlp_dim = MLPD, mlp_activation = MLPACT,
               device = DEVICE, dtype = DTYPE)

gca = GraphCellularAutomata(device=DEVICE, dtype=DTYPE, graph = pems.G,
                            max_length = pems.max_length, token_size=pems.token_dim,
                            tokenizer=pems.tokenizer, cell_model = cell_model)


gca.load_state_dict(torch.load(gca_file, 
                                 weights_only=True,
                                 map_location=torch.device(get_device())), strict=False)

ITERATIONS = 12

dataset = FineTunningDataset(pems, increment_type='minutes', increment=5, 
                                 steps_ahead=ITERATIONS, step=10, device = DEVICE)


df = evaluate(dataset.test(), gca, ITERATIONS, increment_type='minutes', increment=5)
df

In [None]:
DEVICE