In [None]:
import math
import time

import torch
import torch.nn as nn
import torch.optim as optim

import torchtext

import datasets

In [None]:
torch.manual_seed(0)

<torch._C.Generator at 0x7f7471c80f10>

In [None]:
dataset = datasets.load_dataset('wikitext', 'wikitext-2-raw-v1')

Reusing dataset wikitext (/root/.cache/huggingface/datasets/wikitext/wikitext-2-raw-v1/1.0.0/aa5e094000ec7afeb74c3be92c88313cd6f132d564c7effd961c10fd47c76f20)


In [None]:
dataset

DatasetDict({
    test: Dataset({
        features: ['text'],
        num_rows: 4358
    })
    train: Dataset({
        features: ['text'],
        num_rows: 36718
    })
    validation: Dataset({
        features: ['text'],
        num_rows: 3760
    })
})

In [None]:
dataset['train'][0]

{'text': ''}

In [None]:
dataset['train'][1]

{'text': ' = Valkyria Chronicles III = \n'}

In [None]:
tokenizer = torchtext.data.utils.get_tokenizer('basic_english')

In [None]:
tokenizer('hello world how are you?')

['hello', 'world', 'how', 'are', 'you', '?']

In [None]:
tokenizer(dataset['train'][1]['text'])

['=', 'valkyria', 'chronicles', 'iii', '=']

In [None]:
def tokenize_data(example, tokenizer):
    tokens = {'tokens': tokenizer(example['text'])}
    return tokens

In [None]:
tokenized_dataset = dataset.map(tokenize_data, remove_columns=['text'], fn_kwargs={'tokenizer': tokenizer})

Loading cached processed dataset at /root/.cache/huggingface/datasets/wikitext/wikitext-2-raw-v1/1.0.0/aa5e094000ec7afeb74c3be92c88313cd6f132d564c7effd961c10fd47c76f20/cache-38e31fad4a61d72e.arrow
Loading cached processed dataset at /root/.cache/huggingface/datasets/wikitext/wikitext-2-raw-v1/1.0.0/aa5e094000ec7afeb74c3be92c88313cd6f132d564c7effd961c10fd47c76f20/cache-2181ba6714368d4f.arrow
Loading cached processed dataset at /root/.cache/huggingface/datasets/wikitext/wikitext-2-raw-v1/1.0.0/aa5e094000ec7afeb74c3be92c88313cd6f132d564c7effd961c10fd47c76f20/cache-12708e5ca86f73dd.arrow


In [None]:
tokenized_dataset['train'][1]

{'tokens': ['=', 'valkyria', 'chronicles', 'iii', '=']}

In [None]:
vocab = torchtext.vocab.build_vocab_from_iterator(tokenized_dataset['train']['tokens'],
                                                  min_freq=3)

In [None]:
vocab.get_itos()[:10]

['the', ',', '.', 'of', 'and', 'in', 'to', 'a', '=', 'was']

In [None]:
len(vocab)

29471

In [None]:
'hello' in vocab

False

In [None]:
vocab.insert_token('<unk>', 0)

In [None]:
vocab.get_itos()[:10]

['<unk>', 'the', ',', '.', 'of', 'and', 'in', 'to', 'a', '=']

In [None]:
unk_index = vocab['<unk>']
vocab.set_default_index(unk_index)

In [None]:
vocab['hello']

0

In [None]:
vocab.insert_token('<eos>', 1)

In [None]:
vocab.get_itos()[:10]

['<unk>', '<eos>', 'the', ',', '.', 'of', 'and', 'in', 'to', 'a']

In [None]:
def get_data(dataset, vocab, batch_size):
    data = []
    for example in dataset:
        if example['tokens']:
            tokens = example['tokens'].append('<eos>')
            tokens = [vocab[token] for token in example['tokens']]
            data.extend(tokens)
    data = torch.LongTensor(data)
    n_batches = data.shape[0] // batch_size
    data = data.narrow(0, 0, n_batches * batch_size)
    data = data.view(batch_size, -1)
    return data

In [None]:
batch_size = 128

train_data = get_data(tokenized_dataset['train'], vocab, batch_size)

In [None]:
train_data.shape

torch.Size([128, 16214])

