# 5. Gated Recurrent Unit

In [1]:
import torch 
import torch.nn as nn
from torch.utils import data
from torch.nn import functional as F

import re
import collections

import math
import numpy as np
import pandas as pd

import matplotlib.pyplot as plt

In [2]:
from d2l import torch as d2l

We often avoid gradient exploding or vanishing in the cost of lossing some information in long sequences. The earliest solution is the **Long Short-term Memory** (LSTM). The **gated recurrent unit** (GRU) offered a **streamlined version** of the LSTM cell that often achieves comparable performance but with the advantage of being faster to compute.

## Reset Gate and Update Gate

The biggest difference between GRU and RNN is that the former can control **when to reset and update** the hidden states using **gates**.

The **reset gate** controls how much of the **previous state** we might still want to remember. The **update gate** would allow us to control how much of the **new state** is just a copy of the old state. 

In short:

1. Reset gates help capture **short-term dependencies** in sequences.
2. Update gates help capture **long-term dependencies** in sequences.

The **outputs** of two gates are given by two **fully connected layers** with a **sigmoid** activation function that forces the values to in the interval (0,1).

![](http://d2l.ai/_images/gru-1.svg)

For a mini-batch input $\mathbf{X}_t \in \mathbb{R}^{n \times d}$, the mathematical expression of the gates are given as:

$$
\begin{aligned}
\mathbf{R}_t = \sigma(\mathbf{X}_t \mathbf{W}_{xr} + \mathbf{H}_{t-1} \mathbf{W}_{hr} + \mathbf{b}_r)\\
\mathbf{Z}_t = \sigma(\mathbf{X}_t \mathbf{W}_{xz} + \mathbf{H}_{t-1} \mathbf{W}_{hz} + \mathbf{b}_z)
\end{aligned}
$$

where we have $\mathbf{R}_t \in \mathbb{R}^{n \times h}$, $\mathbf{Z}_t \in \mathbb{R}^{n \times h}$, $\mathbf{W}_{xr}, \mathbf{W}_{xz} \in \mathbb{R}^{d \times h}$, $\mathbf{W}_{hr}, \mathbf{W}_{hz} \in \mathbb{R}^{h \times h}$, $\mathbf{b}_r, \mathbf{b}_z \in \mathbb{R}^{1 \times h}$.

## Candidate Hidden State

The **candidate hidden state** $\tilde{\mathbf{H}}_t \in \mathbb{R}^{n \times h}$ at time step $t$ is obtained by integrating the **reset gate** with the **regular updating mechanism**:

$$\tilde{\mathbf{H}}_t = \tanh(\mathbf{X}_t \mathbf{W}_{xh} + \left(\mathbf{R}_t \odot \mathbf{H}_{t-1}\right) \mathbf{W}_{hh} + \mathbf{b}_h)$$

where we have $\mathbf{W}_{xh} \in \mathbb{R}^{d \times h}$, $\mathbf{W}_{hh} \in \mathbb{R}^{h \times h}$, $\mathbf{b}_h \in \mathbb{R}^{1 \times h}$, and $\odot$ is the Hadamard (elementwise) product operator.

The **elementwise product** of $\mathbf{R}_t$ and $\mathbf{H}_{t-1}$ **reduces** the effect of the previous hidden states (i.e. controls how much **previous information** to remember):

1. when the entries in the $\mathbf{R}_t$ are close to 1, the GRU resembles a vanilla RNN (all previous information remembered)
2. when the entries in the $\mathbf{R}_t$ are close to 0, the candidate hidden state becomes the MLP output of $\mathbf{X}_t$ (non of the previous information remembered)


![](http://d2l.ai/_images/gru-2.svg)

## Hidden State

The **update gate $\mathbf{Z}_t$** determines the extent to which the **new hidden state $\mathbf{H}_t \in \mathbb{R}^{n \times h}$** matches the **old state $\mathbf{H}_{t-1}$** versus how much it resembles the **new candidate hidden state $\tilde{\mathbf{H}}_t$**.

This can be achieved by taking **elementwise convex combinations** of $\mathbf{H}_{t-1}$ and $\tilde{\mathbf{H}}_t$ with $\mathbf{Z}_t$:

$$\mathbf{H}_t = \mathbf{Z}_t \odot \mathbf{H}_{t-1}  + (1 - \mathbf{Z}_t) \odot \tilde{\mathbf{H}}_t$$

When $\mathbf{Z}_t$ is close to 1, the model simply **retains** the old state. In this case, the information from $\mathbf{X}_t$ is **ignored**, effectively **skipping time step $t$** in the dependency chain. 

In contrast, when $\mathbf{Z}_t$ is close to 0, the new hidden state approaches the candidate hidden state.

![](http://d2l.ai/_images/gru-3.svg)

## Implementing GRU

In [3]:
batch_size, num_steps = 32, 35
train_iter, vocab = d2l.load_data_time_machine(batch_size, num_steps)

In [5]:
vocab_size, num_hiddens, device = len(vocab), 256, torch.device('mps')

In [None]:
num_inputs = vocab_size
num_epochs, lr = 500, 1

gru_layer = nn.GRU(num_inputs, num_hiddens)

model = d2l.RNNModel(gru_layer, len(vocab))
model = model.to(device)

d2l.train_ch8(model, train_iter, vocab, lr, num_epochs, device)