We will use the datasets library from HuggingFace to load and map over the dataset, Torchtext to tokenize the dataset and construct the vocabulary and PyTorch to define, train and evaluate the model. The purpose of tqdm is just to show progress bars during training and evaluation

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim

import math

import torchtext

import datasets

from tqdm import tqdm

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
torch.manual_seed(0)

<torch._C.Generator at 0x7f2e456eb6d0>

Loading the Dataset

In [None]:
dataset = datasets.load_dataset('wikitext', 'wikitext-2-raw-v1')



Downloading builder script:   0%|          | 0.00/8.48k [00:00<?, ?B/s]

Downloading metadata:   0%|          | 0.00/6.84k [00:00<?, ?B/s]

Downloading readme:   0%|          | 0.00/9.25k [00:00<?, ?B/s]

Downloading and preparing dataset wikitext/wikitext-2-raw-v1 to /root/.cache/huggingface/datasets/wikitext/wikitext-2-raw-v1/1.0.0/a241db52902eaf2c6aa732210bead40c090019a499ceb13bcbfa3f8ab646a126...


Downloading data:   0%|          | 0.00/4.72M [00:00<?, ?B/s]

Generating test split:   0%|          | 0/4358 [00:00<?, ? examples/s]

Generating train split:   0%|          | 0/36718 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/3760 [00:00<?, ? examples/s]

Dataset wikitext downloaded and prepared to /root/.cache/huggingface/datasets/wikitext/wikitext-2-raw-v1/1.0.0/a241db52902eaf2c6aa732210bead40c090019a499ceb13bcbfa3f8ab646a126. Subsequent calls will reuse this data.


  0%|          | 0/3 [00:00<?, ?it/s]

Tokenizing the Dataset

In [None]:
tokenizer = torchtext.data.utils.get_tokenizer('basic_english')
tokenize_data = lambda example, tokenizer: {'tokens': tokenizer(example['text'])}  
tokenized_dataset = dataset.map(tokenize_data, remove_columns=['text'], 
fn_kwargs={'tokenizer': tokenizer})
print(tokenized_dataset['train'][88]['tokens'])

  0%|          | 0/4358 [00:00<?, ?ex/s]

  0%|          | 0/36718 [00:00<?, ?ex/s]

  0%|          | 0/3760 [00:00<?, ?ex/s]

