Clean repo for tensor-train RNN implemented in TensorFlow
Switch branches/tags
Nothing to show
Clone or download
Rose Yu
Rose Yu retry
Latest commit 59c2393 Mar 14, 2018
Permalink
Failed to load latest commit information.
README.md retry Mar 15, 2018
data.npy initial commit Nov 15, 2017
model_seq2seq.py initial commit Nov 15, 2017
reader.py initial commit Nov 15, 2017
test_trnn.ipynb initial commit Nov 15, 2017
tlstm.png refresh readme Mar 14, 2018
train_config.py initial commit Nov 15, 2017
trnn.py initial commit Nov 15, 2017
trnn_imply.py initial commit Nov 15, 2017

README.md

Tensor Train Recurrent Neural Network

Build Status Coverage Status Dependency Status

Clean code repo for tensor train recurrent neural network, implemented in Tensorflow. See details in our paper Long-Term Forecasting with Tensor Train RNNs

Getting Started

install prerequisites

  • tensorflow >= r1.6
  • Python >=3.0
  • Jupyter >=4.1.1

import module

from trnn import TensorLSTMCell
from trnn_imply import tensor_rnn_with_feed_prev

Classes

  • TensorLSTMCell(num_units, num_lags, rank_vals) – creates a TensorTrainLSTM object with num_units hidden nodes, num_lags time lags, with rank_vals is the list of values for tensor train decomposition rank

Methods

  • tensor_rnn_with_feed_prev – forward pass for a single TensorTrainLSTM cell, returns an output and a hidden state.

Running the test

Run the Jupyter notebook

  • jupyter notebook test_trnn.pynb

A simple example of using TensorTrainLSTM by

  • loading a set of sim sequences
  • building a tensor train Seq2Seq model
  • making long-term predictions

Directory

  • reader.py read the data into train/valid/test datasets, normalize the data if needed

  • model.py seq2seq model for sequence prediction

  • trnn.py tensor-train lstm cell and corresponding tensor train contraction

  • trnn_imply.py forward step in tensor-train rnn, feed previous predictions as input

Citation

If you think the repo is useful, we kindly ask you to cite our work at

@article{yu2017long,
  title={Long-term forecasting using tensor-train RNNs},
  author={Yu, Rose and Zheng, Stephan and Anandkumar, Anima and Yue, Yisong},
  journal={arXiv preprint arXiv:1711.00073},
  year={2017}
}