# Long Short-Term Memory (LSTM)
:label:`sec_lstm`

The challenge to address long-term information preservation and short-term input
skipping in latent variable models has existed for a long time. One of the
earliest approaches to address this was the
long short-term memory (LSTM) :cite:`Hochreiter.Schmidhuber.1997`. It shares many of the properties of the
GRU.
Interestingly, LSTMs have a slightly more complex
design than GRUs but predates GRUs by almost two decades.



## Gated Memory Cell

Arguably LSTM's design is inspired
by logic gates of a computer.
LSTM introduces a *memory cell* (or *cell* for short)
that has the same shape as the hidden state
(some literatures consider the memory cell
as a special type of the hidden state),
engineered to record additional information.
To control the memory cell
we need a number of gates.
One gate is needed to read out the entries from the
cell.
We will refer to this as the
*output gate*.
A second gate is needed to decide when to read data into the
cell.
We refer to this as the *input gate*.
Last, we need a mechanism to reset
the content of the cell, governed by a *forget gate*.
The motivation for such a
design is the same as that of GRUs,
namely to be able to decide when to remember and
when to ignore inputs in the hidden state via a dedicated mechanism. Let us see
how this works in practice.


### Input Gate, Forget Gate, and Output Gate

Just like in GRUs,
the data feeding into the LSTM gates are
the input at the current time step and
the hidden state of the previous time step,
as illustrated in :numref:`lstm_0`.
They are processed by
three fully-connected layers with a sigmoid activation function to compute the values of
the input, forget. and output gates.
As a result, values of the three gates
are in the range of $(0, 1)$.

