# RNN language model

In [1]:
from chapter import *

Our goal in this section is to train a character-level RNN language model to predict the next token at *each* step with varying-length context. Hence, during training, our model predicts on each time-step ({numref}`04-char-rnn`). The language model below is simply an RNN cell with an attached **logits layer** applied at each step.


```{figure} ../../../img/nn/04-char-rnn.svg
---
width: 550px
name: 04-char-rnn
align: center
---
Character-level RNN language model for predicting the next character at each step.  [Source](https://www.d2l.ai/chapter_recurrent-neural-networks/rnn.html)
```

To implement a language model, we simply attach a linear layer on the RNN unit to compute logits. 
The linear layer performs matrix multiplication on the rightmost dimension of `outs` which contains the value of the state vector at each time step. Thus, as shown in {numref}`04-char-rnn` we have $T$ predictions with increasing context size[^1] $1, 2, \ldots, T.$

[^1]: Consequently, the model gets corrected at each time step, with variable-length dependency, during backward pass. 

In [2]:
%%save
import torch
import torch.nn as nn
from typing import Type
from functools import partial


class RNNLanguageModel(nn.Module):
    def __init__(self, 
        cell: Type[RNNBase],
        inputs_dim: int,
        hidden_dim: int,
        vocab_size: int,
        **kwargs
    ):
        super().__init__()
        self.cell = cell(inputs_dim, hidden_dim, **kwargs)
        self.linear = nn.Linear(hidden_dim, vocab_size)

    def forward(self, x, state=None, return_state=False):
        outs, state = self.cell(x, state)
        outs = self.linear(outs)    # (T, B, H) -> (T, B, C)
        return outs if not return_state else (outs, state)


LanguageModel = lambda cell: partial(RNNLanguageModel, cell)

<br>

## Character sequences dataset

Our dataset consists of $T$ input-output pairs of characters **shifted** one time step:

In [3]:
%%save
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

class SequenceDataset(Dataset):
    def __init__(self, corpus: list, seq_len: int, vocab_size: int):
        super().__init__()
        self.corpus = corpus
        self.seq_len = seq_len
        self.vocab_size = vocab_size

    def __getitem__(self, i):
        c = torch.tensor(self.corpus[i: i + self.seq_len + 1])
        x, y = c[:-1], c[1:]
        x = F.one_hot(x, num_classes=self.vocab_size).float()
        return x, y
    
    def __len__(self):
        return len(self.corpus) - self.seq_len

Training on the *Time Machine* text:

In [4]:
%%save
import re
import os
import requests
import collections
from pathlib import Path

DATA_DIR = Path("./data")
DATA_DIR.mkdir(exist_ok=True)


class Vocab:
    def __init__(self, tokens=[], min_freq=0, reserved_tokens=[]):
        counter = collections.Counter(tokens)
        self.token_freqs = sorted(counter.items(), key=lambda x: x[1], reverse=True)
        self.itot = ["<unk>"] + list(sorted(set(
            reserved_tokens +   # i.e. not subject to min_freq
            [token for token, freq in self.token_freqs if freq >= min_freq]
        )))
        self.ttoi = {tok: idx for idx, tok in enumerate(self.itot)}

    def __len__(self):
        return len(self.itot)
    
    def __getitem__(self, tokens: list[str]) -> list[int]:
        if isinstance(tokens, (list, tuple)):
            return [self.__getitem__(tok) for tok in tokens]
        else:
            return self.ttoi.get(tokens, self.unk)
            
    def to_tokens(self, indices) -> list[str]:
        if hasattr(indices, "__len__"):
            return [self.itot[int(index)] for index in indices]
        else:
            return self.itot[indices]

    @property
    def unk(self) -> int:
        return self.ttoi["<unk>"]


class TimeMachine:
    def __init__(self, download=False, path=None, token_level="char"):
        DEFAULT_PATH = str((DATA_DIR / "time_machine.txt").absolute())
        self.token_level = token_level
        self.filepath = path or DEFAULT_PATH
        if download or not os.path.exists(self.filepath):
            self._download()
        
    def _download(self):
        url = "https://www.gutenberg.org/cache/epub/35/pg35.txt"
        print(f"Downloading text from {url} ...", end=" ")
        response = requests.get(url, stream=True)
        response.raise_for_status()
        print("OK!")
        with open(self.filepath, "wb") as output:
            output.write(response.content)
        
    def _preprocess(self, text: str):
        s = "*** START OF THE PROJECT GUTENBERG EBOOK THE TIME MACHINE ***"
        e = "*** END OF THE PROJECT GUTENBERG EBOOK THE TIME MACHINE ***"
        text = text[text.find(s) + len(s): text.find(e)]
        text = re.sub('[^A-Za-z]+', ' ', text).lower().strip()
        return text
    
    def tokenize(self, text: str):
        return list(text) if self.token_level == "char" else text.split()
        
    def build(self, vocab=None):
        with open(self.filepath, "r") as f:
            raw_text = f.read()
        
        self.text = self._preprocess(raw_text)
        self.tokens = self.tokenize(self.text) 
        
        vocab = Vocab(self.tokens) if vocab is None else vocab
        corpus = vocab[self.tokens]
        return corpus, vocab

In [5]:
from torch.utils.data import random_split

def collate_fn(batch):
    """Transforming the data to sequence-first format."""
    x, y = zip(*batch)
    x = torch.stack(x, 1)      # (T, B, vocab_size)
    y = torch.stack(y, 1)      # (T, B)
    return x, y


tm = TimeMachine()
corpus, vocab = tm.build()
dataset = SequenceDataset(corpus, seq_len=10, vocab_size=len(vocab))
train_dataset, valid_dataset = random_split(dataset, [0.80, 0.20])
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, collate_fn=collate_fn)
valid_loader = DataLoader(valid_dataset, batch_size=32, shuffle=True, collate_fn=collate_fn)

The batch index (i.e. starting point) is shuffled, but the ordering in each sequence is intact:

In [6]:
x, y = next(iter(train_loader))

a, T = 1, dataset.seq_len
x_chars = vocab.to_tokens(torch.argmax(x[:, a], dim=1))     # inputs are one-hot
y_chars = vocab.to_tokens(y[:, a])
for i in range(T):
    print(f"{x_chars[i]} --> {y_chars[i]}")

y -->  
  --> o
o --> w
w --> n
n -->  
  --> e
e --> x
x --> p
p --> e
e --> n


In [7]:
print(x.shape, y.shape)
print("inputs:", torch.argmax(x[:, 0], dim=-1))
print("target:", y[:, 0])

torch.Size([10, 32, 28]) torch.Size([10, 32])
inputs: tensor([ 2, 15,  5,  1, 24,  9,  6, 15,  1, 10])
target: tensor([15,  5,  1, 24,  9,  6, 15,  1, 10,  1])


PyTorch `F.cross_entropy` expects input `(B, C, T)` and target `(B, T)`:

In [8]:
import torch.nn.functional as F

x, y = next(iter(train_loader))
model = LanguageModel(RNN)(28, 5, len(vocab))
loss = F.cross_entropy(model(x).permute(1, 2, 0), y.transpose(0, 1))
loss

tensor(3.3582, grad_fn=<NllLoss2DBackward0>)