Recurrent Neural Net (RNN) uses recursion as a strategy to create deep neural nets. RNNs use the same hidden weight matrix for all additional layers, which make it memory efficient. It trains this weight matrix by recursively incorporating the next token in the training sequence (like a sentence). RNN is best suited for sequential data like natural language or time series.

Suppose the hidden layer is $W$, $e$ are token embeddings, and $h_o$ is the output layer:

The first hidden state/activation is $$h_0 = \text{ReLU}\big(W\,e_0 + b\big)$$
The second hidden state is $$h_1 = \text{ReLU}\big(W\,(h_0 + e_1) + b\big)$$
The third hidden state: $$h_2 = \text{ReLU}\big(W\,(h_1 + e_2) + b\big)$$
The output predictions: $$\text{logits} = h_o(h_2)$$
Notice that throughout this process, the hidden layer, $W$, stays the same in the forward pass.

In general, the standard RNN is formulated as
$$
h_t = f(W h_{t-1} + U e_t + b)
$$

* $h_{t-1}$ : previous hidden state
* $e_t$ : embedding (input at this step)
* $W$ : hidden-to-hidden weights (processes the past)
* $U$ : input-to-hidden weights (processes the current token)
* $b$ : bias
* $f$ :the activation function

# Data Preparation

In [1]:
# !wget "https://s3.amazonaws.com/fast-ai-sample/human_numbers.tgz" -O "../data/human_numbers.tgz" && tar -xzf "../data/human_numbers.tgz" -C ../data/


In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader

In [3]:
from pathlib import Path

sample_path = Path("../data/human_numbers")
print(list(sample_path.iterdir()))

[PosixPath('../data/human_numbers/train.txt'), PosixPath('../data/human_numbers/valid.txt')]


In [4]:
lines = []
with open(sample_path / "train.txt") as f:
    lines += [*f.readlines()]
with open(sample_path / "valid.txt") as f:
    lines += [*f.readlines()]
len(lines)


9998

In [5]:
text = " . ".join([l.strip() for l in lines])
tokens = text.split(" ")
vocab = set(tokens)
len(vocab), list(vocab)[:10]

(30,
 ['nine',
  'fifty',
  'one',
  'thirty',
  'hundred',
  'two',
  '.',
  'fifteen',
  'thousand',
  'seven'])

In [6]:
word2idx = {w: i for i, w in enumerate(vocab)}
nums = [word2idx[i] for i in tokens]
len(nums)

63095

Our neural network will predict next word in a sequence. The training data will have three words and the model will predict the forth word:

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F



In [8]:
[(tokens[i : i + 3], tokens[i + 3]) for i in range(0, len(tokens) - 4, 3)][:5]


[(['one', '.', 'two'], '.'),
 (['.', 'three', '.'], 'four'),
 (['four', '.', 'five'], '.'),
 (['.', 'six', '.'], 'seven'),
 (['seven', '.', 'eight'], '.')]

convert that into tensors:

In [9]:
seqs = [
    (torch.tensor(nums[i : i + 3]), nums[i + 3]) for i in range(0, len(nums) - 4, 3)
]
seqs[:5]

[(tensor([2, 6, 5]), 6),
 (tensor([ 6, 16,  6]), 27),
 (tensor([27,  6, 29]), 6),
 (tensor([ 6, 18,  6]), 9),
 (tensor([ 9,  6, 14]), 6)]

In [10]:
bs = 1
cut = int(len(seqs) * 0.8)
dls_train = DataLoader(seqs[:cut], batch_size=bs)
dls_valid = DataLoader(seqs[cut:], batch_size=bs)


# RNN

A literal interpretation of RNN where the same hidden layer is applied to the embeddings one at a time. But the hidden layer is the same at every step.

