Clean repo for tensor-train RNN implemented in TensorFlow
Switch branches/tags
Nothing to show
Clone or download
Rose Yu
Rose Yu retry
Latest commit 59c2393 Mar 14, 2018
Failed to load latest commit information. retry Mar 15, 2018
data.npy initial commit Nov 15, 2017 initial commit Nov 15, 2017 initial commit Nov 15, 2017
test_trnn.ipynb initial commit Nov 15, 2017
tlstm.png refresh readme Mar 14, 2018 initial commit Nov 15, 2017 initial commit Nov 15, 2017 initial commit Nov 15, 2017

Tensor Train Recurrent Neural Network

Build Status Coverage Status Dependency Status

Clean code repo for tensor train recurrent neural network, implemented in Tensorflow. See details in our paper Long-Term Forecasting with Tensor Train RNNs

Getting Started

install prerequisites

  • tensorflow >= r1.6
  • Python >=3.0
  • Jupyter >=4.1.1

import module

from trnn import TensorLSTMCell
from trnn_imply import tensor_rnn_with_feed_prev


  • TensorLSTMCell(num_units, num_lags, rank_vals) – creates a TensorTrainLSTM object with num_units hidden nodes, num_lags time lags, with rank_vals is the list of values for tensor train decomposition rank


  • tensor_rnn_with_feed_prev – forward pass for a single TensorTrainLSTM cell, returns an output and a hidden state.

Running the test

Run the Jupyter notebook

  • jupyter notebook test_trnn.pynb

A simple example of using TensorTrainLSTM by

  • loading a set of sim sequences
  • building a tensor train Seq2Seq model
  • making long-term predictions


  • read the data into train/valid/test datasets, normalize the data if needed

  • seq2seq model for sequence prediction

  • tensor-train lstm cell and corresponding tensor train contraction

  • forward step in tensor-train rnn, feed previous predictions as input


If you think the repo is useful, we kindly ask you to cite our work at

  title={Long-term forecasting using tensor-train RNNs},
  author={Yu, Rose and Zheng, Stephan and Anandkumar, Anima and Yue, Yisong},
  journal={arXiv preprint arXiv:1711.00073},