#### Transformer Language Model

Similar to what we did for the RNN, we are given a sequence of word embedding vectors $w_1, .., w_N$, we pass the sequence through a `Transformer decoder network` (i.e. causal masking is applied to the attention scores) to obtain a sequence of output vectors $h_1,..,h_N$. Then each of these $h_i$ can be regarded as a contextual repsentation of the words $w_1, .., w_i$. So using this output, we can compute a probability distribution for the next word that follows all the preceding words:

$P(w_{i+1} | w_1,...,w_i) = f(h_i)$

Again, as in the case of the RNN, we will use a simple feed forward network for the transformation $f$, in this case a linear projection followed by a softmax. Instead of using word-level tokenization, we will use sub-word tokenization, in particular Byte-Pair Encoding. We will train the model on the collected work of Shakespeare.

The transformer does not suffer from the vanishing gradient and weak long-range dependency problems and therefore has potentially higher representational power than an RNN with comparable number of parameters and is able to generalize better. We will see superior performance for the transformer, with much lower perplexity on the validation set and better quality text generation.

In [1]:
%load_ext autoreload
%autoreload 2

import random, math, pickle
from nltk.tokenize import word_tokenize
import torch
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torch.nn.utils.rnn import pad_sequence
from tqdm import tqdm
import psutil
from BPE import BPE
import wandb
wandb.login()

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mtanzids[0m. Use [1m`wandb login --relogin`[0m to force relogin


True

In [2]:
# prep the training data
with open('shakespeare.txt', 'r') as file:
    lines = file.readlines()

In [3]:
# now lets train a BPE tokenizer on the Shakespeare corpus
tokenizer = BPE(max_vocab_size=10000, sos_token="<s>", eos_token="</s>")
tokenizer.learn(lines)
tokenizer.precompute_word_tokens(lines)
print(f"\nLearned vocab: {tokenizer.vocab}\n")


Number of words in corpus: 334797


Building vocab. Num tokens added --> : 100%|██████████| 10000/10000 [02:13<00:00, 75.09it/s]


Done building vocab!


Precomputing word tokenizations for corpus words--> : 100%|██████████| 40000/40000 [00:33<00:00, 1193.62it/s]

Done precomputing subword tokens for all words in corpus.







In [4]:
# encode and decode on test sentence to make sure tokenizer is working properly 
s = ["Yeah, I'm gonna inform them today."]
s_tokenized= tokenizer.tokenize_sentences(s)
s_encoded= tokenizer.encode(s)
s_decoded= tokenizer.decode(s_encoded)
print(s_tokenized)
print(s_encoded)
print(s_decoded)

[['<s>_', 'Y', 'ea', 'h_', ',_', 'I_', "'", 'm_', 'g', 'on_', 'na_', 'inform_', 'them_', 'to', 'day_', '._', '</s>_']]
[[94, 2105, 4214, 5226, 72, 1024, 5, 6265, 4970, 6884, 6656, 5645, 9029, 9192, 3759, 81, 92]]
[" Yeah , I 'm gon na inform them today . \n"]


In [3]:
# save the trained tokenizer to file
"""
with open("BPE_tokenizer.pkl", 'wb') as file:
    pickle.dump(tokenizer, file)
    
"""


# load trained BPE tokenizer from file
with open('BPE_tokenizer.pkl', 'rb') as f:
    tokenizer = pickle.load(f)


Tokenize the sentences and create train-test splits

In [5]:
# subword tokenize the sentences and convert to integer tokens
sentences_tokenized = tokenizer.encode(lines)

# create train-val splits
num_sent = len(sentences_tokenized)
num_test = int(0.1 * num_sent)
x_train = sentences_tokenized[:-num_test]
x_val   = sentences_tokenized[-num_test:]

# concatenate all sentences
x_train = [w for s in x_train for w in s] 
x_val   = [w for s in x_val for w in s] 

In [6]:
print(f"Num tokens train: {len(x_train)}")
print(f"Num tokens val: {len(x_val)}")

Num tokens train: 316352
Num tokens val: 32542


Create pytorch dataset

In [7]:
class Shakespeare(Dataset):
    def __init__(self, corpus, block_size=16):
        self.corpus = corpus
        self.block_size = block_size

    def __len__(self):
        return len(self.corpus)-self.block_size-1

    def __getitem__(self, idx):
        inputs = torch.tensor(self.corpus[idx:idx+self.block_size], dtype=torch.long)
        targets = torch.tensor(self.corpus[idx+1:idx+1+self.block_size], dtype=torch.long)
        return inputs, targets

Create Transformer decoder model

In [8]:
class MultiHeadAttention(torch.nn.Module):
    def __init__(self, block_size, embedding_dim, total_head_size, num_heads, dropout_rate):
        super().__init__()

        assert total_head_size % num_heads == 0, "head_size needs to be integer multiple of num_heads"

        self.block_size = block_size
        self.embedding_dim = embedding_dim
        self.total_head_size = total_head_size 
        self.head_size = total_head_size // num_heads 
        self.num_heads = num_heads
        self.dropout_rate = dropout_rate

        # define parameters
        self.key = torch.nn.Linear(embedding_dim, self.total_head_size, bias=False)
        self.query = torch.nn.Linear(embedding_dim, self.total_head_size, bias=False)
        self.value = torch.nn.Linear(embedding_dim, self.total_head_size, bias=False)
        self.attn_dropout = torch.nn.Dropout(dropout_rate)

        # non-parameter tensor of lower triangular ones
        self.register_buffer('tril', torch.tril(torch.ones(block_size, block_size)))

        # we also need to apply a linear projection to make the output residual the same dimension as the input
        self.proj = torch.nn.Linear(total_head_size, embedding_dim) 
        self.output_dropout = torch.nn.Dropout(dropout_rate)


    # define forward pass, input shape: (B,T,C) where B=batch size, T=block_size, C=embedding_dim
    def forward(self, x):
        B, T, C = x.shape
        k = self.key(x) # (B,T,H) where H is the total_head_size
        q = self.query(x) # (B,T,H)
        v = self.value(x) # (B,T,H)

        # reshape (B,T,H) --> (B,T,n,h), where n=num_heads and h=head_size and H=n*h
        k = k.view(B,T,self.num_heads,self.head_size) 
        q = q.view(B,T,self.num_heads,self.head_size) 
        v = v.view(B,T,self.num_heads,self.head_size) 

        # now we transpose so that the num_heads is the second dimension followed by T,h
        # this allows us to batch matrix mutliply for all heads simulataneously to compute their attention weights
        # (B,T,n,h) --> (B,n,T,h) 
        k = k.transpose(1,2) 
        q = q.transpose(1,2)
        v = v.transpose(1,2)
        
        # use pytorch built-in function for faster computation of attention scores (set the 'is_causal' parameter for applying causal masking)
        out = F.scaled_dot_product_attention(q,k,v,dropout_p=self.dropout_rate if self.training else 0,is_causal=True)
        # we can transpose the output from (B,n,T,h) --> (B,T,n,h)
        # since the last two dimensions of the transposed tensor are non-contiguous, we apply 
        # contiguous() which return a contiguous tensor
        out = out.transpose(1,2).contiguous()
        # finally we collapse the last two dimensions to get the concatenated output, (B,T,n,h) --> (B,T,n*h) 
        out = out.view(B,T,self.total_head_size)
        # now we project the concatenated output so that it has the same dimensions as the multihead attention layer input
        # (we need to add it with the input because of the residual connection, so need to be same size) 
        out = self.proj(out) # (B,T,C) 
        # apply dropout
        out = self.output_dropout(out)
        return out

# a simple mlp 
class FeedForward(torch.nn.Module):
    def __init__(self, embedding_dim, dropout_rate):
        super().__init__()
        # we add extra computations by growing out the feed-forward hidden size by a factor of 4
        # we also add an extra linear layer at the end to project the residual back to same dimensions as input
        self.net = torch.nn.Sequential(
            torch.nn.Linear(embedding_dim, 4*embedding_dim),  
            torch.nn.GELU(),
            torch.nn.Linear(4*embedding_dim, embedding_dim), 
            torch.nn.Dropout(dropout_rate)
        )
    
    # in the forward pass, concatenate the outputs from all the attention heads
    def forward(self, x):
        return self.net(x)

# transformer block with residual connection and layer norm
class TransformerBlock(torch.nn.Module):
    def __init__(self, block_size, embedding_dim, head_size, num_heads, dropout_rate):
        super().__init__()
        self.sa = MultiHeadAttention(block_size, embedding_dim, head_size, num_heads, dropout_rate) # multi-head attention layer 
        self.ff = FeedForward(embedding_dim, dropout_rate)   # feed-forward layer
        self.ln1 = torch.nn.LayerNorm(embedding_dim) # layer norm at input of multi-head attention
        self.ln2 = torch.nn.LayerNorm(embedding_dim) # layer norm at input of feed-forward

    # in the forward pass, concatenate the outputs from all the attention heads
    def forward(self, x):
        # residual connection between input and multi-head attention output
        x = x + self.sa(self.ln1(x))
        # residual connection between multi-head attention output and feed-forward output
        x = x + self.ff(self.ln2(x)) 
        return x
    
# language model with multiple transformer blocks
class TransformerLanguageModel(torch.nn.Module):
    def __init__(self, vocab_size, block_size, embedding_dim, head_size, num_heads, num_blocks, dropout_rate=0.2, pad_token_idx=0):
        super().__init__()
        self.vocab_size = vocab_size
        self.block_size = block_size
        self.embedding_dim = embedding_dim
        self.head_size = head_size
        self.hum_heads = num_heads
        self.num_blocks = num_blocks
        self.pad_token_idx = pad_token_idx

        '''
        Define model parameters
        '''
        # token embedding layer 
        self.token_embedding = torch.nn.Embedding(vocab_size, embedding_dim) # shape: (vocab_size,C)
        # position embedding layer
        self.pos_embedding = torch.nn.Embedding(block_size, embedding_dim) # shape: (T,C)
        # stack of transformer blocks
        self.blocks = torch.nn.Sequential(*[TransformerBlock(block_size, embedding_dim, head_size, num_heads, dropout_rate) for _ in range(num_blocks)])
        # we also add a layer norm before the final output layer
        self.ln_f = torch.nn.LayerNorm(embedding_dim)
        # output layer logits
        self.lm_head = torch.nn.Linear(head_size, vocab_size) # shape: (h,vocab_size)


        # forward pass takes in a batch of input token sequences of shape (B,T) and corresponding targets of shape (B,T)
    def forward(self, idx, targets=None):
        B, T =idx.shape
        # get token embeddings
        token_embeds = self.token_embedding(idx) # (B,T,C)
        # add positional encoding
        pos_embeds = self.pos_embedding(torch.arange(T, device=idx.device)) # (T,C) 
        x = token_embeds + pos_embeds # (B,T,C)
        # pass through transformer blocks
        x = self.blocks(x) # (B,T,C)
        # apply layer norm
        x = self.ln_f(x)  # (B,T,C)
        # compute output logits 
        logits = self.lm_head(x) # (B,T,vocab_size)
        loss = None
        if targets is not None:
            B,T,vocab_size = logits.shape
            # reshape the logits and targets such that batch of input sequences are flattened into a single big input sequence
            # i.e. (B,T) --> (B*T)
            logits = logits.view(B*T,vocab_size) # reshaped to (B*T,vocab_size)
            targets = targets.view(B*T) # reshaped to (B*T)
            # compute cross entropy loss (i.e. average negative log likelihood)
            #loss = F.cross_entropy(logits, targets, ignore_index=self.pad_token_idx)
            loss = F.cross_entropy(logits, targets)
        return logits, loss
    
    # generates new sequences continuing from a given batch of context tokens
    @torch.no_grad()
    def generate(self, subword2idx, block_size, temperature=1.0, topk=None, start_token="<s>", end_token="</s>", max_len=30, device="cpu"):
        self.eval()
        # generate one token at a time
        x = torch.full(size=(1,1), fill_value=subword2idx[start_token], dtype=torch.long, device=device)
        tokens = [x.item()]
        for _ in range(max_len):
            # crop the input sequence so that it doesn't exceed block size (only keep the last block_size tokens in the sequence to generate the next token)
            x = x[:,-block_size:]
            logits = self.forward(x) # shape: (1,L,V)
            # rescale the logits with the temperature
            logits = logits / temperature
            if topk is not None:
                topk_logits, idx = torch.sort(logits[0,-1,:], descending=True)
                # sample from the distribution for the last word in the sequence
                p = F.softmax(topk_logits, dim=-1) # shape: (V,)
                next_word_idx = idx[torch.multinomial(p, num_samples=1)]
            else:             
                # sample from the distribution for the last word in the sequence
                p = F.softmax(logits[:,-1,:], dim=-1) # shape: (V,)
                next_word_idx = torch.multinomial(p, num_samples=1)
            # append to the sequence
            x = torch.cat((x, next_word_idx), dim=1)
            tokens.append(next_word_idx.item())

        self.train()
        return tokens


In [9]:
# training loop
def train(model, optimizer, scheduler, train_dataloader, val_dataloader, device="cpu", num_epochs=10, val_every=1, save_every=None, log_metrics=None):
    avg_loss = 0
    train_acc = 0
    val_loss = 0
    val_acc = 0
    pp = 0
    for epoch in range(num_epochs):
        num_correct = 0
        num_total = 0
        pbar = tqdm(train_dataloader, desc="Epochs")
        for batch in pbar:
            inputs, targets = batch
            # move batch to device
            inputs, targets = inputs.to(device), targets.to(device)
            # reset gradients
            optimizer.zero_grad()
            # forward pass
            logits, loss = model(inputs, targets)
            # backward pass
            loss.backward()
            # optimizer step
            optimizer.step()
            avg_loss = 0.9* avg_loss + 0.1*loss.item()
            y_pred = logits.argmax(dim=-1) # shape (B*L)
            y = targets.view(-1) # shape (B*L)
            mask = (y != -1)
            num_correct += (torch.eq(y[mask], y_pred[mask])).sum().item()
            num_total += len(y[mask])
            train_acc = num_correct / num_total        
            
            pbar.set_description(f"Epoch {epoch + 1}, EMA Train Loss: {avg_loss:.3f}, Train Accuracy: {train_acc: .3f}, Val Loss: {val_loss: .3f}, Val Accuracy: {val_acc: .3f}, Val Perplexity: {pp:.1f}")  

            if log_metrics:
                metrics = {"Batch loss" : loss.item(), "Moving Avg Loss" : avg_loss, "Val Loss": val_loss}
                log_metrics(metrics)

        scheduler.step()
        if epoch%val_every == 0:
            # compute validation loss
            val_loss, val_acc, pp = validation(model, val_dataloader, device=device)
            pbar.set_description(f"Epoch {epoch + 1}, EMA Train Loss: {avg_loss:.3f}, Train Accuracy: {train_acc: .3f}, Val Loss: {val_loss: .3f}, Val Accuracy: {val_acc: .3f}, Val Perplexity: {pp:.1f}") 

        if save_every is not None:
            if (epoch+1) % save_every == 0:
                save_model_checkpoint(model, optimizer, epoch, avg_loss)

def validation(model, val_dataloader, device="cpu"):
    model.eval()
    val_losses = torch.zeros(len(val_dataloader))
    with torch.no_grad():
        num_correct = 0
        num_total = 0
        for i,batch in enumerate(val_dataloader):
            inputs, targets = batch = batch
            inputs, targets = inputs.to(device), targets.to(device)
            logits, loss = model(inputs, targets)
            y_pred = logits.argmax(dim=-1) # shape (B*L)
            y = targets.view(-1) # shape (B*L)
            mask = (y != -1)
            num_correct += (torch.eq(y[mask], y_pred[mask])).sum().item()
            num_total += len(y[mask])
            val_losses[i] = loss.item()
    model.train()
    val_loss = val_losses.mean().item()
    val_accuracy = num_correct / num_total
    perplexity = math.exp(val_loss)
    return val_loss, val_accuracy, perplexity


def save_model_checkpoint(model, optimizer, epoch=None, loss=None):
    # Save the model and optimizer state_dict
    checkpoint = {
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'loss': loss,
    }

    # Save the checkpoint to a file
    torch.save(checkpoint, 'transformer_checkpoint.pth')
    print(f"Saved model checkpoint!")


def load_model_checkpoint(model, optimizer):
    checkpoint = torch.load('transformer_checkpoint.pth')
    model.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    model.train()
    print("Loaded model from checkpoint!")
    return model, optimizer      

In [10]:
L = 64
train_dataset = Shakespeare(x_train, block_size=L)
val_dataset = Shakespeare(x_val, block_size=L)

In [11]:
B = 32
D = 64
vocab_size = len(tokenizer.vocab)
num_heads = 8
num_layers = 4
learning_rate = 1e-4
DEVICE = "cuda"

train_dataloader = DataLoader(train_dataset, batch_size=B, shuffle=True, pin_memory=True, num_workers=2)
val_dataloader = DataLoader(val_dataset, batch_size=B, shuffle=True, pin_memory=True, num_workers=2)


model = TransformerLanguageModel(vocab_size, L, D, D, num_heads, num_layers, dropout_rate=0.2).to(DEVICE)
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
scheduler =  torch.optim.lr_scheduler.StepLR(optimizer, step_size=50, gamma=0.95)
#model, optimizer = load_model_checkpoint(model, optimizer)


num_params = sum(p.numel() for p in model.parameters())
print(f"Total number of parameters in transformer network: {num_params/1e6} M")
print(f"RAM used: {psutil.Process().memory_info().rss / (1024 * 1024):.2f} MB")

Total number of parameters in transformer network: 1.502293 M
RAM used: 725.34 MB


In [12]:
train(model, optimizer, scheduler, train_dataloader, val_dataloader, device=DEVICE, num_epochs=10, save_every=5, val_every=1) #, log_metrics=log_metrics)

Epochs:   0%|          | 0/9884 [00:00<?, ?it/s]

Epoch 1, EMA Train Loss: 4.200, Train Accuracy:  0.294, Val Loss:  0.000, Val Accuracy:  0.000, Val Perplexity: 0.0: 100%|██████████| 9884/9884 [01:53<00:00, 87.04it/s]
Epoch 2, EMA Train Loss: 3.858, Train Accuracy:  0.324, Val Loss:  4.461, Val Accuracy:  0.335, Val Perplexity: 86.6: 100%|██████████| 9884/9884 [02:01<00:00, 81.09it/s]
Epoch 3, EMA Train Loss: 3.595, Train Accuracy:  0.337, Val Loss:  4.495, Val Accuracy:  0.340, Val Perplexity: 89.6: 100%|██████████| 9884/9884 [02:02<00:00, 80.50it/s]


In [71]:
tokens = model.generate(tokenizer.subword2idx, block_size=L, temperature=0.9, topk=500, device=DEVICE, max_len=2000)
decoded = tokenizer.decode([tokens])
print(decoded[0])

with with with with with with with with with with with with with with with with with with with with with with with with with with with with with with with with with with with with with with with with with with with with with with with with with with with with with with with with with with with with with with with with with with with with with with with with with with with with with with with with with look with with ; ' with with with with with with with with with with so ' with with for with before let with with with with with with with with with with with with ; ' keeps ' ' ' ' ' set with with with with with distingu' 
  spingcent with the world with my country : 
  Against us the bastardy ' 't n't it is you , 
  For how the man would have wont with gold ; 
  As I can not believe not seeing belong the seas 
  Fordle your n't is ? 
  
  CORIOLANUS : 
  Come , or your word , in it well , 
  And bear him in butency them , it comes ; 
  It is a man father will serve in they 
  Bring them

#### Comaprison with RNN

In [9]:
class RNNLM(torch.nn.Module):
    def __init__(self, vocab_size, embedding_dim=32, num_rnn_layers=1, dropout_rate=0.1):
        super().__init__()

        # embedding layer
        self.emb = torch.nn.Embedding(vocab_size, embedding_dim)
        c = 0.1        
        torch.nn.init.uniform_(self.emb.weight, -c, c)

        # create rnn layers (we will use bidirectional LSTM so the output hidden states will have dims=2*hidden_dims)
        if num_rnn_layers == 1:
            self.rnn_layers = torch.nn.LSTM(input_size=embedding_dim, hidden_size=embedding_dim, num_layers=num_rnn_layers, batch_first=True)
        else:    
            self.rnn_layers = torch.nn.LSTM(input_size=embedding_dim, hidden_size=embedding_dim, num_layers=num_rnn_layers, batch_first=True, dropout=dropout_rate)
        self.dropout = torch.nn.Dropout(dropout_rate)
        # create output layer (computes output class logits for each item in sequence)
        self.output_layer =  torch.nn.Linear(embedding_dim, vocab_size)
        # tie the output layer weights with the embedding layer weights
        self.output_layer.weight = self.emb.weight

    # forward pass
    def forward(self, x, y=None):
        # get embeddings for batch of input sequences of length L
        x = self.emb(x) # shape: (B,L,D)
        # apply dropout
        x = self.dropout(x)
        # compute rnn hidden states
        x, _ = self.rnn_layers(x) # shape: (B,L,D)
        # apply dropout
        x = self.dropout(x)
        # compute output logits
        x = self.output_layer(x) # shape: (B,L,vocab_size)

        if y==None:
            return x

        # reshape
        x = x.view(-1,x.shape[-1]) # shape: (B*L,vocab_size)
        y = y.view(-1) # shape: (B*L,)
        # compute cross entropy loss
        loss = F.cross_entropy(x, y)
        return x, loss
    
    @torch.no_grad()
    def generate(self, vocab, subword2idx, temperature=1.0, topk=None, start_token="<s>", end_token="</s>", max_len=30, device="cpu"):
        self.eval()
        # generate one word at a time
        x = torch.full(size=(1,1), fill_value=subword2idx[start_token], dtype=torch.long, device=device)
        tokens = [x.item()]
        for _ in range(max_len):
            logits = self.forward(x) # shape: (1,L,V)
            # rescale the logits with the temperature
            logits = logits / temperature
            if topk is not None:
                topk_logits, idx = torch.sort(logits[0,-1,:], descending=True)
                # sample from the distribution for the last word in the sequence
                p = F.softmax(topk_logits, dim=-1) # shape: (V,)
                next_word_idx = idx[torch.multinomial(p, num_samples=1)]
            else:             
                # sample from the distribution for the last word in the sequence
                p = F.softmax(logits[:,-1,:], dim=-1) # shape: (V,)
                next_word_idx = torch.multinomial(p, num_samples=1)
            # append to the sequence
            x = torch.cat((x, next_word_idx.view(1,1)), dim=1)
            tokens.append(next_word_idx.item())
        
        self.train()
        return tokens


In [12]:
B = 32
D = 128
vocab_size = len(tokenizer.vocab)
num_rnn_layers = 4
learning_rate = 1e-4
DEVICE = "cuda"

train_dataloader = DataLoader(train_dataset, batch_size=B, shuffle=True, pin_memory=True, num_workers=2)
val_dataloader = DataLoader(val_dataset, batch_size=B, shuffle=True, pin_memory=True, num_workers=2)

rnn_model = RNNLM(vocab_size, D, num_rnn_layers=num_rnn_layers, dropout_rate=0.2).to(DEVICE)
optimizer = torch.optim.AdamW(rnn_model.parameters(), lr=learning_rate)
scheduler =  torch.optim.lr_scheduler.StepLR(optimizer, step_size=50, gamma=0.95)

num_params = sum(p.numel() for p in rnn_model.parameters())
print(f"Total number of parameters in transformer network: {num_params/1e6} M")
print(f"RAM used: {psutil.Process().memory_info().rss / (1024 * 1024):.2f} MB")

Total number of parameters in transformer network: 1.827285 M
RAM used: 740.54 MB


In [15]:
train(rnn_model, optimizer, scheduler, train_dataloader, val_dataloader, device=DEVICE, num_epochs=10, val_every=1) #, log_metrics=log_metrics)

Epochs:   0%|          | 0/9885 [00:00<?, ?it/s]

Epoch 1, EMA Train Loss: 5.831, Train Accuracy:  0.113, Val Loss:  0.000, Val Accuracy:  0.000, Val Perplexity: 0.0: 100%|██████████| 9885/9885 [01:06<00:00, 149.50it/s]
Epoch 2, EMA Train Loss: 4.461, Train Accuracy:  0.268, Val Loss:  5.890, Val Accuracy:  0.122, Val Perplexity: 361.4: 100%|██████████| 9885/9885 [01:08<00:00, 144.45it/s]
Epoch 3, EMA Train Loss: 4.263, Train Accuracy:  0.306, Val Loss:  4.697, Val Accuracy:  0.320, Val Perplexity: 109.6: 100%|██████████| 9885/9885 [01:06<00:00, 147.63it/s]
Epoch 4, EMA Train Loss: 4.117, Train Accuracy:  0.317, Val Loss:  4.578, Val Accuracy:  0.323, Val Perplexity: 97.3: 100%|██████████| 9885/9885 [01:06<00:00, 148.57it/s]
Epoch 5, EMA Train Loss: 4.027, Train Accuracy:  0.324, Val Loss:  4.505, Val Accuracy:  0.338, Val Perplexity: 90.5: 100%|██████████| 9885/9885 [01:05<00:00, 150.08it/s]


Saved model checkpoint!


Epoch 6, EMA Train Loss: 3.962, Train Accuracy:  0.331, Val Loss:  4.472, Val Accuracy:  0.341, Val Perplexity: 87.5: 100%|██████████| 9885/9885 [01:05<00:00, 150.63it/s]
Epoch 7, EMA Train Loss: 3.899, Train Accuracy:  0.336, Val Loss:  4.452, Val Accuracy:  0.343, Val Perplexity: 85.8: 100%|██████████| 9885/9885 [01:06<00:00, 147.63it/s]
Epoch 8, EMA Train Loss: 3.758, Train Accuracy:  0.341, Val Loss:  4.445, Val Accuracy:  0.344, Val Perplexity: 85.2: 100%|██████████| 9885/9885 [01:07<00:00, 146.77it/s]
Epoch 9, EMA Train Loss: 3.741, Train Accuracy:  0.345, Val Loss:  4.442, Val Accuracy:  0.345, Val Perplexity: 84.9: 100%|██████████| 9885/9885 [01:08<00:00, 144.09it/s]
Epoch 10, EMA Train Loss: 3.709, Train Accuracy:  0.348, Val Loss:  4.437, Val Accuracy:  0.345, Val Perplexity: 84.5: 100%|██████████| 9885/9885 [01:10<00:00, 139.50it/s]


Saved model checkpoint!


In [16]:
tokens = rnn_model.generate(tokenizer.vocab, tokenizer.subword2idx, temperature=0.9, topk=None, device=DEVICE, max_len=2000)
decoded = tokenizer.decode([tokens])
print(decoded[0])

the butchers of me . 
  
  POMPEY : 
  Look , if he be gone that hath a strange thing . 
  
  VOLUMNIA : 
  I think you Has Friar villains : they intend for part , 
  for a affairs made honour , his brother 's caught . 
  
  He : 
  Out of thy place , they faint dead , thus that if more , 
  And pays revenge Gmseins in deformage , 
  That he can scarce church break a crown of me , 
  And for our walls , so make my prettiest life 
  And they hath thought our Katharina to wheel , 
  That we hear'st a staves will want in death . 
  
  KING RICHARD II : 
  I am bound in them , dinner , I 'll go of thee . 
  
  CLIFFORD : 
  Sir , fair boy ! away ! see the forth thy grave ; 
  Or on thou plead there I give thee the man , 
  That thou hadst so lost my lord , and have wont ? 
  
  BUCKINGHAM : 
  Never , there will resolve her ay to the journey , 
  But means with the mantness to bear him not , 
  We enmented has myself : but then they 
  Wmnved of their aps of welcome of a county : 
  So are

#### Even with over 4 times as many parameters, the RNN still does not perform nearly as well as the Transformer!