In [11]:
class LMModel1(nn.Module):
    def __init__(self, vocab_sz, n_hidden):
        super().__init__()
        # input layer
        self.i_h = nn.Embedding(vocab_sz, n_hidden)
        # hidden layer
        self.h_h = nn.Linear(n_hidden, n_hidden)
        # output layer
        self.h_o = nn.Linear(n_hidden, vocab_sz)

    def forward(self, x):
        # h is hidden activation
        h = F.relu(self.h_h(self.i_h(x[:, 0])))
        h = h + self.i_h(x[:, 1])
        # use the same hidden layer for all input
        h = F.relu(self.h_h(h))
        h = h + self.i_h(x[:, 2])
        h = F.relu(self.h_h(h))
        return self.h_o(h)


In [12]:
xb, yb = next(iter(dls_train))
rnn1 = LMModel1(len(vocab), 3)
outputs = rnn1(xb)
# take the index that gives the biggest elements in outputs tensor
pred_idx = outputs.argmax(dim=-1)
print("Predicted index:", pred_idx.item())

Predicted index: 26


Rewrite the forward pass with loop:

In [13]:
class LMModel2(nn.Module):
    def __init__(self, vocab_sz, n_hidden):
        super().__init__()
        # input layer
        self.i_h = nn.Embedding(vocab_sz, n_hidden)
        # hidden layer
        self.h_h = nn.Linear(n_hidden, n_hidden)
        # output layer
        self.h_o = nn.Linear(n_hidden, vocab_sz)

    def forward(self, x):
        h = 0
        for i in range(3):
            h = h + self.i_h(x[:, i])
            # use the same hidden layer for all input
            h = F.relu(self.h_h(h))
        return self.h_o(h)


Make the hidden state persists over each step by making it a class attribute. Since the hidden state exists forever, keeping track and calculating its gradient become very expensive as the nn becomes very deep. We will use `detach()` to remove gradient history, and keep only the most recent 3 gradients instead.

To make this work, we also need to change our `DataLoader` to generate continuous sequences across batch. That's to say, if we have batch size `bs`, our dataset is divided into `m = len(dset) // bs` groups (the # of batches). Across these groups, sequence at index `i` should follow one another. That is to say, ith sequence in every batch should follow one other.

Say our batch size is 3, we will have 7010 groups, our original dataset is `seqs`.

In [14]:
bs = 3
m = len(seqs) // bs
m, bs, len(seqs)


(7010, 3, 21031)

In [15]:
seqs[:3]

[(tensor([2, 6, 5]), 6), (tensor([ 6, 16,  6]), 27), (tensor([27,  6, 29]), 6)]

We define `group_chunks` to load our dataset based on the logic above. As we can see, the resulting batches `bx, cx, dx` have their 1st element the same as the first 3 elements of `seqs`

In [16]:
def group_chunks(ds, bs):
    m = len(ds) // bs
    new_ds = []
    for i in range(m):
        new_ds += [ds[i + m * j] for j in range(bs)]
    return new_ds


dls_train = DataLoader(group_chunks(seqs[:cut], bs), batch_size=bs, drop_last=True)
it = iter(dls_train)
bx, by = next(it)
cx, cy = next(it)
dx, dy = next(it)
bx[:10], cx[:10], dx

(tensor([[ 2,  6,  5],
         [ 8, 14,  4],
         [29,  8, 27]]),
 tensor([[ 6, 16,  6],
         [25, 29,  6],
         [ 4, 25,  5]]),
 tensor([[27,  6, 29],
         [ 5,  8, 14],
         [ 6, 29,  8]]))

In [17]:
class LMModel3(nn.Module):
    def __init__(self, vocab_sz, n_hidden):
        super().__init__()
        # input layer
        self.i_h = nn.Embedding(vocab_sz, n_hidden)
        # hidden layer
        self.h_h = nn.Linear(n_hidden, n_hidden)
        # output layer
        self.h_o = nn.Linear(n_hidden, vocab_sz)
        # initiate hidden state
        self.h = 0

    def forward(self, x):
        for i in range(3):
            self.h = self.h + self.i_h(x[:, i])
            # use the same hidden layer for all input
            self.h = F.relu(self.h_h(self.h))
        out = self.h_o(self.h)
        # remove old gradient
        self.h = self.h.detach()
        return out

