<a href="https://colab.research.google.com/github/victorm0202/temas_selectos_CD/blob/main/language_model_tarea.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Word-level language modeling with RNNs

## The imports

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import math
import torchtext
import datasets
from torch.utils.data import DataLoader
from tqdm import tqdm

In [None]:
from torchtext.data import get_tokenizer
from torchtext.vocab import build_vocab_from_iterator

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

device(type='cuda')

## The data

On this example, we will use [wikitext](https://huggingface.co/datasets/wikitext/viewer/wikitext-2-raw-v1/test) corpus from Hugging Face.

In [None]:
train_dataset = datasets.load_dataset('wikitext', 'wikitext-2-raw-v1', split='train')
valid_dataset = datasets.load_dataset('wikitext', 'wikitext-2-raw-v1', split='validation')
test_dataset = datasets.load_dataset('wikitext', 'wikitext-2-raw-v1', split='test')

Reusing dataset wikitext (/home/victor_nuevo/.cache/huggingface/datasets/wikitext/wikitext-2-raw-v1/1.0.0/a241db52902eaf2c6aa732210bead40c090019a499ceb13bcbfa3f8ab646a126)
Reusing dataset wikitext (/home/victor_nuevo/.cache/huggingface/datasets/wikitext/wikitext-2-raw-v1/1.0.0/a241db52902eaf2c6aa732210bead40c090019a499ceb13bcbfa3f8ab646a126)
Reusing dataset wikitext (/home/victor_nuevo/.cache/huggingface/datasets/wikitext/wikitext-2-raw-v1/1.0.0/a241db52902eaf2c6aa732210bead40c090019a499ceb13bcbfa3f8ab646a126)


In [None]:
print(train_dataset[88]['text'])
print(type(train_dataset))

 This ammunition , and that which I brought with me , was rapidly prepared for use at the Laboratory established at the Little Rock Arsenal for that purpose . As illustrating as the pitiful scarcity of material in the country , the fact may be stated that it was found necessary to use public documents of the State Library for cartridge paper . Gunsmiths were employed or conscripted , tools purchased or impressed , and the repair of the damaged guns I brought with me and about an equal number found at Little Rock commenced at once . But , after inspecting the work and observing the spirit of the men I decided that a garrison 500 strong could hold out against Fitch and that I would lead the remainder - about 1500 - to Gen 'l Rust as soon as shotguns and rifles could be obtained from Little Rock instead of pikes and lances , with which most of them were armed . Two days elapsed before the change could be effected . " 

<class 'datasets.arrow_dataset.Dataset'>


## Tokenize text data and build the vocabulary

In [None]:
def tokenize_data(example,tokenizer):
  return  {'tokens': tokenizer(example['text'])}

tokenizer = get_tokenizer('basic_english')
tokenized_train_dataset = train_dataset.map(tokenize_data, remove_columns=['text'], fn_kwargs={'tokenizer': tokenizer})
tokenized_test_dataset = test_dataset.map(tokenize_data, remove_columns=['text'], fn_kwargs={'tokenizer': tokenizer})
tokenized_valid_dataset = valid_dataset.map(tokenize_data, remove_columns=['text'], fn_kwargs={'tokenizer': tokenizer})

Loading cached processed dataset at /home/victor_nuevo/.cache/huggingface/datasets/wikitext/wikitext-2-raw-v1/1.0.0/a241db52902eaf2c6aa732210bead40c090019a499ceb13bcbfa3f8ab646a126/cache-66033c7ed0b2d7d0.arrow
Loading cached processed dataset at /home/victor_nuevo/.cache/huggingface/datasets/wikitext/wikitext-2-raw-v1/1.0.0/a241db52902eaf2c6aa732210bead40c090019a499ceb13bcbfa3f8ab646a126/cache-8dd5ae197919e15f.arrow
Loading cached processed dataset at /home/victor_nuevo/.cache/huggingface/datasets/wikitext/wikitext-2-raw-v1/1.0.0/a241db52902eaf2c6aa732210bead40c090019a499ceb13bcbfa3f8ab646a126/cache-0472de18fc681116.arrow


In [None]:
print(tokenized_train_dataset[88]['tokens'])
type(tokenized_train_dataset)

['this', 'ammunition', ',', 'and', 'that', 'which', 'i', 'brought', 'with', 'me', ',', 'was', 'rapidly', 'prepared', 'for', 'use', 'at', 'the', 'laboratory', 'established', 'at', 'the', 'little', 'rock', 'arsenal', 'for', 'that', 'purpose', '.', 'as', 'illustrating', 'as', 'the', 'pitiful', 'scarcity', 'of', 'material', 'in', 'the', 'country', ',', 'the', 'fact', 'may', 'be', 'stated', 'that', 'it', 'was', 'found', 'necessary', 'to', 'use', 'public', 'documents', 'of', 'the', 'state', 'library', 'for', 'cartridge', 'paper', '.', 'gunsmiths', 'were', 'employed', 'or', 'conscripted', ',', 'tools', 'purchased', 'or', 'impressed', ',', 'and', 'the', 'repair', 'of', 'the', 'damaged', 'guns', 'i', 'brought', 'with', 'me', 'and', 'about', 'an', 'equal', 'number', 'found', 'at', 'little', 'rock', 'commenced', 'at', 'once', '.', 'but', ',', 'after', 'inspecting', 'the', 'work', 'and', 'observing', 'the', 'spirit', 'of', 'the', 'men', 'i', 'decided', 'that', 'a', 'garrison', '500', 'strong', 'co

datasets.arrow_dataset.Dataset

You can 'expand' your vocabulary by concatenating the datasets, for instance:

`build_vocab(datasets.concatenate_datasets([train_dataset, valid_dataset]))`

but keep in mind that computational resources will increase...


In [None]:
def build_vocab(dataset):
    tokens = dataset.map(tokenize_data, remove_columns=['text'], fn_kwargs={'tokenizer': tokenizer})
    return tokens

vocab = build_vocab_from_iterator(build_vocab(train_dataset)['tokens'], specials=["<UNK>", "<EOS>"], min_freq=3)
vocab.set_default_index(vocab["<UNK>"])

print(len(vocab))                         
print(vocab.get_itos()[:10])

Loading cached processed dataset at /home/victor_nuevo/.cache/huggingface/datasets/wikitext/wikitext-2-raw-v1/1.0.0/a241db52902eaf2c6aa732210bead40c090019a499ceb13bcbfa3f8ab646a126/cache-66033c7ed0b2d7d0.arrow


29473
['<UNK>', '<EOS>', 'the', ',', '.', 'of', 'and', 'in', 'to', 'a']


Example...

In [None]:
#observe that, with this configuration of the vocabulary, the word 'hello' is not present, 
# therefore, assign the UNK token (index=0)
tokens = tokenizer("hello and good morning to everybody!") 
indexes = vocab(tokens)

tokens, indexes

(['hello', 'and', 'good', 'morning', 'to', 'everybody', '!'],
 [0, 6, 426, 1041, 8, 7063, 386])

## Dataloaders

For language models we need to process text sequences with a defined length. There are many ways to get batches of sequences for this task, but in our case, we will consider sequences from *all* texts in our corpus. As an example, take the follow sequences taken from our corpus:

- *“the more you read, the more things you will know and learn”*
- *“curiosity is the wick in the candle of learning”*
- *“eventually things start making sense, be patient”*

Then, for a batch size 5, we will obtain the following tensor (observe that the function appends special token <EOS> at the end of each sentence):

![batch1](https://drive.google.com/uc?id=12Yw8fzaczFXbyPLNOcb-R2P2ZUX4H_UY)


This function implement this idea, by processing sequences of tokenized text and returning a vectorized version (according to the indexes in vocab) into a tensor.

In [None]:
def batchify(dataset, vocab, batch_size):
    data = []                                                   
    for example in dataset:
        if example['tokens']:                                      
            tokens = example['tokens'].append('<EOS>')             
            tokens = [vocab[token] for token in example['tokens']] 
            data.extend(tokens)                                    
    data = torch.LongTensor(data)                                 
    num_batches = data.shape[0] // batch_size 
    data = data[:num_batches * batch_size]                       
    data = data.view(batch_size, num_batches)          
    return data


Of course, `batchify` function will return a tensor with numbers, which corresponds to the index of the words. 

Observe that some tokens are not included, because the size of the tensor depends on the length of the sentences needed to obtain the batches (`batch_size`) you requested: `num_batches = data.shape[0] // batch_size`

Now, we need to define a sequence length. Our model will use as input, a tensor of shape `[N, L, E]` where `N` is the batch_size and `L` is the sequence length and `E` is the size of the embedding for each token. Let's forget for the moment the size of the embedding, and consider a 2D tensor with `[batch_size=4, num_batches=24]`, as we show in the next figure:

![batch2](https://drive.google.com/uc?id=1KmILYG5gjeSjnV8LxjV673EhAWgakgQW)

If we decide to use a sequence length `L=4`, then a training epoch of the model will consists on 6 iterations, because each color corresponds to a *batch of sequences* which is one feedforward pass to the model:

![batch3](https://drive.google.com/uc?id=1YUmGMqEm0ggXgqm6zGKo_w2TW-fNSzYT)

Then, as you can see, some of the sequences that will be fed to the model may involve parts from different sequences in the original dataset or be a subset of one (depending on the sequence length `L`). For this reason we will later only reset the hidden state every epoch.

In [None]:
batch_size = 128
train_data = batchify(tokenized_train_dataset, vocab, batch_size)
valid_data = batchify(tokenized_valid_dataset, vocab, batch_size)
test_data = batchify(tokenized_test_dataset, vocab, batch_size)

In [None]:
train_data.shape

torch.Size([128, 16214])

## The model

In [None]:
class LSTM_LM(nn.Module):
    def __init__(self, vocab_size, embedding_dim, hidden_dim, num_layers, dropout_rate):                
        super().__init__()
        self.num_layers = num_layers
        self.hidden_dim = hidden_dim
        self.embedding_dim = embedding_dim

        self.embedding = nn.Embedding(vocab_size, embedding_dim)
        self.lstm = nn.LSTM(embedding_dim, hidden_dim, num_layers=num_layers, 
                    dropout=dropout_rate, batch_first=True)
        self.dropout = nn.Dropout(dropout_rate)
        self.fc = nn.Linear(hidden_dim, vocab_size)
        
        self.init_weights()

    def forward(self, src, hidden):
        embedding = self.dropout(self.embedding(src))
        output, hidden = self.lstm(embedding, hidden)          
        output = self.dropout(output) 
        prediction = self.fc(output)
        return prediction, hidden

    def init_weights(self):
        init_range_emb = 0.1
        init_range_other = 1/math.sqrt(self.hidden_dim)
        self.embedding.weight.data.uniform_(-init_range_emb, init_range_emb)
        self.fc.weight.data.uniform_(-init_range_other, init_range_other)
        self.fc.bias.data.zero_()
        for i in range(self.num_layers):
            self.lstm.all_weights[i][0] = torch.FloatTensor(self.embedding_dim,
                    self.hidden_dim).uniform_(-init_range_other, init_range_other) 
            self.lstm.all_weights[i][1] = torch.FloatTensor(self.hidden_dim, 
                    self.hidden_dim).uniform_(-init_range_other, init_range_other) 

    def init_hidden(self, batch_size, device):
        hidden = torch.zeros(self.num_layers, batch_size, self.hidden_dim).to(device)
        cell = torch.zeros(self.num_layers, batch_size, self.hidden_dim).to(device)
        return hidden, cell

    def detach_hidden(self, hidden):
        hidden, cell = hidden
        hidden = hidden.detach()
        cell = cell.detach()
        return hidden, cell

In [None]:
vocab_size = len(vocab)
embedding_dim = 500
hidden_dim = 500  
num_layers = 2    
dropout_rate = 0.65 
lr = 1e-3         

In [None]:
model = LSTM_LM(vocab_size, embedding_dim, hidden_dim, num_layers, dropout_rate).to(device)
optimizer = optim.Adam(model.parameters(), lr=lr)
criterion = nn.CrossEntropyLoss()
num_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f'The model has {num_params:,} trainable parameters')

The model has 33,510,473 trainable parameters


## Training

In order to train our model, we need training and target values. In the case of a language model, the target values will be the next tokens of the sequence. Those values (batches of text sequences), are defined by the `batch_size` and `sequence_length` parameters. 

We need sequences of the same length, and altough there are different ways to obtain those sequences (zero padding, for example), we will use texts from all training sequences, as we explained before, truncating the sequence when necessary. 

Given a sequence length, target and sources sequences are defined as follow. For illustration purposes, we will use the later example (`batch_size=4, num_batches=24, seq_length=4`). For the first and second batch, we have:

![batch4](https://drive.google.com/uc?id=1bXjWfqyPCqNAwsGJCQ19WRl8kACN57kb)

![batch5](https://drive.google.com/uc?id=1Gxl_lhRdLFxIk23T0Cmbr378HAsjlo1Z)

Observe that, in the training process, the model has *learned* the past sequences (yellow tensors). Finally, the last batches we use for training are:

![batch6](https://drive.google.com/uc?id=1oC53-9wMXTj9Dx9zzJmOCt-4NRy30BFn)


The `get_batch` function implement this idea.

In [None]:
def get_batch(data, seq_len, idx):
    src = data[:, idx:idx+seq_len]                   
    target = data[:, idx+1:idx+seq_len+1]             
    return src, target

In [None]:
def train(model, data, optimizer, criterion, batch_size, seq_len, clip, device):   
    epoch_loss = 0
    model.train()
    # drop all batches that are not a multiple of seq_len
    num_batches = data.shape[-1]
    data = data[:, :num_batches - (num_batches -1) % seq_len]
    num_batches = data.shape[-1]

    hidden = model.init_hidden(batch_size, device)
    # The last batch can't be a src
    for idx in tqdm(range(0, num_batches - 1, seq_len), desc='Training: ',leave=False):
        #zero the gradients due to the previous batch and detach its hidden state
        optimizer.zero_grad()
        hidden = model.detach_hidden(hidden)

        src, target = get_batch(data, seq_len, num_batches, idx)
        src, target = src.to(device), target.to(device)
        batch_size = src.shape[0]
        prediction, hidden = model(src, hidden)               

        prediction = prediction.reshape(batch_size * seq_len, -1)   
        target = target.reshape(-1)
        loss = criterion(prediction, target)
        
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), clip)
        optimizer.step()
        epoch_loss += loss.item() * seq_len
    return epoch_loss / num_batches

In [None]:
def evaluate(model, data, criterion, batch_size, seq_len, device):
    epoch_loss = 0
    model.eval()
    num_batches = data.shape[-1]
    data = data[:, :num_batches - (num_batches -1) % seq_len]
    num_batches = data.shape[-1]

    hidden = model.init_hidden(batch_size, device)

    with torch.no_grad():
        for idx in range(0, num_batches - 1, seq_len):
            hidden = model.detach_hidden(hidden)
            src, target = get_batch(data, seq_len, num_batches, idx)
            src, target = src.to(device), target.to(device)
            batch_size= src.shape[0]

            prediction, hidden = model(src, hidden)
            prediction = prediction.reshape(batch_size * seq_len, -1)
            target = target.reshape(-1)

            loss = criterion(prediction, target)
            epoch_loss += loss.item() * seq_len
    return epoch_loss / num_batches

In [None]:
n_epochs = 50
seq_len = 50
clip = 0.25
saved = True

#reduce the learning rate by a factor of 2 after every epoch associated with no improvement
lr_scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor=0.5, patience=0)

if saved:
    model.load_state_dict(torch.load('best-val-languageModel.pt',  map_location=device))
    test_loss = evaluate(model, test_data, criterion, batch_size, seq_len, device)
    print(f'Test Perplexity: {math.exp(test_loss):.3f}')
else:
    best_valid_loss = float('inf')

    for epoch in range(n_epochs):
        train_loss = train(model, train_data, optimizer, criterion, 
                    batch_size, seq_len, clip, device)
        valid_loss = evaluate(model, valid_data, criterion, batch_size, 
                    seq_len, device)
        
        lr_scheduler.step(valid_loss)

        if valid_loss < best_valid_loss:
            best_valid_loss = valid_loss
            torch.save(model.state_dict(), './best-val-languageModel.pt')

        print(f'\tTrain Perplexity: {math.exp(train_loss):.3f}')
        print(f'\tValid Perplexity: {math.exp(valid_loss):.3f}')

                                                           

	Train Perplexity: 1001.656
	Valid Perplexity: 468.221


                                                           

	Train Perplexity: 485.633
	Valid Perplexity: 315.350


                                                           

	Train Perplexity: 356.289
	Valid Perplexity: 259.623


                                                           

	Train Perplexity: 292.501
	Valid Perplexity: 227.320


                                                           

	Train Perplexity: 250.317
	Valid Perplexity: 203.243


                                                           

	Train Perplexity: 219.318
	Valid Perplexity: 192.013


                                                           

	Train Perplexity: 196.289
	Valid Perplexity: 181.825


                                                           

	Train Perplexity: 179.278
	Valid Perplexity: 164.673


                                                           

	Train Perplexity: 165.218
	Valid Perplexity: 160.545


                                                           

	Train Perplexity: 154.044
	Valid Perplexity: 153.266


                                                           

	Train Perplexity: 144.951
	Valid Perplexity: 148.865


                                                           

	Train Perplexity: 136.831
	Valid Perplexity: 147.879


                                                           

	Train Perplexity: 130.641
	Valid Perplexity: 143.221


                                                           

	Train Perplexity: 124.910
	Valid Perplexity: 140.935


                                                           

	Train Perplexity: 120.095
	Valid Perplexity: 139.516


                                                           

	Train Perplexity: 115.223
	Valid Perplexity: 137.081


                                                           

	Train Perplexity: 111.775
	Valid Perplexity: 134.631


                                                           

	Train Perplexity: 108.342
	Valid Perplexity: 134.324


                                                           

	Train Perplexity: 104.862
	Valid Perplexity: 131.467


                                                           

	Train Perplexity: 102.133
	Valid Perplexity: 132.342


                                                           

	Train Perplexity: 97.002
	Valid Perplexity: 129.051


                                                           

	Train Perplexity: 94.794
	Valid Perplexity: 127.641


                                                           

	Train Perplexity: 93.058
	Valid Perplexity: 128.014


                                                           

	Train Perplexity: 90.579
	Valid Perplexity: 128.304


                                                           

	Train Perplexity: 89.947
	Valid Perplexity: 127.529


                                                           

	Train Perplexity: 89.490
	Valid Perplexity: 127.047


                                                           

	Train Perplexity: 88.876
	Valid Perplexity: 126.962


                                                           

	Train Perplexity: 88.313
	Valid Perplexity: 126.367


                                                           

	Train Perplexity: 87.851
	Valid Perplexity: 126.316


                                                           

	Train Perplexity: 88.082
	Valid Perplexity: 124.769


                                                           

	Train Perplexity: 87.714
	Valid Perplexity: 124.405


                                                           

	Train Perplexity: 87.455
	Valid Perplexity: 124.572


                                                           

	Train Perplexity: 88.110
	Valid Perplexity: 124.197


                                                           

	Train Perplexity: 87.935
	Valid Perplexity: 123.895


                                                           

	Train Perplexity: 87.706
	Valid Perplexity: 123.907


                                                           

	Train Perplexity: 88.727
	Valid Perplexity: 123.905


                                                           

	Train Perplexity: 89.328
	Valid Perplexity: 123.868


                                                           

	Train Perplexity: 89.977
	Valid Perplexity: 123.886


                                                           

	Train Perplexity: 90.075
	Valid Perplexity: 123.884


                                                           

	Train Perplexity: 90.096
	Valid Perplexity: 123.876


                                                           

	Train Perplexity: 90.186
	Valid Perplexity: 123.878


                                                           

	Train Perplexity: 90.217
	Valid Perplexity: 123.879


                                                           

	Train Perplexity: 90.216
	Valid Perplexity: 123.879


                                                           

	Train Perplexity: 90.041
	Valid Perplexity: 123.879


                                                           

	Train Perplexity: 90.029
	Valid Perplexity: 123.879


                                                           

	Train Perplexity: 90.225
	Valid Perplexity: 123.879


                                                           

	Train Perplexity: 90.153
	Valid Perplexity: 123.879


                                                           

	Train Perplexity: 90.334
	Valid Perplexity: 123.879


                                                           

	Train Perplexity: 90.237
	Valid Perplexity: 123.879


                                                           

	Train Perplexity: 90.333
	Valid Perplexity: 123.878


## Inference

In [None]:
def generate(prompt, max_seq_len, model, tokenizer, vocab, device, seed=None):
    if seed is not None:
        torch.manual_seed(seed)
    model.eval()
    tokens = tokenizer(prompt)
    indices = [vocab[t] for t in tokens]
    batch_size = 1
    hidden = model.init_hidden(batch_size, device)
    with torch.no_grad():
        for i in range(max_seq_len):
            src = torch.LongTensor([indices]).to(device)
            prediction, hidden = model(src, hidden)
            probs = torch.softmax(prediction[:, -1], dim=-1)  
            prediction = torch.multinomial(probs, num_samples=1).item()    
            
            while prediction == vocab['<UNK>']:
                prediction = torch.multinomial(probs, num_samples=1).item()

            if prediction == vocab['<EOS>']:
                break

            indices.append(prediction)

    itos = vocab.get_itos()
    tokens = [itos[i] for i in indices]
    return tokens

In [None]:
prompt = 'Remember that'
max_seq_len = 30
seed = 45

generation = generate(prompt, max_seq_len, model, tokenizer, vocab, device, seed)
print('Generated text:\n'+' '.join(generation)+'\n')

Generated text:
remember that he was a victim and succeed when they could do happens . he grew or melting , and give he off the more thought to play her not a .