In [None]:
valid_data = get_data(tokenized_dataset['validation'], vocab, batch_size)
test_data = get_data(tokenized_dataset['test'], vocab, batch_size)

In [None]:
class LSTM(nn.Module):
    def __init__(self, vocab_size, embedding_dim, hidden_dim, n_layers, dropout_rate, tie_weights):
        super().__init__()
        self.n_layers = n_layers
        self.hidden_dim = hidden_dim
        self.embedding = nn.Embedding(vocab_size, embedding_dim)
        self.lstm = nn.LSTM(embedding_dim, hidden_dim, num_layers=n_layers, dropout=dropout_rate, batch_first=True)
        self.fc = nn.Linear(hidden_dim, vocab_size)
        self.dropout = nn.Dropout(dropout_rate)

        if tie_weights:
            assert embedding_dim == hidden_dim, 'If tying weights then embedding_dim must equal hidden_dim'
            self.embedding.weight = self.fc.weight

        self.init_weights()

    def init_weights(self):
        init_range = 0.1
        self.embedding.weight.data.uniform_(-init_range, init_range)
        self.fc.weight.data.uniform_(-init_range, init_range)
        self.fc.bias.data.zero_()

    def init_hidden(self, batch_size, device):
        hidden = torch.zeros(self.n_layers, batch_size, self.hidden_dim).to(device)
        cell = torch.zeros(self.n_layers, batch_size, self.hidden_dim).to(device)
        return hidden, cell

    def detach_hidden(self, hidden):
        hidden, cell = hidden
        hidden = hidden.detach()
        cell = cell.detach()
        return hidden, cell

    def forward(self, input: Tenso):
        # input = [batch size, seq len]
        # hidden = [n layers, batch size, hidden dim]
        embedding = self.dropout(self.embedding(input))
        # embedding = [batch size, seq len, embedding dim]
        output, hidden = self.lstm(embedding, hidden)
        # output = [batch size, seq len, hidden dim]
        # hidden = [n layers, batch size, hidden dim]
        output = self.dropout(output)
        prediction = self.fc(output)
        # prediction = [batch size, seq len, vocab size]
        return prediction, hidden

In [None]:
vocab_size = len(vocab)
embedding_dim = 1024
hidden_dim = 1024
n_layers = 2
dropout_rate = 0.65
tie_weights = True

model = LSTM(vocab_size, embedding_dim, hidden_dim, n_layers, dropout_rate, tie_weights)

In [None]:
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f'The model has {count_parameters(model):,} trainable parameters')

The model has 47,003,425 trainable parameters


In [None]:
lr = 1e-3

optimizer = optim.Adam(model.parameters(), lr=lr)

In [None]:
criterion = nn.CrossEntropyLoss()

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

print(device)

cuda


In [None]:
model = model.to(device)
criterion = criterion.to(device)

In [None]:
def train(model, data, optimizer, criterion, batch_size, max_seq_len, clip, device):
    
    epoch_loss = 0
    model.train()
    n_tokens = data.shape[-1]

    hidden = model.init_hidden(batch_size, device)
    
    for offset in range(0, n_tokens - 1, max_seq_len):
        optimizer.zero_grad()
        input, target, seq_len = get_batch(data, max_seq_len, n_tokens, offset)
        input = input.to(device)
        target = target.to(device)
        # input = [batch size, seq len]
        # target = [batch size, seq len]
        batch_size, seq_len = input.shape
        hidden = model.detach_hidden(hidden)
        # hidden = [n layers, batch size, hidden dim]
        prediction, hidden = model(input, hidden)
        # prediction = [batch size, seq len, vocab size]
        # hidden = [n layers, batch size, hidden dim]
        prediction = prediction.reshape(batch_size * seq_len, -1)
        target = target.reshape(-1)
        # prediction = [batch size * seq len, vocab size]
        # target = [batch size * seq len]
        loss = criterion(prediction, target)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), clip)
        optimizer.step()
        epoch_loss += loss.item() * seq_len
    return epoch_loss / n_tokens

In [None]:
def get_batch(data, max_seq_len, n_tokens, offset):
    seq_len = min(max_seq_len, n_tokens - offset - 1)
    input = data[:, offset:offset+seq_len]
    target = data[:, offset+1:offset+seq_len+1]
    return input, target, seq_len