Next, instead of predict every 3 words, we predict every word. Todo so, we modify our dataset further, since `seqs` is defined for every 3 words. We define the sequence length as `sl`, each element in `seqs` now contains 2 element, offset by 1 index.

In [18]:
sl = 16
seqs = [
    (torch.tensor(nums[i : i + sl]), torch.tensor(nums[i + 1 : i + sl + 1]))
    for i in range(0, len(nums) - sl - 1, sl)
]
cut = int(len(seqs) * 0.8)
seqs[:5]

[(tensor([ 2,  6,  5,  6, 16,  6, 27,  6, 29,  6, 18,  6,  9,  6, 14,  6]),
  tensor([ 6,  5,  6, 16,  6, 27,  6, 29,  6, 18,  6,  9,  6, 14,  6,  0])),
 (tensor([ 0,  6, 17,  6, 13,  6, 10,  6, 20,  6, 26,  6,  7,  6, 19,  6]),
  tensor([ 6, 17,  6, 13,  6, 10,  6, 20,  6, 26,  6,  7,  6, 19,  6, 21])),
 (tensor([21,  6, 23,  6, 22,  6, 12,  6, 12,  2,  6, 12,  5,  6, 12, 16]),
  tensor([ 6, 23,  6, 22,  6, 12,  6, 12,  2,  6, 12,  5,  6, 12, 16,  6])),
 (tensor([ 6, 12, 27,  6, 12, 29,  6, 12, 18,  6, 12,  9,  6, 12, 14,  6]),
  tensor([12, 27,  6, 12, 29,  6, 12, 18,  6, 12,  9,  6, 12, 14,  6, 12])),
 (tensor([12,  0,  6,  3,  6,  3,  2,  6,  3,  5,  6,  3, 16,  6,  3, 27]),
  tensor([ 0,  6,  3,  6,  3,  2,  6,  3,  5,  6,  3, 16,  6,  3, 27,  6]))]

In [19]:
class LMModel4(nn.Module):
    def __init__(self, vocab_sz, n_hidden, bs):
        super().__init__()
        # input layer
        self.i_h = nn.Embedding(vocab_sz, n_hidden)
        # hidden layer
        self.h_h = nn.Linear(n_hidden, n_hidden)
        # output layer
        self.h_o = nn.Linear(n_hidden, vocab_sz)
        # store dimensions for proper initialization
        self.n_hidden = n_hidden
        # initiate hidden state properly
        self.h = torch.zeros(bs, n_hidden)
        # self.h = None

    def forward(self, x):
        _, sl = x.shape

        out = []
        for i in range(sl):
            # Add embedding to hidden state
            self.h = self.h + self.i_h(x[:, i])
            # Apply hidden layer with activation
            self.h = F.relu(self.h_h(self.h))
            # Generate output for this timestep
            out.append(self.h_o(self.h))
        
        # Detach hidden state to prevent gradient explosion
        self.h = self.h.detach()
        return torch.stack(out, dim=1)

    def reset(self):
        """Reset the hidden state"""
        self.h = torch.zeros(bs, self.n_hidden)


The output shape of the model is `bs x sl x vocab_sz`, our valid data are `bs x sl`

In [20]:
bs = 64
dls_train = DataLoader(group_chunks(seqs[:cut], bs), batch_size=bs, drop_last=True)
dls_valid = DataLoader(group_chunks(seqs[cut:], bs), batch_size=bs, drop_last=True)

xb, yb = next(iter(dls_train))
rnn2 = LMModel4(len(vocab), 64, bs)
rnn2(xb).shape, xb.shape, yb.shape

