# Understanding Recurrent Neural Networks

**References:**
* [The Unreasonable Effectiveness of Recurrent Neural Networks by Andrej Karpathy](https://karpathy.github.io/2015/05/21/rnn-effectiveness/)
* [NLP with DL CS224N Lecture 7](https://web.stanford.edu/class/cs224n/slides/cs224n-2019-lecture07-fancy-rnn.pdf)
* [Vanishing And Exploding Gradient Problems by Jefkine](https://www.jefkine.com/general/2018/05/21/2018-05-21-vanishing-and-exploding-gradient-problems/)
* [Why LSTMs Stop Your Gradients From Vanishing: A View from the Backwards Pass by weberna](https://weberna.github.io/blog/2017/11/15/LSTM-Vanishing-Gradients.html)
* [Neural Network (2): RNN and Problems of Exploding/Vanishing Gradient by Liyan Xu](https://liyanxu.blog/2018/11/01/rnn-exploding-vanishing-gradient/)
* [Understanding LSTM Networks by colah](https://colah.github.io/posts/2015-08-Understanding-LSTMs/)
* [Einsum is all you need - Einstein summation in deep learning by Tim Rocktäschel](https://rockt.github.io/2018/04/30/einsum)

In [1]:
import torch
import torch.nn as nn

## Vanilla RNN


Vanilla recurrent neural networks (RNNs) are a class of neural networks that allow for modelling over sequential vectors. A few examples are shown below taken from [Andrej Karpathy's blog](https://karpathy.github.io/2015/05/21/rnn-effectiveness/)

<img src="assets/types_of_networks.png" alt="Drawing" style="width: 800px;"/>

We now discuss a vanilla RNN module as depicted in the image below

<img src="assets/rnn_unroll.png" alt="Drawing" style="width: 800px;"/>

At each timestep, $t$, the RNN module takes as input the previous hidden state $\mathbf{h}_{t-1}\in\mathbb{R}^{d}$ and an input $\mathbf{x}_t\in\mathbb{R}^{k}$, and produces an output $\mathbf{h}_t$ (shown on the left). This can be "unrolled" to easier visualise the behaviour of the RNN (as seen on the right). The internal workings of a vanilla RNN is shown below 

<img src="assets/rnn_internal.png" alt="Drawing" style="width: 700px;"/>

The RNN output is given by,

$$\begin{align*}
\mathbf{h}_t&=\text{tanh}\Big(\mathbf{W}_h\mathbf{h}_{t-1}+\mathbf{W}_x\mathbf{x}_t + \mathbf{b}\Big)\\
&=\text{tanh}\Big(\mathbf{W}[\mathbf{h}_{t-1};\mathbf{x}_t] + \mathbf{b}\Big),
\end{align*}$$

where $[\mathbf{h}_{t-1};\mathbf{x}_t]\in\mathbb{R}^{(d+k)}$ is the concatenation of vectors $\mathbf{h}_{t-1}$ and $\mathbf{x}_t$, $\mathbf{W}\in\mathbb{R}^{d\times(d+k)}$, $\mathbf{W}_h\in\mathbb{R}^{d\times d}$, $\mathbf{W}_x\in\mathbb{R}^{d\times k}$ and $\mathbf{b}\in\mathbb{R}^{d}$ is the bias vector. Here the output vector is just the updated hidden state, i.e. $\mathbf{s}_t=\mathbf{h}_t$. However sometimes we can augment the output vector to be for example

$$\mathbf{s}_t=\mathbf{W}_s\mathbf{h}_t,$$

where $\mathbf{s}_t\in\mathbb{R}^{m}$ and $\mathbf{W}_s\in\mathbb{R}^{m\times d}$. Let our RNN be given by the function $f^{\text{RNN}}_\theta(\mathbf{x}_t,\mathbf{h}_{t-1})=(\mathbf{s}_t,\mathbf{h}_t)$. Then if we have a sequence of inputs $\mathbf{X}=[\mathbf{x}_1,\mathbf{x}_2,...,\mathbf{x}_T]$ and unfold the RNN we obtain:

$$\begin{align*}
\text{unfold}\Big(f^{\text{RNN}}_\theta, \mathbf{X}, \mathbf{h}_0\Big)&=\Big[f_\theta(\mathbf{x}_1,\mathbf{h}_0),f_\theta(\mathbf{x}_2,\mathbf{h}_1),...,f_\theta(\mathbf{x}_T,\mathbf{h}_{T-1})\Big]\\
&=\Big[(\mathbf{s}_1,\mathbf{h}_1),(\mathbf{s}_2,\mathbf{h}_2),...,(\mathbf{s}_T,\mathbf{h}_T)\Big]
\end{align*}.$$

### Backpropagation Through Time (BPTT)

The discrepency between output $\mathbf{s}_t$ and desired label $\mathbf{y}_t$ is evaluated by a loss function across all $T$ timesteps as

$$\mathcal{L}(\mathbf{x}_1,...,\mathbf{x}_T,\mathbf{y}_1,...,\mathbf{y}_T)=\frac{1}{T}\sum_{t=1}^{T}\ell(\mathbf{y}_t,\mathbf{s}_t).$$

The goal is now to calculate the gradients of our loss function w.r.t. the weights $\mathbf{b}$, $\mathbf{W}_h$, $\mathbf{W}_x$ and $\mathbf{W}_s$:

1. The derivative of the loss w.r.t. $\mathbf{b}$ 

$$\begin{align*}
\frac{\partial \mathcal{L}}{\partial \mathbf{b}}&=\frac{1}{T}\sum^{T}_{t=1}\frac{\partial \ell(\mathbf{y}_t,\mathbf{s}_t)}{\partial \mathbf{b}}\\
&=\frac{1}{T}\sum^{T}_{t=1}\frac{\partial \ell(\mathbf{y}_t,\mathbf{s}_t)}{\partial \mathbf{s}_t}\frac{\partial \mathbf{s}_t}{\partial \mathbf{h}_t}\frac{\partial \mathbf{h}_t}{\partial \mathbf{b}}.
\end{align*}$$

2. The derivative of the loss w.r.t. $\mathbf{W}_s$ 

$$\begin{align*}
\frac{\partial \mathcal{L}}{\partial \mathbf{W}_s}&=\frac{1}{T}\sum^{T}_{t=1}\frac{\partial \ell(\mathbf{y}_t,\mathbf{s}_t)}{\partial \mathbf{W}_s}\\
&=\frac{1}{T}\sum^{T}_{t=1}\frac{\partial \ell(\mathbf{y}_t,\mathbf{s}_t)}{\partial \mathbf{s}_t}\frac{\partial \mathbf{s}_t}{\partial \mathbf{W}_s}.
\end{align*}$$

3. The derivative of the loss w.r.t. $\mathbf{W}_h$ 

$$\begin{align*}
\frac{\partial \mathcal{L}}{\partial \mathbf{W}_h}&=\frac{1}{T}\sum^{T}_{t=1}\frac{\partial \ell(\mathbf{y}_t,\mathbf{s}_t)}{\partial \mathbf{W}_h}\\
&=\frac{1}{T}\sum^{T}_{t=1}\frac{\partial \ell(\mathbf{y}_t,\mathbf{s}_t)}{\partial \mathbf{s}_t}\frac{\partial \mathbf{s}_t}{\partial \mathbf{h}_t}\frac{\partial \mathbf{h}_t}{\partial \mathbf{W}_h}\\
&=\frac{1}{T}\sum^{T}_{t=1}\frac{\partial \ell(\mathbf{y}_t,\mathbf{s}_t)}{\partial \mathbf{s}_t}\frac{\partial \mathbf{s}_t}{\partial \mathbf{h}_t}\sum^{t}_{k=1}\frac{\partial \mathbf{h}_t}{\partial \mathbf{h}_k}\frac{\partial \mathbf{h}_k}{\partial \mathbf{W}_h}\\
&=\frac{1}{T}\sum^{T}_{t=1}\frac{\partial \ell(\mathbf{y}_t,\mathbf{s}_t)}{\partial \mathbf{s}_t}\frac{\partial \mathbf{s}_t}{\partial \mathbf{h}_t}\sum^{t}_{k=1}\Big(\prod^{t-1}_{j=k}\frac{\partial \mathbf{h}_{j+1}}{\partial \mathbf{h}_{j}}\Big)\frac{\partial \mathbf{h}_k}{\partial \mathbf{W}_h},
\end{align*}$$ 

where we have used the multivariate chain rule such that for a function $z=f\big(x(t),y(t)\big)$, its derivative is given by $\frac{\partial z}{\partial t}=\frac{\partial z}{\partial x}\frac{\partial x}{\partial t}+\frac{\partial z}{\partial y}\frac{\partial y}{\partial t}$. Hence we have that

$$\begin{align*}
\frac{\partial \mathbf{h}_t}{\partial \mathbf{W}_h}&= \frac{\partial \mathbf{h}_t}{\partial \mathbf{W}_h}\frac{\partial \mathbf{W}_t}{\partial \mathbf{W}_h}+ \frac{\partial \mathbf{h}_{t}}{\partial \mathbf{h}_{t-1}}\frac{\partial \mathbf{h}_{t-1}}{\partial \mathbf{W}_h}\\
&=\frac{\partial \mathbf{h}_t}{\partial \mathbf{W}_h}+ \frac{\partial \mathbf{h}_{t}}{\partial \mathbf{h}_{t-1}}\big(\frac{\partial \mathbf{h}_{t-1}}{\partial \mathbf{W}_h}\frac{\partial \mathbf{W}_{h}}{\partial \mathbf{W}_h}+\frac{\partial \mathbf{h}_{t-1}}{\partial \mathbf{h}_{t-2}}\frac{\partial \mathbf{h}_{t-2}}{\partial \mathbf{W}_h}\big)\\
&=\frac{\partial \mathbf{h}_t}{\partial \mathbf{h}_t}\frac{\partial \mathbf{h}_t}{\partial \mathbf{W}_h}+ \frac{\partial \mathbf{h}_{t}}{\partial \mathbf{h}_{t-1}}\frac{\partial \mathbf{h}_{t-1}}{\partial \mathbf{W}_h}+\frac{\partial \mathbf{h}_{t}}{\partial \mathbf{h}_{t-2}}\frac{\partial \mathbf{h}_{t-2}}{\partial \mathbf{W}_h}\\
&\vdots\\
&=\sum^{t}_{k=1}\frac{\partial \mathbf{h}_t}{\partial \mathbf{h}_k}\frac{\partial \mathbf{h}_k}{\partial \mathbf{W}_h}.
\end{align*}$$

Additionally we can use the chain rule again on $\frac{\partial \mathbf{h}_t}{\partial \mathbf{h}_k}$ which involves the product of Jacobians $\frac{\partial \mathbf{h}_{i}}{\partial \mathbf{h}_{i-1}}$ over subsequences linking an event at time $t$ and one at time $k$

$$\frac{\partial \mathbf{h}_t}{\partial \mathbf{h}_k}=\frac{\partial \mathbf{h}_t}{\partial \mathbf{h}_{t-1}}\frac{\partial \mathbf{h}_{t-1}}{\partial \mathbf{h}_{t-2}}...\frac{\partial \mathbf{h}_{k+1}}{\partial \mathbf{h}_k}=\prod^{t-1}_{j=k}\frac{\partial \mathbf{h}_{j+1}}{\partial \mathbf{h}_{j}},$$



4. The derivative of the loss w.r.t. $\mathbf{W}_x$ similar to the above

$$\begin{align*}
\frac{\partial \mathcal{L}}{\partial \mathbf{W}_x}&=\frac{1}{T}\sum^{T}_{t=1}\frac{\partial \ell(\mathbf{y}_t,\mathbf{s}_t)}{\partial \mathbf{W}_x}\\
&=\frac{1}{T}\sum^{T}_{t=1}\frac{\partial \ell(\mathbf{y}_t,\mathbf{s}_t)}{\partial \mathbf{s}_t}\frac{\partial \mathbf{s}_t}{\partial \mathbf{h}_t}\frac{\partial \mathbf{h}_t}{\partial \mathbf{W}_x}\\
&=\frac{1}{T}\sum^{T}_{t=1}\frac{\partial \ell(\mathbf{y}_t,\mathbf{s}_t)}{\partial \mathbf{s}_t}\frac{\partial \mathbf{s}_t}{\partial \mathbf{h}_t}\sum^{t}_{k=1}\Big(\prod^{t-1}_{j=k}\frac{\partial \mathbf{h}_{j+1}}{\partial \mathbf{h}_{j}}\Big)\frac{\partial \mathbf{h}_k}{\partial \mathbf{W}_x}.
\end{align*}$$ 

### Vanishing and Exploding Gradients

The product of Jacobians when evaluated gives

$$\frac{\partial \mathbf{h}_t}{\partial \mathbf{h}_k}=\prod^{t-1}_{j=k}\frac{\partial \mathbf{h}_{j+1}}{\partial \mathbf{h}_{j}}=\prod^{t-1}_{j=k}\mathbf{W}_h^{T}\text{diag}\Big(\tanh'\big(\mathbf{W}_h\mathbf{h}_{j}+ \mathbf{W}_x\mathbf{x}_j+\mathbf{b}\big)\Big).$$

Lets look at the L2 matrix norms associated with these Jacobians

$$\bigg\Vert\frac{\partial \mathbf{h}_{j+1}}{\partial \mathbf{h}_{j}}\bigg\Vert\leq\big\Vert\mathbf{W}^{T}_h\big\Vert\big\Vert\text{diag}\Big(\tanh'\big(\mathbf{W}_h\mathbf{h}_{j}+ \mathbf{W}_x\mathbf{x}_j+\mathbf{b}\big)\Big)\big\Vert,$$

where we use the Cauchy-Schwarz inequality. We set $\gamma_w$, the largest eigenvalue associated with $\big\Vert\mathbf{W}^{T}_h\big\Vert$ to be its upper bound, while $\gamma_h$, the largest eignvalue associated with $\big\Vert\text{diag}\Big(\tanh'\big(\mathbf{W}_h\mathbf{h}_{j}+ \mathbf{W}_x\mathbf{x}_j+\mathbf{b}\big)\Big)\big\Vert$ as its corresponding upper bound. Depending on the activation function the upperbound $\gamma_h$ can be:
1. $\gamma_h=1$ for `tanh` activation function,
2. $\gamma_h=\frac{1}{4}$ for `sigmoid` activation function.

This means we can write that 

$$\bigg\Vert\frac{\partial \mathbf{h}_t}{\partial \mathbf{h}_k}\bigg\Vert=\bigg\Vert\prod^{t-1}_{j=k}\frac{\partial \mathbf{h}_{j+1}}{\partial \mathbf{h}_{j}}\bigg\Vert\leq(\gamma_w\gamma_h)^{t-k}.$$

As the sequence gets longer (i.e the distance between $t$ and $k$ increases), then the value of $\gamma$ will determine if the gradient either explodes or vanishes. Hence if $\gamma<1$ the gradients tend to vanish, while if $\gamma>1$ the gradients tend to explode.

In [6]:
# below we show code for a batched vanilla RNN module.
class VanillaRNN(nn.Module):
    def __init__(self, input_size, hidden_size):
        self.input_size = input_size
        self.hidden_size = hidden_size 
    
        self.W = torch.rand((hidden_size, hidden_size+input_size))
        self.b = torch.zeros(hidden_size, 1)
        
    def f(self, x, h):
        """
        x : [batch_size, input_size]
        h : [batch_size, hidden_size]        
        """
        x = torch.cat([h, x], dim=1) # [batch_size, hidden_size+input_size]
        h = torch.tanh(torch.einsum("ij,kj->ki", [self.W, x]) + self.b.T)
        s = h
        return s, h
        
    def forward(self, X, h):
        """
        x : [batch_size, seq_len, input_size]
        h : [batch_size, hidden_size]
        S : [batch_size, seq_len, hidden_size]
        """
        S = torch.zeros(batch_size, X.size(1), hidden_size)
        # unroll
        for i in range(X.size(1)):
            s, h = self.f(X[:,i,:], h)
            S[:,i,:] = s
        return S, h
        
    def init_h(self, batch_size, hidden_size):
        return torch.zeros(batch_size, hidden_size)

batch_size = 4
seq_len = 10
input_size = 5
hidden_size = 3

rnn = VanillaRNN(input_size=input_size, hidden_size=hidden_size)
h = rnn.init_h(batch_size=batch_size, hidden_size=hidden_size)
X = torch.randn(batch_size, seq_len, input_size)
S, h = rnn.forward(X, h)

## Stacking RNNs

We can also stack RNNs as shown below

<img src="assets/stacked_rnn.png" alt="Drawing" style="width: 600px;"/>

Here the first RNN layer takes as input the input vector, $\mathbf{x}_t$, and the initial first layer hidden state $\mathbf{h}_0^{1}$. The $n^{\text{th}}$ RNN layer takes as input the output from the previous RNN layer, e.g. $\mathbf{s}^{n-1}_1$ and the initial current layer hidden state $\mathbf{h}_0^{n}$.

Let the RNN be given by the function $f^{\text{RNN}}_\theta(\mathbf{x}_t,\mathbf{h}_{t-1})=(\mathbf{s}_t,\mathbf{h}_t)$. Then if we have a sequence of inputs $\mathbf{X}=[\mathbf{x}_1,\mathbf{x}_2,...,\mathbf{x}_T]$ and unfold the first layer RNN we obtain:

$$\begin{align*}
\text{unfold}\Big(f^{1}_\theta, \mathbf{X}, \mathbf{h}^{1}_0\Big)&=\Big[f^{1}_\theta(\mathbf{x}_1,\mathbf{h}^{1}_0),f^{1}_\theta(\mathbf{x}_2,\mathbf{h}^{1}_1),...,f^{1}_\theta(\mathbf{x}_T,\mathbf{h}^{1}_{T-1})\Big]\\
&=\Big[(\mathbf{s}^{1}_1,\mathbf{h}^{1}_1),(\mathbf{s}^{1}_2,\mathbf{h}^{1}_2),...,(\mathbf{s}^{1}_T,\mathbf{h}^{1}_T)\Big].
\end{align*}$$

For the $n^{\text{th}}$ RNN layer if we have a sequence of inputs $\mathbf{S}^{n-1}=[\mathbf{s}^{n-1}_1,\mathbf{s}^{n-1}_2,...,\mathbf{s}^{n-1}_T]$ and unfold the first layer RNN we obtain:

$$\begin{align*}
\text{unfold}\Big(f^{n}_\theta, \mathbf{S}^{n-1}, \mathbf{h}^{n}_0\Big)&=\Big[f^{n}_\theta(\mathbf{s}^{n-1}_1,\mathbf{h}^{n}_0),f^{n}_\theta(\mathbf{s}^{n-1}_2,\mathbf{h}^{n}_1),...,f^{n}_\theta(\mathbf{s}^{n-1}_T,\mathbf{h}^{n}_{T-1})\Big]\\
&=\Big[(\mathbf{s}^{n}_1,\mathbf{h}^{n}_1),(\mathbf{s}^{n}_2,\mathbf{h}^{n}_2),...,(\mathbf{s}^{n}_T,\mathbf{h}^{n}_T)\Big].
\end{align*}$$

Thus for a stacked RNN we have that:
* Hidden state: $\mathbf{h}_t\in\mathbb{R}^{n\times d}$, where $n$ is the number of RNN layers,
* Output vector: $\mathbf{s}_t\in\mathbb{R}^{m}$, where we take the output vector from the final RNN layer.

## Bidirectional RNNs

**TODO**


## Long Short-Term Memory

LSTMs are a type of RNN proposed by  proposed by [Hochreiter and Schmidhuber in 1997](http://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.676.4320&rep=rep1&type=pdf) as a
solution to the vanishing gradients problem. In addition to a hidden state, $\mathbf{h}_t\in\mathbb{R}^{d}$, that is taken as input the LSTM also has a cell state, $\mathbf{c}_t\in\mathbb{R}^{d}$. The cell state stores long-term information and the LSTM can remove, add and read information from the cell state, which is achieved through special gates in the internal structure of the LSTM depicted below

<img src="assets/lstm_internal.png" alt="Drawing" style="width: 700px;"/>

All the gates take as input the previous hidden state $\mathbf{h}_{t-1}$ and the current input state $\mathbf{x}_t$ and through matrix-vector multiplication and a non-linear transformation output a vector with the same dimensions as the cell state $\mathbf{c}_{t-1}$.

1. **Forget gate:** through a sigmoid activation function outputs a vector $\mathbf{f}_t$ whose elements take values in $[0,1]$. Intuitively the forget gate decides what information should be kept or forgotten in the previous cell state $\mathbf{c}_{t-1}$. This is achieved through element-wise multiplication of $\mathbf{f}_t$ with the previous cell state $\mathbf{c}_{t-1}$. A value of $0$ in $\mathbf{f}_t$ corresponds to "fully forget" the corresponding element in $\mathbf{c}_{t-1}$, whereas a value of $1$ in $\mathbf{f}_t$ corresponds to "fully remember" the corresponding element in $\mathbf{c}_{t-1}$,

$$\mathbf{f}_t=\sigma\Big(\mathbf{W}_f[\mathbf{h}_{t-1};\mathbf{x}_t]+\mathbf{b}_f\Big).$$

2. **New memory cell gate:** through the `tanh` activation function outputs a vector $\tilde{\mathbf{C}}_t$ whose elements take values in $[-1,+1]$. Intuitively, the gate uses the current input $\mathbf{x}_t$ and previous hidden state $\mathbf{h}_{t-1}$ to generate a new memory $\tilde{\mathbf{C}}_t$, which includes aspects of the new input $\mathbf{x}_t$, that could be added to the previous cell state $\mathbf{c}_{t-1}$,

$$\tilde{\mathbf{C}}_t=\text{tanh}\Big(\mathbf{W}_C[\mathbf{h}_{t-1};\mathbf{x}_t]+\mathbf{b}_C\Big).$$

3. **Input gate:** through the sigmoid activation function outputs a vector $\mathbf{i}_t$, whose elements take values in $[0,1]$. Intuitively the input gate decides which values in the previous cell state $\mathbf{c}_{t-1}$ we will update. This is achieved through element-wise multiplication of $\mathbf{i}_t$ with the candidate values from the new memory cell gate $\tilde{\mathbf{C}}_t$. The resulting vector is then added element-wise to the previous cell state $\mathbf{c}_{t-1}$. A value of $0$ in $\mathbf{i}_t$ corresponds to "not important - forget" the corresponding element in $\tilde{\mathbf{C}}_{t}$ and thus do not update the corresponding element in the previous cell state $\mathbf{c}_{t-1}$. Whereas a value of $1$ in $\mathbf{i}_t$ corresponds to "very important - keep" the corresponding element in $\tilde{\mathbf{C}}_{t}$ and thus update the corresponding element in the previous cell state $\mathbf{c}_{t-1}$,

$$\mathbf{i}_t=\sigma\Big(\mathbf{W}_i[\mathbf{h}_{t-1};\mathbf{x}_t]+\mathbf{b}_i\Big).$$

4. **Output gate:** through the sigmoid activation function outputs a vector $\mathbf{o}_t$, whose elements take values in $[0,1]$. Intuitively, the output gate controls what part of the cell state $\mathbf{c}_t$ are output to the hidden state $\mathbf{h}_t$. The current cell state $\mathbf{c}_t$ contains a lot of information that is not necessarily required to be saved in the hidden state $\mathbf{h}_t$. The output gate makes the assessment regarding what parts of the memory $\mathbf{c}_t$ needs to be present in the hidden state $\mathbf{h}_t$. This is achieved through element-wise multiplication of $\mathbf{o}_t$ with the point-wise `tanh` of the current cell state $\mathbf{c}_t$. A value of $0$ in $\mathbf{o}_t$ corresponds to "not necessary - forgot" the corresponding element in $\tanh(\mathbf{c}_t)$, whereas a value of $1$ in $\mathbf{o}_t$ corresponds to "necessary - keep" the corresponding element in $\tanh(\mathbf{c}_t)$,

$$\mathbf{o}_t=\sigma\Big(\mathbf{W}_o[\mathbf{h}_{t-1};\mathbf{x}_t]+\mathbf{b}_o\Big).$$

In [7]:
# below we show code for a batched LSTM module.
class LSTM(nn.Module):
    def __init__(self, input_size, hidden_size):
        self.input_size = input_size
        self.hidden_size = hidden_size 
    
        self.W = torch.rand((4*hidden_size, hidden_size+input_size))
        self.b = torch.zeros(4*hidden_size, 1)
        
    def f(self, x, h, c):
        """
        x : [batch_size, input_size]
        h : [batch_size, hidden_size]
        c : [batch_size, hidden_size]
        """
        x = torch.cat([h, x], dim=1) # [batch_size, hidden_size+input_size]
        u = torch.einsum("ij,kj->ki", [self.W, x]) + self.b.T # [batch_size, 4*hidden_size]
        f = torch.sigmoid(u[:, :self.hidden_size])
        C = torch.tanh(u[:, self.hidden_size:2*self.hidden_size])
        i = torch.sigmoid(u[:, 2*self.hidden_size:3*self.hidden_size])
        o = torch.sigmoid(u[:, 3*self.hidden_size:])
        c = c*f + i*C
        h = o*torch.tanh(c)
        s = h
        return s, h, c
        
    def forward(self, X, h, c):
        """
        x : [batch_size, seq_len, input_size]
        h : [batch_size, hidden_size]
        c : [batch_size, hidden_size]
        S : [batch_size, seq_len, hidden_size]               
        """
        S = torch.zeros(batch_size, X.size(1), hidden_size)         
        # unroll
        for i in range(X.size(1)):
            s, h, c = self.f(X[:,i,:], h, c)
            S[:,i,:] = s
        return S, h, c
        
    def init_h_c(self, batch_size, hidden_size):
        return torch.zeros(batch_size, hidden_size), torch.zeros(batch_size, hidden_size)
    
batch_size = 4
seq_len = 10
input_size = 5
hidden_size = 3

lstm = LSTM(input_size=input_size, hidden_size=hidden_size)
h, c = lstm.init_h_c(batch_size=batch_size, hidden_size=hidden_size)
X = torch.randn(batch_size, seq_len, input_size)
S, h, c = lstm.forward(X, h, c)