# LSTM
![LSTM](https://user-images.githubusercontent.com/26361028/125639171-b7f92866-2481-4293-b9ad-af85281c0967.png)


$\textbf{F}_t = \sigma(\textbf{X}_t\textbf{W}_{xf} + \textbf{H}_{t-1}\textbf{W}_{hf} + \textbf{b}_f)$

$\textbf{I}_t = \sigma(\textbf{X}_t\textbf{W}_{xi} + \textbf{H}_{t-1}\textbf{W}_{hi} + \textbf{b}_i)$

$\textbf{O}_t = \sigma(\textbf{X}_t\textbf{W}_{xo} + \textbf{H}_{t-1}\textbf{W}_{ho} + \textbf{b}_o)$

where, 

$\textbf{W}_{xf}$, $\textbf{W}_{xi}$, $\textbf{W}_{xo}$ $\in \mathbb{R}^{dxh}$

$\textbf{W}_{hf}$, $\textbf{W}_{hi}$, $\textbf{W}_{ho}$ $\in \mathbb{R}^{hxh}$

$\textbf{b}_{f}$, $\textbf{b}_{i}$, $\textbf{b}_{o}$ $\in \mathbb{R}^{1xh}$

**Candidate Memory Cell**:

$\tilde{\textbf{C}_t} = tanh(\textbf{X}_t\textbf{W}_{xc} + \textbf{H}_{t-1}\textbf{W}_{hc} + \textbf{b}_c)$

where,

$\textbf{W}_{xc} \in \mathbb{R}^{dxh}$

$\textbf{W}_{hc} \in \mathbb{R}^{hxh}$

$\textbf{b}_{c} \in \mathbb{R}^{1xh}$

**Memory Cell**:

$\textbf{C}_t = \textbf{F}_t \odot \textbf{C}_{t-1} + \textbf{I}_t \odot \tilde{\textbf{C}_t}$

**Hidden Cell**:

$\textbf{H}_t = \textbf{O}_t \odot tanh(\textbf{C}_t)$

In [None]:
!pip install d2l

In [2]:
import torch
import torch.nn as nn
from d2l import torch as d2l

In [3]:
batch_size, num_steps = 32, 35
train_iter, vocab = d2l.load_data_time_machine(batch_size, num_steps)

Downloading ../data/timemachine.txt from http://d2l-data.s3-accelerate.amazonaws.com/timemachine.txt...


# Initializing Model parameters

In [4]:
def get_lstm_params(vocab_size, num_hiddens, device):
  num_inptus = num_outputs = vocab_size

  def normal(shape):
    return torch.randn(size=shape, device=device) * 0.01
  
  def three():
    return (normal((num_inputs, num_hiddens)),
            normal((num_hiddens, num_hiddens)),
            torch.zeros(num_hiddens, device=device))
    
  W_xi, W_hi, b_i = three() # Input gate parameters
  W_xf, W_hf, b_f = three() # Forget gate parameters
  W_xo, W_ho, b_o = three() # Output gate parameters
  W_xc, W_hc, b_c = three() # Candidate memory cell parameters

  # Output parameters
  W_hy = normal((num_hiddens, num_outputs))
  b_y = torch.zeros(num_outputs, device=device)

  # Attach gradients
  params = [W_xi, W_hi, b_i, W_xf, W_hf, b_f, W_xo, W_ho, b_o, W_xc, W_hc, b_c, W_hy, b_y]
  for param in params:
    param.requires_grad_(True)

  return params

# Defining the Model

In [6]:
def init_lstm_state(batch_size, num_hiddens, device):
  return (torch.zeros((batch_size, num_hiddens), device=device), 
          torch.zeros((batch_size, num_hiddens), device=device))
  
def lstm(inputs, state, params):
  # inputs shape: (num_steps, batch_size, vocab_size)
  W_xi, W_hi, b_i, W_xf, W_hf, b_f, W_xo, W_ho, b_o, W_xc, W_hc, b_c, W_hy, b_y = params
  (H,C) = state
  outputs = []

  for X in inputs:
    # X shape: (batch_size, vocab_size)
    I = torch.sigmoid( (X @ W_xi) + (H @ W_hi) + b_i)
    F = torch.sigmoid( (X @ W_xf) + (H @ W_hf) + b_f)
    O = torch.sigmoid( (X @ W_xo) + (H @ W_ho) + b_o)

    C_tilda = torch.tanh( (X @ W_xc) + (H @ W_hc) + b_c)

    C = F * C + I * C_tilda
    H = O * torch.tanh(C)
    Y = (H @ W_hy) + b_y    # shape: (batch_size, vocab_size)
    outputs.append(Y)

  return torch.cat(outputs, dim=0), (H,C) # shape: (num_steps x batch_size, vocab_size), (H,C)=> (n x h) each

# Training and prediction

In [None]:
vocab_size, num_hiddens, device = len(vocab), 256, d2l.try_gpu()
num_epochs, lr = 500, 1
model = d2l.RNNModelScratch(len(vocab), num_hiddens, device, get_lstm_params,
                            init_lstm_state, lstm)
d2l.train_ch8(model, train_iter, vocab, lr, num_epochs, device)  

# Concise Implementation

In [None]:
num_inputs = vocab_size
lstm_layer = nn.LSTM(num_inputs, num_hiddens)
model = d2l.RNNModel(lstm_layer, len(vocab))
model = model.to(device)
d2l.train_ch8(model, train_iter, vocab, lr, num_epochs, device)