(torch.Size([64, 16, 30]), torch.Size([64, 16]), torch.Size([64, 16]))

We make the align by flattening them:

In [21]:
rnn2(xb).view(-1, len(vocab)).shape, yb.view(-1).shape
F.cross_entropy(rnn2(xb).view(-1, len(vocab)), yb.view(-1))


tensor(3.4675, grad_fn=<NllLossBackward0>)

Based on the above, we define our loss function:

In [22]:
def loss_func(inp, targ):
    return F.cross_entropy(inp.view(-1, len(vocab)), targ.view(-1))

In [23]:
def batch_accuracy(pred, target):
    # pred: (bs, sl, vocab), targ: (bs, sl)
    return (pred.argmax(-1) == target).float().mean().item()


In [24]:
print(f"Training batches: {len(dls_train)}")
print(f"Validation batches: {len(dls_valid)}")

Training batches: 49
Validation batches: 12


We reset the hidden state of the model at the beginning of each train and validation phases of an epoch. this will make sure we start with a clean state before reading those continuous chunks of text. 

In [40]:
epochs = 20
# Use a lower learning rate to start
rnn2 = LMModel4(len(vocab), 64, bs)

optimizer = torch.optim.SGD(rnn2.parameters(), lr=0.01, momentum=0.9, weight_decay=5e-4)
scheduler = torch.optim.lr_scheduler.OneCycleLR(
    optimizer, max_lr=0.01, steps_per_epoch=len(dls_train), epochs=epochs
)


# torch.backends.cudnn.benchmark = True  # good if input sizes are consistent
def train(model, epochs, train_loader, valid_loader):
    for epoch in range(epochs):
        epoch_loss = torch.zeros(())
        batch_num = 0
        model.train()
        model.reset() # reset hidden state at the beginning of each epoch
        for xb, yb in train_loader:
            pred = model(xb)
            loss = loss_func(pred, yb)
            loss.backward()
            optimizer.step()
            scheduler.step()
            optimizer.zero_grad()
            epoch_loss += loss.item()
            batch_num += 1  # number of batches within a epoch
        avg_loss = epoch_loss / batch_num
        model.eval()
        model.reset()
        with torch.no_grad():
            batch_num_valid = 0
            valid_loss = 0
            valid_acc = 0
            for xb, yb in valid_loader:
                pred = model(xb)
                valid_loss += loss_func(pred, yb).item()
                valid_acc += batch_accuracy(pred, yb)
                batch_num_valid += 1
        print(f"epoch {epoch}, train loss: {avg_loss:.4f}")
        print(f"validation loss {valid_loss / batch_num_valid:.4f}")
        print(f"validation accuracy {valid_acc / batch_num_valid:.4f}")

In [41]:
train(rnn2, epochs, dls_train, dls_valid)

epoch 0, train loss: 3.3567
validation loss 3.2219
validation accuracy 0.1353
epoch 1, train loss: 2.8462
validation loss 2.5005
validation accuracy 0.4107
epoch 2, train loss: 1.9836
validation loss 2.0847
validation accuracy 0.4659
epoch 3, train loss: 1.5911
validation loss 1.8851
validation accuracy 0.4338
epoch 4, train loss: 1.5050
validation loss 1.9142
validation accuracy 0.3676
epoch 5, train loss: 1.4625
validation loss 1.9384
validation accuracy 0.3444
epoch 6, train loss: 1.4268
validation loss 1.9259
validation accuracy 0.3503
epoch 7, train loss: 1.3876
validation loss 1.8492
validation accuracy 0.3971
epoch 8, train loss: 1.3426
validation loss 1.7483
validation accuracy 0.4548
epoch 9, train loss: 1.3006
validation loss 1.7162
validation accuracy 0.4701
epoch 10, train loss: 1.2694
validation loss 1.6989
validation accuracy 0.4807
epoch 11, train loss: 1.2361
validation loss 1.7161
validation accuracy 0.4776
epoch 12, train loss: 1.1958
validation loss 1.7573
validation