['this', 'ammunition', ',', 'and', 'that', 'which', 'i', 'brought', 'with', 'me', ',', 'was', 'rapidly', 'prepared', 'for', 'use', 'at', 'the', 'laboratory', 'established', 'at', 'the', 'little', 'rock', 'arsenal', 'for', 'that', 'purpose', '.', 'as', 'illustrating', 'as', 'the', 'pitiful', 'scarcity', 'of', 'material', 'in', 'the', 'country', ',', 'the', 'fact', 'may', 'be', 'stated', 'that', 'it', 'was', 'found', 'necessary', 'to', 'use', 'public', 'documents', 'of', 'the', 'state', 'library', 'for', 'cartridge', 'paper', '.', 'gunsmiths', 'were', 'employed', 'or', 'conscripted', ',', 'tools', 'purchased', 'or', 'impressed', ',', 'and', 'the', 'repair', 'of', 'the', 'damaged', 'guns', 'i', 'brought', 'with', 'me', 'and', 'about', 'an', 'equal', 'number', 'found', 'at', 'little', 'rock', 'commenced', 'at', 'once', '.', 'but', ',', 'after', 'inspecting', 'the', 'work', 'and', 'observing', 'the', 'spirit', 'of', 'the', 'men', 'i', 'decided', 'that', 'a', 'garrison', '500', 'strong', 'co

Constructing the Vocabulary

In [None]:
vocab = torchtext.vocab.build_vocab_from_iterator(tokenized_dataset['train']['tokens'], 
min_freq=3) 
vocab.insert_token('<unk>', 0)           
vocab.insert_token('<eos>', 1)            
vocab.set_default_index(vocab['<unk>'])   
print(len(vocab))                         
print(vocab.get_itos()[:10])  

29473
['<unk>', '<eos>', 'the', ',', '.', 'of', 'and', 'in', 'to', 'a']


Implementing the Dataloaders



In [None]:
def get_data(dataset, vocab, batch_size):
    data = []                                                   
    for example in dataset:
        if example['tokens']:                                      
            tokens = example['tokens'].append('<eos>')             
            tokens = [vocab[token] for token in example['tokens']] 
            data.extend(tokens)                                    
    data = torch.LongTensor(data)                                 
    num_batches = data.shape[0] // batch_size 
    data = data[:num_batches * batch_size]                       
    data = data.view(batch_size, num_batches)          
    return data


In [None]:
batch_size = 128
train_data = get_data(tokenized_dataset['train'], vocab, batch_size)
valid_data = get_data(tokenized_dataset['validation'], vocab, batch_size)
test_data = get_data(tokenized_dataset['test'], vocab, batch_size)

Defining the Model

In [None]:
class LSTM(nn.Module):
    def __init__(self, vocab_size, embedding_dim, hidden_dim, num_layers, dropout_rate, 
                tie_weights):
                
        super().__init__()
        self.num_layers = num_layers
        self.hidden_dim = hidden_dim
        self.embedding_dim = embedding_dim

        self.embedding = nn.Embedding(vocab_size, embedding_dim)
        self.lstm = nn.LSTM(embedding_dim, hidden_dim, num_layers=num_layers, 
                    dropout=dropout_rate, batch_first=True)
        self.dropout = nn.Dropout(dropout_rate)
        self.fc = nn.Linear(hidden_dim, vocab_size)
        
        if tie_weights:
            assert embedding_dim == hidden_dim, 'cannot tie, check dims'
            self.embedding.weight = self.fc.weight
        self.init_weights()

    def forward(self, src, hidden):
        embedding = self.dropout(self.embedding(src))
        output, hidden = self.lstm(embedding, hidden)          
        output = self.dropout(output) 
        prediction = self.fc(output)
        return prediction, hidden

    def init_weights(self):
        init_range_emb = 0.1
        init_range_other = 1/math.sqrt(self.hidden_dim)
        self.embedding.weight.data.uniform_(-init_range_emb, init_range_emb)
        self.fc.weight.data.uniform_(-init_range_other, init_range_other)
        self.fc.bias.data.zero_()
        for i in range(self.num_layers):
            self.lstm.all_weights[i][0] = torch.FloatTensor(self.embedding_dim,
                    self.hidden_dim).uniform_(-init_range_other, init_range_other) 
            self.lstm.all_weights[i][1] = torch.FloatTensor(self.hidden_dim, 
                    self.hidden_dim).uniform_(-init_range_other, init_range_other) 

    def init_hidden(self, batch_size, device):
        hidden = torch.zeros(self.num_layers, batch_size, self.hidden_dim).to(device)
        cell = torch.zeros(self.num_layers, batch_size, self.hidden_dim).to(device)
        return hidden, cell
    def detach_hidden(self, hidden):
        hidden, cell = hidden
        hidden = hidden.detach()
        cell = cell.detach()
        return hidden, cell

Hyperparameter Tuning & Model Initialization

In [None]:

vocab_size = len(vocab)
embedding_dim = 400             
hidden_dim = 400                
num_layers = 3                  
dropout_rate = 0.65              
tie_weights = True                  
lr = 1e-3    

In [None]:
model = LSTM(vocab_size, embedding_dim, hidden_dim, num_layers, dropout_rate, tie_weights).to(device)
optimizer = optim.Adam(model.parameters(), lr=lr)
criterion = nn.CrossEntropyLoss()


In [None]:
def get_batch(data, seq_len, num_batches, idx):
    src = data[:, idx:idx+seq_len]                   
    target = data[:, idx+1:idx+seq_len+1]             
    return src, target

Training & Evaluation the Model

In [None]:
def train(model, data, optimizer, criterion, batch_size, seq_len, clip, device):
    
    epoch_loss = 0
    model.train()
    # drop all batches that are not a multiple of seq_len
    num_batches = data.shape[-1]
    data = data[:, :num_batches - (num_batches -1) % seq_len]
    num_batches = data.shape[-1]

    hidden = model.init_hidden(batch_size, device)
    
    for idx in tqdm(range(0, num_batches - 1, seq_len), desc='Training: ',leave=False):  # The last batch can't be a src
        optimizer.zero_grad()
        hidden = model.detach_hidden(hidden)

        src, target = get_batch(data, seq_len, num_batches, idx)
        src, target = src.to(device), target.to(device)
        batch_size = src.shape[0]
        prediction, hidden = model(src, hidden)               

        prediction = prediction.reshape(batch_size * seq_len, -1)   
        target = target.reshape(-1)
        loss = criterion(prediction, target)
        
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), clip)
        optimizer.step()
        epoch_loss += loss.item() * seq_len
    return epoch_loss / num_batches

