Skip to content

yash-jakhmola/w-dstagnn

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

13 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

W-DSTAGNN

Wavelet-based Temporal Attention Improves Traffic Forecasting

model architecture

Requirements

  • python >= 3.5
  • scipy
  • tensorboardX
  • pytorch
  • scikit-learn
  • tqdm

Datasets

W-DSTAGNN is implemented on these public traffic datasets.

  • PEMSBAY from DCRNN (ICLR-18)
  • PEMS03, PEMS04 from STSGCN (AAAI-20) All datasets can be found in .npz format with the only key being 'data' in Google drive. The shape of input traffic data should be (Total_Time_Steps, Node_Number).

Processing the datasets:

  • on PEMSBAY dataset

    python prepareData.py --config configurations/PEMSBAY_dstagnn.conf
  • on PEMS03 dataset

    python prepareData.py --config configurations/PEMS03_dstagnn.conf
  • on PEMS04 dataset

    python prepareData.py --config configurations/PEMS04_dstagnn.conf

Spatial-Temporal Aware Graph Construction

The graphs can be generated by code:

cd ./data/
python STAG_gen.py

The calculation uses CPU, which should be prepared for enough computation resources.

Train and Test

  • on PEMSBAY dataset

    python train_DSTAGNN.py --config configurations/PEMSBAY_dstagnn.conf   
  • on PEMS03 dataset

    python train_DSTAGNN.py --config configurations/PEMS03_dstagnn.conf   
  • on PEMS04 dataset

    python train_DSTAGNN.py --config configurations/PEMS04_dstagnn.conf   

Configuration

The configuration file config.conf contains two parts: Data, Training:

Data

  • adj_filename: path of the adjacency matrix file
  • graph_signal_matrix_filename: path of graph signal matrix file
  • stag_filename: path of the Spatial-Temporal Aware Grap file
  • strg_filename: path of the Spatial-Temporal Relevance Graph file
  • num_of_vertices: number of vertices
  • points_per_hour: points per hour, in our dataset is 12
  • num_for_predict: points to predict, in our model is 12
  • len_input: length of each input to the model, in our model is 12
  • dataset_name: name of the dataset

Training

  • graph: select the graph structure, G or AG, G stands for adjacency graph, AG stands for Spatial-Temporal Aware Graph
  • ctx: set ctx = cpu, or set gpu-0, which means the first gpu device
  • in_channels: number of channels for spatial convolution
  • n_heads: int, number of temporal att heads will be used
  • K: int, K-order chebyshev polynomials (number of spatial att heads) will be used
  • d_k: int, the dimensions of the Q, K, and V vectors will be used
  • d_model: int, d_E
  • batch_size: int
  • model_name: name of the model
  • num_of_weeks: int, how many weeks' data will be used
  • num_of_days: int, how many days' data will be used
  • num_of_hours: int, how many hours' data will be used
  • epochs: int
  • learning_rate: float

Citation

If you find this repository useful in your research, please cite the following paper:

@article{jakhmola2024spatiotemporal,
  title={Spatiotemporal Forecasting of Traffic Flow using Wavelet-based Temporal Attention},
  author={Jakhmola, Yash and Panja, Madhurima and Mishra, Nitish Kumar and Ghosh, Kripabandhu and Kumar, Uttam and Chakraborty, Tanujit},
  journal={IEEE Access},
  year={2024},
  publisher={IEEE}
}

About

No description, website, or topics provided.

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors

Languages