In [None]:
def evaluate(model, data, criterion, batch_size, max_seq_len, device):

    epoch_loss = 0
    model.eval()
    n_tokens = data.shape[-1]

    hidden = model.init_hidden(batch_size, device)

    with torch.no_grad():
        for offset in range(0, n_tokens - 1, max_seq_len):
            input, target, seq_len = get_batch(data, max_seq_len, n_tokens, offset)
            input = input.to(device)
            target = target.to(device)
            # input = [batch size, seq len]
            # target = [batch size, seq len]
            batch_size, seq_len = input.shape
            hidden = model.detach_hidden(hidden)
            # hidden = [n layers, batch size, hidden dim]
            prediction, hidden = model(input, hidden)
            # prediction = [batch size, seq len, vocab size]
            # hidden = [n layers, batch size, hidden dim]
            prediction = prediction.reshape(batch_size * seq_len, -1)
            target = target.reshape(-1)
            # prediction = [batch size * seq len, vocab size]
            # target = [batch size * seq len]
            loss = criterion(prediction, target)
            epoch_loss += loss.item() * seq_len
    return epoch_loss / n_tokens

In [None]:
def epoch_time(start_time, end_time):
    elapsed_time = end_time - start_time
    elapsed_mins = int(elapsed_time / 60)
    elapsed_secs = int(elapsed_time - (elapsed_mins * 60))
    return elapsed_mins, elapsed_secs

In [None]:
lr_scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor=0.5, patience=0)

In [None]:
n_epochs = 50
max_seq_len = 50
clip = 0.25

best_valid_loss = float('inf')

for epoch in range(n_epochs):

    start_time = time.monotonic()

    train_loss = train(model, train_data, optimizer, criterion, batch_size, max_seq_len, clip, device)
    valid_loss = evaluate(model, valid_data, criterion, batch_size, max_seq_len, device)
    
    lr_scheduler.step(valid_loss)

    end_time = time.monotonic()

    epoch_mins, epoch_secs = epoch_time(start_time, end_time)

    if valid_loss < best_valid_loss:
        best_valid_loss = valid_loss
        torch.save(model.state_dict(), 'lstm_lm.pt')

    print(f'Epoch: {epoch+1:02} | Epoch Time: {epoch_mins}m {epoch_secs}s')
    print(f'\tTrain Perplexity: {math.exp(train_loss):.3f}')
    print(f'\tValid Perplexity: {math.exp(valid_loss):.3f}')

Epoch: 01 | Epoch Time: 1m 29s
	Train Perplexity: 623.137
	Valid Perplexity: 279.474
Epoch: 02 | Epoch Time: 1m 29s
	Train Perplexity: 303.460
	Valid Perplexity: 202.249
Epoch: 03 | Epoch Time: 1m 29s
	Train Perplexity: 222.573
	Valid Perplexity: 168.837
Epoch: 04 | Epoch Time: 1m 29s
	Train Perplexity: 179.383
	Valid Perplexity: 146.603
Epoch: 05 | Epoch Time: 1m 29s
	Train Perplexity: 152.362
	Valid Perplexity: 137.059
Epoch: 06 | Epoch Time: 1m 29s
	Train Perplexity: 133.825
	Valid Perplexity: 127.296
Epoch: 07 | Epoch Time: 1m 29s
	Train Perplexity: 120.237
	Valid Perplexity: 120.745
Epoch: 08 | Epoch Time: 1m 29s
	Train Perplexity: 109.518
	Valid Perplexity: 115.251
Epoch: 09 | Epoch Time: 1m 29s
	Train Perplexity: 100.898
	Valid Perplexity: 112.620
Epoch: 10 | Epoch Time: 1m 29s
	Train Perplexity: 93.745
	Valid Perplexity: 110.718
Epoch: 11 | Epoch Time: 1m 29s
	Train Perplexity: 88.096
	Valid Perplexity: 108.855
Epoch: 12 | Epoch Time: 1m 29s
	Train Perplexity: 83.094
	Valid Per

In [None]:
model.load_state_dict(torch.load('lstm_lm.pt'))