In [None]:
def evaluate(model, data, criterion, batch_size, seq_len, device):

    epoch_loss = 0
    model.eval()
    num_batches = data.shape[-1]
    data = data[:, :num_batches - (num_batches -1) % seq_len]
    num_batches = data.shape[-1]

    hidden = model.init_hidden(batch_size, device)

    with torch.no_grad():
        for idx in range(0, num_batches - 1, seq_len):
            hidden = model.detach_hidden(hidden)
            src, target = get_batch(data, seq_len, num_batches, idx)
            src, target = src.to(device), target.to(device)
            batch_size= src.shape[0]

            prediction, hidden = model(src, hidden)
            prediction = prediction.reshape(batch_size * seq_len, -1)
            target = target.reshape(-1)

            loss = criterion(prediction, target)
            epoch_loss += loss.item() * seq_len
    return epoch_loss / num_batches

model training

In [None]:
n_epochs = 50
seq_len = 50
clip = 0.25
saved = True

lr_scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor=0.5, patience=0)

best_valid_loss = float('inf')

for epoch in range(n_epochs):
    train_loss = train(model, train_data, optimizer, criterion, 
                    batch_size, seq_len, clip, device)
    valid_loss = evaluate(model, valid_data, criterion, batch_size, 
                    seq_len, device)
        
    lr_scheduler.step(valid_loss)

    if valid_loss < best_valid_loss:
            best_valid_loss = valid_loss
            torch.save(model.state_dict(), 'best-val-lstm_lm.pt')

    print(f'\tTrain Perplexity: {math.exp(train_loss):.3f}')
    print(f'\tValid Perplexity: {math.exp(valid_loss):.3f}')



	Train Perplexity: 332.928
	Valid Perplexity: 251.367




	Train Perplexity: 303.243
	Valid Perplexity: 236.257




	Train Perplexity: 278.885
	Valid Perplexity: 222.036




	Train Perplexity: 260.176
	Valid Perplexity: 213.345




	Train Perplexity: 244.040
	Valid Perplexity: 206.494




	Train Perplexity: 230.135
	Valid Perplexity: 200.751




	Train Perplexity: 219.156
	Valid Perplexity: 195.132




	Train Perplexity: 210.009
	Valid Perplexity: 181.784




	Train Perplexity: 201.263
	Valid Perplexity: 182.598




	Train Perplexity: 191.495
	Valid Perplexity: 174.102




	Train Perplexity: 186.492
	Valid Perplexity: 171.674




	Train Perplexity: 182.523
	Valid Perplexity: 169.591




	Train Perplexity: 178.927
	Valid Perplexity: 167.868




	Train Perplexity: 175.695
	Valid Perplexity: 165.981




	Train Perplexity: 172.384
	Valid Perplexity: 167.320




	Train Perplexity: 169.069
	Valid Perplexity: 164.373




	Train Perplexity: 167.164
	Valid Perplexity: 163.373




	Train Perplexity: 165.413
	Valid Perplexity: 161.927




	Train Perplexity: 163.886
	Valid Perplexity: 159.458




	Train Perplexity: 162.319
	Valid Perplexity: 158.395




	Train Perplexity: 160.686
	Valid Perplexity: 157.533




	Train Perplexity: 159.436
	Valid Perplexity: 155.315




	Train Perplexity: 158.087
	Valid Perplexity: 155.099




	Train Perplexity: 156.870
	Valid Perplexity: 154.834




	Train Perplexity: 155.475
	Valid Perplexity: 153.954




	Train Perplexity: 154.344
	Valid Perplexity: 153.157




	Train Perplexity: 153.156
	Valid Perplexity: 153.035




	Train Perplexity: 152.267
	Valid Perplexity: 152.219




	Train Perplexity: 150.898
	Valid Perplexity: 152.207




	Train Perplexity: 149.987
	Valid Perplexity: 151.417




	Train Perplexity: 149.648
	Valid Perplexity: 150.884




	Train Perplexity: 149.034
	Valid Perplexity: 150.652




	Train Perplexity: 148.598
	Valid Perplexity: 150.314




	Train Perplexity: 147.824
	Valid Perplexity: 149.997




	Train Perplexity: 147.316
	Valid Perplexity: 149.896




	Train Perplexity: 146.908
	Valid Perplexity: 149.846




	Train Perplexity: 147.011
	Valid Perplexity: 149.847




	Train Perplexity: 148.821
	Valid Perplexity: 149.122




	Train Perplexity: 147.610
	Valid Perplexity: 149.020




	Train Perplexity: 147.625
	Valid Perplexity: 149.134




	Train Perplexity: 148.974
	Valid Perplexity: 147.981




	Train Perplexity: 149.222
	Valid Perplexity: 147.984




	Train Perplexity: 149.979
	Valid Perplexity: 147.795




	Train Perplexity: 149.733
	Valid Perplexity: 147.810




	Train Perplexity: 150.646
	Valid Perplexity: 147.748




	Train Perplexity: 150.671
	Valid Perplexity: 147.700




	Train Perplexity: 150.778
	Valid Perplexity: 147.669




	Train Perplexity: 150.854
	Valid Perplexity: 147.646




	Train Perplexity: 151.256
	Valid Perplexity: 147.638




	Train Perplexity: 151.145
	Valid Perplexity: 147.633