![Computing the input gate, the forget gate, and the output gate in an LSTM model.](https://raw.githubusercontent.com/d2l-ai/d2l-en/master/img/lstm-0.svg)
:label:`lstm_0`

Mathematically,
suppose that there are $h$ hidden units, the batch size is $n$, and the number of inputs is $d$.
Thus, the input is $\mathbf{X}_t \in \mathbb{R}^{n \times d}$ and the hidden state of the previous time step is $\mathbf{H}_{t-1} \in \mathbb{R}^{n \times h}$. Correspondingly, the gates at time step $t$
are defined as follows: the input gate is $\mathbf{I}_t \in \mathbb{R}^{n \times h}$, the forget gate is $\mathbf{F}_t \in \mathbb{R}^{n \times h}$, and the output gate is $\mathbf{O}_t \in \mathbb{R}^{n \times h}$. They are calculated as follows:

$$
\begin{aligned}
\mathbf{I}_t &= \sigma(\mathbf{X}_t \mathbf{W}_{xi} + \mathbf{H}_{t-1} \mathbf{W}_{hi} + \mathbf{b}_i),\\
\mathbf{F}_t &= \sigma(\mathbf{X}_t \mathbf{W}_{xf} + \mathbf{H}_{t-1} \mathbf{W}_{hf} + \mathbf{b}_f),\\
\mathbf{O}_t &= \sigma(\mathbf{X}_t \mathbf{W}_{xo} + \mathbf{H}_{t-1} \mathbf{W}_{ho} + \mathbf{b}_o),
\end{aligned}
$$

where $\mathbf{W}_{xi}, \mathbf{W}_{xf}, \mathbf{W}_{xo} \in \mathbb{R}^{d \times h}$ and $\mathbf{W}_{hi}, \mathbf{W}_{hf}, \mathbf{W}_{ho} \in \mathbb{R}^{h \times h}$ are weight parameters and $\mathbf{b}_i, \mathbf{b}_f, \mathbf{b}_o \in \mathbb{R}^{1 \times h}$ are bias parameters.

### Candidate Memory Cell

Next we design the memory cell. Since we have not specified the action of the various gates yet, we first introduce the *candidate* memory cell $\tilde{\mathbf{C}}_t \in \mathbb{R}^{n \times h}$. Its computation is similar to that of the three gates described above, but using a $\tanh$ function with a value range for $(-1, 1)$ as the activation function. This leads to the following equation at time step $t$:

$$\tilde{\mathbf{C}}_t = \text{tanh}(\mathbf{X}_t \mathbf{W}_{xc} + \mathbf{H}_{t-1} \mathbf{W}_{hc} + \mathbf{b}_c),$$

where $\mathbf{W}_{xc} \in \mathbb{R}^{d \times h}$ and $\mathbf{W}_{hc} \in \mathbb{R}^{h \times h}$ are weight parameters and $\mathbf{b}_c \in \mathbb{R}^{1 \times h}$ is a bias parameter.

A quick illustration of the candidate memory cell is shown in :numref:`lstm_1`.

![Computing the candidate memory cell in an LSTM model.](https://raw.githubusercontent.com/d2l-ai/d2l-en/master/img/lstm-1.svg)
:label:`lstm_1`

### Memory Cell

In GRUs, we have a mechanism to govern input and forgetting (or skipping).
Similarly,
in LSTMs we have two dedicated gates for such purposes: the input gate $\mathbf{I}_t$ governs how much we take new data into account via $\tilde{\mathbf{C}}_t$ and the forget gate $\mathbf{F}_t$ addresses how much of the old memory cell content $\mathbf{C}_{t-1} \in \mathbb{R}^{n \times h}$ we retain. Using the same pointwise multiplication trick as before, we arrive at the following update equation:

$$\mathbf{C}_t = \mathbf{F}_t \odot \mathbf{C}_{t-1} + \mathbf{I}_t \odot \tilde{\mathbf{C}}_t.$$

If the forget gate is always approximately 1 and the input gate is always approximately 0, the past memory cells $\mathbf{C}_{t-1}$ will be saved over time and passed to the current time step.
This design is introduced to alleviate the vanishing gradient problem and to better capture
long range dependencies within sequences.

We thus arrive at the flow diagram in :numref:`lstm_2`.

![Computing the memory cell in an LSTM model.](https://raw.githubusercontent.com/d2l-ai/d2l-en/master/img/lstm-2.svg)

:label:`lstm_2`


### Hidden State

Last, we need to define how to compute the hidden state $\mathbf{H}_t \in \mathbb{R}^{n \times h}$. This is where the output gate comes into play. In LSTM it is simply a gated version of the $\tanh$ of the memory cell.
This ensures that the values of $\mathbf{H}_t$ are always in the interval $(-1, 1)$.

$$\mathbf{H}_t = \mathbf{O}_t \odot \tanh(\mathbf{C}_t).$$


Whenever the output gate approximates 1 we effectively pass all memory information through to the predictor, whereas for the output gate close to 0 we retain all the information only within the memory cell and perform no further processing.



:numref:`lstm_3` has a graphical illustration of the data flow.

![Computing the hidden state in an LSTM model.](https://raw.githubusercontent.com/d2l-ai/d2l-en/master/img/lstm-3.svg)
:label:`lstm_3`



## Implementation from Scratch

Now let us implement an LSTM from scratch.
As same as the experiments in :numref:`sec_rnn_scratch`,
we first load the time machine dataset.


In [4]:
%use @file[../djl.json]
%use lets-plot
@file:DependsOn("../D2J-1.0-SNAPSHOT.jar")
import jp.live.ugai.d2j.timemachine.RNNModelScratch
import jp.live.ugai.d2j.timemachine.TimeMachine.trainCh8
import jp.live.ugai.d2j.timemachine.TimeMachineDataset
import jp.live.ugai.d2j.timemachine.Vocab
import jp.live.ugai.d2j.RNNModel
import jp.live.ugai.d2j.util.StopWatch
import jp.live.ugai.d2j.util.Accumulator
import jp.live.ugai.d2j.util.Training

// %load ../utils/djl-imports
// %load ../utils/plot-utils
// %load ../utils/Functions.java
// %load ../utils/PlotUtils.java

// %load ../utils/StopWatch.java
// %load ../utils/Accumulator.java
// %load ../utils/Animator.java
// %load ../utils/Training.java
// %load ../utils/timemachine/Vocab.java
// %load ../utils/timemachine/RNNModel.java
// %load ../utils/timemachine/RNNModelScratch.java
// %load ../utils/timemachine/TimeMachine.java
// %load ../utils/timemachine/TimeMachineDataset.java
import kotlin.random.Random
import kotlin.collections.List
import kotlin.collections.Map
import kotlin.Pair

In [5]:
val manager = NDManager.newBaseManager();

In [6]:
    val batchSize = 32
    val numSteps = 35

    val dataset = TimeMachineDataset.Builder()
        .setManager(manager)
        .setMaxTokens(10000)
        .setSampling(batchSize, false)
        .setSteps(numSteps)
        .build()
    dataset.prepare()
    val vocab = dataset.vocab

### Initializing Model Parameters

Next we need to define and initialize the model parameters. As previously, the hyperparameter `numHiddens` defines the number of hidden units. We initialize weights following a Gaussian distribution with 0.01 standard deviation, and we set the biases to 0.


In [7]:
    fun normal(shape: Shape, device: Device): NDArray {
        return manager.randomNormal(0.0f, 0.01f, shape, DataType.FLOAT32, device)
    }

    fun three(numInputs: Int, numHiddens: Int, device: Device): NDList {
        return NDList(
            normal(Shape(numInputs.toLong(), numHiddens.toLong()), device),
            normal(Shape(numHiddens.toLong(), numHiddens.toLong()), device),
            manager.zeros(Shape(numHiddens.toLong()), DataType.FLOAT32, device)
        )
    }

    fun getLSTMParams(vocabSize: Int, numHiddens: Int, device: Device): NDList {
        // Input gate parameters
        var temp: NDList = three(vocabSize, numHiddens, device)
        val W_xi: NDArray = temp.get(0)
        val W_hi: NDArray = temp.get(1)
        val b_i: NDArray = temp.get(2)

        // Forget gate parameters
        temp = three(vocabSize, numHiddens, device)
        val W_xf: NDArray = temp.get(0)
        val W_hf: NDArray = temp.get(1)
        val b_f: NDArray = temp.get(2)

        // Output gate parameters
        temp = three(vocabSize, numHiddens, device)
        val W_xo: NDArray = temp.get(0)
        val W_ho: NDArray = temp.get(1)
        val b_o: NDArray = temp.get(2)

        // Candidate memory cell parameters
        temp = three(vocabSize, numHiddens, device)
        val W_xc: NDArray = temp.get(0)
        val W_hc: NDArray = temp.get(1)
        val b_c: NDArray = temp.get(2)

        // Output layer parameters
        val W_hq: NDArray = normal(Shape(numHiddens.toLong(), vocabSize.toLong()), device)
        val b_q: NDArray = manager.zeros(Shape(vocabSize.toLong()), DataType.FLOAT32, device)

        // Attach gradients
        val params = NDList(
            W_xi, W_hi, b_i, W_xf, W_hf, b_f, W_xo, W_ho, b_o, W_xc, W_hc, b_c, W_hq,
            b_q
        )
        for (param in params) {
            param.setRequiresGradient(true)
        }
        return params
    }


### Defining the Model

In the initialization function, the hidden state of the LSTM needs to return an *additional* memory cell with a value of 0 and a shape of (batch size, number of hidden units). Hence we get the following state initialization.


In [8]:
    fun initLSTMState(batchSize: Int, numHiddens: Int, device: Device): NDList {
        return NDList(
            manager.zeros(Shape(batchSize.toLong(), numHiddens.toLong()), DataType.FLOAT32, device),
            manager.zeros(Shape(batchSize.toLong(), numHiddens.toLong()), DataType.FLOAT32, device)
        )
    }

The actual model is defined just like what we discussed before: providing three gates and an auxiliary memory cell. Note that only the hidden state is passed to the output layer. The memory cell $\mathbf{C}_t$ does not directly participate in the output computation.


In [9]:
   fun lstm(inputs: NDArray, state: NDList, params: NDList): Pair<NDArray, NDList> {
        val W_xi = params[0]
        val W_hi = params[1]
        val b_i = params[2]
        val W_xf = params[3]
        val W_hf = params[4]
        val b_f = params[5]
        val W_xo = params[6]
        val W_ho = params[7]
        val b_o = params[8]
        val W_xc = params[9]
        val W_hc = params[10]
        val b_c = params[11]
        val W_hq = params[12]
        val b_q = params[13]
        var H = state[0]
        var C = state[1]
        val outputs = NDList()
        var X: NDArray
        var Y: NDArray
        var I: NDArray
        var F: NDArray
        var O: NDArray
        var C_tilda: NDArray
        for (i in 0 until inputs.size(0)) {
            X = inputs[i]
            I = Activation.sigmoid(X.dot(W_xi).add(H.dot(W_hi).add(b_i)))
            F = Activation.sigmoid(X.dot(W_xf).add(H.dot(W_hf).add(b_f)))
            O = Activation.sigmoid(X.dot(W_xo).add(H.dot(W_ho).add(b_o)))
            C_tilda = Activation.tanh(X.dot(W_xc).add(H.dot(W_hc).add(b_c)))
            C = F.mul(C).add(I.mul(C_tilda))
            H = O.mul(Activation.tanh(C))
            Y = H.dot(W_hq).add(b_q)
            outputs.add(Y)
        }
        return Pair(if (outputs.size > 1) NDArrays.concat(outputs) else outputs[0], NDList(H, C))
    }


### Training and Prediction

Let us train an LSTM as same as what we did in :numref:`sec_gru`, by instantiating the `RNNModelScratch` class as introduced in :numref:`sec_rnn_scratch`.


In [None]:
    val vocabSize = vocab!!.length()
    val numHiddens = 256
    val device = manager.device
    val numEpochs = Integer.getInteger("MAX_EPOCH", 500)

    val lr = 1

    val getParamsFn = ::getLSTMParams
    val initLSTMStateFn = ::initLSTMState
    val lstmFn = ::lstm
    val model = RNNModelScratch(vocabSize, numHiddens, device, getParamsFn, initLSTMStateFn, lstmFn)
    trainCh8(model, dataset, vocab, lr, numEpochs, device, false, manager)

10 : 17.987512425261023
20 : 17.456555008756336
30 : 16.756250658563452
40 : 15.733315311233822
50 : 14.588738995471852
60 : 13.024693894128117
70 : 11.896615957526722
80 : 11.311850219601421
90 : 10.87725516866155
100 : 10.51270077292614
110 : 10.181080345836675
120 : 9.846072636206788
130 : 9.429032445860424
140 : 9.182622257477348
150 : 8.839661388042837
160 : 8.470459510765247
170 : 8.274413157079868
180 : 7.851216347172008
190 : 7.6030459957881895
200 : 7.153976916075779
210 : 6.899856907150648
220 : 6.5826942205661565
230 : 6.3922043272388365
240 : 6.039340291009065
250 : 5.733006110647622
260 : 5.47648925443606
270 : 5.138043564766072
280 : 4.755370230416373
290 : 4.430931313760951
300 : 4.172333614863039
310 : 3.8082568111783437
320 : 3.5947060384853957
330 : 3.134375487974975
340 : 2.9506107313881325
350 : 2.6115556961829918
360 : 2.4586846048390596
370 : 2.1178097617815705
380 : 1.9984216998548725
390 : 1.7342695963176875
400 : 1.61662531633966
410 : 1.490990517555368
420 : 1

## Concise Implementation

Using high-level APIs,
we can directly instantiate an `LSTM` model.
This encapsulates all the configuration details that we made explicit above. The code is significantly faster as it uses compiled operators rather than Java for many details that we spelled out in detail before.


In [None]:
    val lstmLayer = LSTM.builder()
        .setNumLayers(1)
        .setStateSize(numHiddens)
        .optReturnState(true)
        .optBatchFirst(false)
        .build()
    val modelConcise = RNNModel(lstmLayer, vocab.length())
    trainCh8(modelConcise, dataset, vocab, lr, numEpochs, device, false, manager)

LSTMs are the prototypical latent variable autoregressive model with nontrivial state control.
Many variants thereof have been proposed over the years, e.g., multiple layers, residual connections, different types of regularization. However, training LSTMs and other sequence models (such as GRUs) are quite costly due to the long range dependency of the sequence.
Later we will encounter alternative models such as Transformers that can be used in some cases.


## Summary

* LSTMs have three types of gates: input gates, forget gates, and output gates that control the flow of information.
* The hidden layer output of LSTM includes the hidden state and the memory cell. Only the hidden state is passed into the output layer. The memory cell is entirely internal.
* LSTMs can alleviate vanishing and exploding gradients.


## Exercises

1. Adjust the hyperparameters and analyze the their influence on running time, perplexity, and the output sequence.
1. How would you need to change the model to generate proper words as opposed to sequences of characters?
1. Compare the computational cost for GRUs, LSTMs, and regular RNNs for a given hidden dimension. Pay special attention to the training and inference cost.
1. Since the candidate memory cell ensures that the value range is between $-1$ and $1$ by  using the $\tanh$ function, why does the hidden state need to use the $\tanh$ function again to ensure that the output value range is between $-1$ and $1$?
1. Implement an LSTM model for time series prediction rather than character sequence prediction.