# Long Short-Term Memory (LSTM)

Next, we will create LSTM, which solves the problem of vanishing/exploding gradients. The LSTM cell includes two hidden states instead of one in classic RNN. In classic RNN, the hidden state is responsible for:
    1. Having the right information for the output layer to predict the correct next token
    2. Retaining memory of everything that happened in the sentence
It turns out that RNN is bad at memorizing things distant in the memory. So we introduces a **cell state** to keep track of the memory.

In [None]:
class LSTMCell(nn.Module):
    def __init__(self, ni, nh):
        self.forget_gate = nn.Linear(ni + nh, nh)
        self.input_gate  = nn.Linear(ni + nh, nh)
        self.cell_gate   = nn.Linear(ni + nh, nh)
        self.output_gate = nn.Linear(ni + nh, nh)

    def forward(self, input, state):
        h,c = state
        h = torch.cat([h, input], dim=1)
        forget = torch.sigmoid(self.forget_gate(h))
        c = c * forget
        inp = torch.sigmoid(self.input_gate(h))
        cell = torch.tanh(self.cell_gate(h))
        c = c + inp * cell
        out = torch.sigmoid(self.output_gate(h))
        h = out * torch.tanh(c)
        return h, (h,c)

In [29]:
class LMModel5(nn.Module):
    def __init__(self, vocab_sz, n_hidden, n_layers):
        super().__init__()
        self.i_h = nn.Embedding(vocab_sz, n_hidden)
        self.rnn = nn.LSTM(n_hidden, n_hidden, n_layers, batch_first=True)
        self.h_o = nn.Linear(n_hidden, vocab_sz)
        self.h = [torch.zeros(n_layers, bs, n_hidden) for _ in range(2)]
        
    def forward(self, x):
        res,h = self.rnn(self.i_h(x), self.h)
        self.h = [h_.detach() for h_ in h]
        return self.h_o(res)
    
    def reset(self): 
        for h in self.h: h.zero_()

In [46]:

epochs = 20
rnn3 = LMModel5(len(vocab), 64, 2)
optimizer = torch.optim.SGD(rnn3.parameters(), lr=0.005, momentum=0.9, weight_decay=5e-4)
scheduler = torch.optim.lr_scheduler.OneCycleLR(
    optimizer, max_lr=0.01, steps_per_epoch=len(dls_train), epochs=epochs
)

train(rnn3, epochs, dls_train, dls_valid)

epoch 0, train loss: 3.3705
validation loss 3.3514
validation accuracy 0.1423
epoch 1, train loss: 3.3021
validation loss 3.2490
validation accuracy 0.1422
epoch 2, train loss: 3.1584
validation loss 3.0872
validation accuracy 0.1580
epoch 3, train loss: 2.9764
validation loss 2.9400
validation accuracy 0.1692
epoch 4, train loss: 2.8543
validation loss 2.8835
validation accuracy 0.1530
epoch 5, train loss: 2.7984
validation loss 2.8554
validation accuracy 0.1645
epoch 6, train loss: 2.7654
validation loss 2.8339
validation accuracy 0.1996
epoch 7, train loss: 2.7436
validation loss 2.8132
validation accuracy 0.2486
epoch 8, train loss: 2.7256
validation loss 2.7923
validation accuracy 0.3114
epoch 9, train loss: 2.7071
validation loss 2.7701
validation accuracy 0.3575
epoch 10, train loss: 2.6852
validation loss 2.7448
validation accuracy 0.3842
epoch 11, train loss: 2.6576
validation loss 2.7144
validation accuracy 0.4014
epoch 12, train loss: 2.6231
validation loss 2.6778
validation

# References:
1. Colah, Understanding LSTM Networks: https://colah.github.io/posts/2015-08-Understanding-LSTMs/