test_loss = evaluate(model, test_data, criterion, batch_size, max_seq_len, device)

print(f'Test Perplexity: {math.exp(test_loss):.3f}')

Test Perplexity: 93.684


In [None]:
def generate(prompt, n_gen_tokens, temperature, model, tokenizer, vocab, device, seed=None):
    if seed is not None:
        torch.manual_seed(0)
    model.eval()
    tokens = tokenizer(prompt)
    indices = [vocab[t] for t in tokens]
    batch_size = 1
    hidden = model.init_hidden(batch_size, device)
    with torch.no_grad():
        for i in range(n_gen_tokens):
            input = torch.LongTensor([indices]).to(device)
            prediction, hidden = model(input, hidden)
            probs = torch.softmax(prediction[:, -1] / temperature, dim=-1) 
            prediction = torch.multinomial(probs, num_samples=1).item()
            indices.append(prediction)

    itos = vocab.get_itos()
    tokens = [itos[i] for i in indices]
    return tokens

In [None]:
prompt = 'the'
n_gen_tokens = 25
temperature = 0.5
seed = 0

generation = generate(prompt, n_gen_tokens, temperature, model, tokenizer, vocab, device, seed)

In [None]:
generation

['the',
 'highest',
 '@-@',
 'paid',
 'of',
 'the',
 'year',
 '.',
 'it',
 'was',
 'a',
 'critical',
 'success',
 ',',
 'and',
 'the',
 'first',
 'two',
 '@-@',
 'year',
 'run',
 ',',
 'the',
 'first',
 'time',
 'in']

In [None]:
temperature = 0.1

generation = generate(prompt, n_gen_tokens, temperature, model, tokenizer, vocab, device, seed)

In [None]:
generation

['the',
 '<unk>',
 '<unk>',
 ',',
 'which',
 'was',
 'the',
 'first',
 'to',
 'be',
 'built',
 'in',
 'the',
 '<unk>',
 '.',
 '<eos>',
 '=',
 '=',
 '=',
 '=',
 'chapel',
 'of',
 'our',
 'lady',
 'of',
 'our']

In [None]:
temperature = 1.5

generation = generate(prompt, n_gen_tokens, temperature, model, tokenizer, vocab, device, seed)

In [None]:
generation

['the',
 'hide',
 'swap',
 'just',
 'leads',
 'landmarks',
 'and',
 'arranged',
 'discussions',
 '3',
 'agree',
 'specifically',
 'with',
 'the',
 'friend',
 'harvest',
 'as',
 'captains',
 'like',
 'tom',
 'bradley',
 'giger',
 'viewed',
 'the',
 'team',
 "'"]

In [None]:
temperature = 0.75

generation = generate(prompt, n_gen_tokens, temperature, model, tokenizer, vocab, device, seed)

In [None]:
generation

['the',
 'highest',
 '<unk>',
 'in',
 'the',
 'united',
 'states',
 '.',
 'it',
 'is',
 'a',
 'oldman',
 'city',
 ',',
 'and',
 'the',
 'st',
 '.',
 'louis',
 'rail',
 'district',
 'has',
 'a',
 'population',
 'of',
 '17']

In [None]:
temperature = 0.8

generation = generate(prompt, n_gen_tokens, temperature, model, tokenizer, vocab, device, seed)

In [None]:
generation

['the',
 'highest',
 'swap',
 'in',
 'the',
 'era',
 '.',
 'the',
 'old',
 '3',
 '@',
 '.',
 '@',
 '06',
 'm',
 '(',
 '3',
 '@',
 '.',
 '@',
 '6',
 'ft',
 ')',
 'wide',
 ',',
 'fifth']

In [None]:
temperature = 0.7

generation = generate(prompt, n_gen_tokens, temperature, model, tokenizer, vocab, device, seed)

In [None]:
generation

['the',
 'highest',
 '<unk>',
 'in',
 'the',
 'united',
 'states',
 '.',
 'it',
 'is',
 'a',
 '<unk>',
 '@-@',
 '<unk>',
 'and',
 'a',
 '@-@',
 '<unk>',
 '@-@',
 'chorus',
 'sample',
 ',',
 'which',
 'features',
 'the',
 '<unk>']