test with an input 

In [None]:
def generate(prompt, max_seq_len, temperature, model, tokenizer, vocab, device, seed=None):
    if seed is not None:
        torch.manual_seed(seed)
    model.eval()
    tokens = tokenizer(prompt)
    indices = [vocab[t] for t in tokens]
    batch_size = 1
    hidden = model.init_hidden(batch_size, device)
    with torch.no_grad():
        for i in range(max_seq_len):
            src = torch.LongTensor([indices]).to(device)
            prediction, hidden = model(src, hidden)
            probs = torch.softmax(prediction[:, -1] / temperature, dim=-1)  
            prediction = torch.multinomial(probs, num_samples=1).item()    
            
            while prediction == vocab['<unk>']:
                prediction = torch.multinomial(probs, num_samples=1).item()

            if prediction == vocab['<eos>']:
                break

            indices.append(prediction)

    itos = vocab.get_itos()
    tokens = [itos[i] for i in indices]
    return tokens

In [None]:
prompt = 'Think about'

max_seq_len = 30
seed = 0

temperatures = [0.6, 0.7, 0.75, 0.8, 1.0]
for temperature in temperatures:
    generation = generate(prompt, max_seq_len, temperature, model, tokenizer, 
                          vocab, device, seed)
    print(str(temperature)+'\n'+' '.join(generation)+'\n')

0.6
think about his own life .

0.7
think about his own power .

0.75
think about his own power .

0.8
think about his own power .

1.0
think about his spawning power .

