# Long short-term memory (LSTM)

A recurrent neural network component..  [Wikipedia](https://en.wikipedia.org/wiki/Long_short-term_memory)

We'll look at the "peephole LSTM" version...
\begin{align}
i[n] &= \sigma_g(A_{i} x[n] + B_{i} u[n] + c_i) & \text{input gate} \\
o[n] &= \sigma_g(A_{o} x[n] + B_{o} u[n] + c_o) & \text{output gate} \\
f[n] &= \sigma_g(A_{f} x[n] + B_{f} u[n] + c_f) & \text{forget gate} \\
x[n+1] &= f[n] \circ x[n] + i[n] \circ \tanh(B_{c} u[n] + c_c) \\
y[n] &= \tanh(o[n] \circ x[n])
\end{align}

where the operator $\circ$  denotes the Hadamard product (element-wise product). $\sigma _{g}(x) = \frac{1}{1+e^{-x}} \in (0,1)$ is the sigmoid function and recall that $\tanh(x) \in (-1,1)$.  The initial conditions are nearly always taken to be $x[0]=0$.

The intuition here is that we have four total inputs, $u[n] \in \Re^N$, and some "latching" inputs: $f[n], o[n], i[n] \in (0,1)^N$ -- but we use a sigmoid to give a differentiable version of these latches.  $o[n]$ simply turns on/off the output (it does not effect the state dynamics).  So let us think about the dynamics as: 

$$x[n+1] = \begin{cases} 0 & f[n]=i[n]=0 \\ x[n] & {f[n]=1, i[n]=0} \\ \tanh(B_c u[n] + c_c) & f[n]=0, i[n]=1, \\ x[n] + \tanh(B_c u[n] + c_c) & f[n]=i[n]=1.\end{cases}$$

We can plot that response to confirm our intuition.


In [4]:
from IPython import get_ipython
from ipywidgets import interact
if get_ipython() is not None: get_ipython().run_line_magic("matplotlib", "qt5")

import matplotlib.pyplot as plt
import numpy as np
plt.rcParams['figure.figsize'] = [10, 8]

# Compute x[n+1]
def lstm(x, u=0, f=0, i=0, b=1, c=0):
    return f*x + i*np.tanh(b*u + c) 

Lstm = np.vectorize(lstm)
xmax = 2.
ymax = 2.
x = np.arange(-xmax, xmax, 0.01)


def plot_lstm(u,f,i):
    
#    for u in range(-5,5):
    plt.clear()
    plt.plot(x, Lstm(x, u=u, f=f, i=i), linewidth=2., label="u=" + str(u))

    plt.title("f="+str(f) + ", i=" + str(i))
    plt.xlabel("x")
    plt.ylim((-ymax, ymax))
    plt.ylabel("x[n+1]")
    plt.legend()

    # draw the x and y axes.
    plt.plot([-xmax, xmax], [0, 0], color="k", linestyle="-", linewidth=1.)
    plt.plot([0, 0], [-ymax, ymax], color="k", linestyle="-", linewidth=1.)
    # draw the line through the origin with slope -1.
    plt.plot([-ymax, ymax], [-ymax, ymax], color="k", linestyle="-", linewidth=1.)
    plt.axis("equal")


In [5]:
from IPython.display import display
from ipywidgets.widgets import FloatSlider

f = FloatSlider(description="f", min=0, max=1, value=0.5)
display(f)
i = FloatSlider(description="i", min=0, max=1, value=0.5)
display(i)

FloatSlider(value=0.5, description='f', max=1.0)

FloatSlider(value=0.5, description='i', max=1.0)

In [6]:
interact(plot_lstm, u=(-5,5), f=(0,1), i=(0,1))

interactive(children=(IntSlider(value=0, description='u', max=5, min=-5), IntSlider(value=0, description='f', …

<function __main__.plot_lstm(u, f, i)>