In [1]:
import torch
import torch.nn as nn
import numpy as np
from copy import deepcopy


class RNN(nn.Module):
    def __init__(self, num_series, hidden, nonlinearity):
        '''
        RNN model with output layer to generate predictions.
        Args:
          num_series: number of input time series.
          hidden: number of hidden units.
        '''
        super(RNN, self).__init__()
        self.p = num_series
        self.hidden = hidden

        # Set up network.
        self.rnn = nn.RNN(num_series, hidden, nonlinearity=nonlinearity,
                          batch_first=True)
        self.rnn.flatten_parameters()
        self.linear = nn.Conv1d(hidden, 1, 1)

    def init_hidden(self, batch):
        '''Initialize hidden states for RNN cell.'''
        device = self.rnn.weight_ih_l0.device
        return torch.zeros(1, batch, self.hidden, device=device)

    def forward(self, X, hidden=None, truncation=None):
        # Set up hidden state.
        if hidden is None:
            hidden = self.init_hidden(X.shape[0])

        # Apply RNN.
        X, hidden = self.rnn(X, hidden)

        # Calculate predictions using output layer.
        X = X.transpose(2, 1)
        X = self.linear(X)
        return X.transpose(2, 1), hidden