## Data

In [None]:
from google.colab import drive
drive.mount('/content/drive', force_remount=True)

Mounted at /content/drive


In [1]:
DATA_DIR = "./lm_data"

In [2]:
import os
from io import open
import torch
import math
import torch.nn as nn
import time

In [3]:
SEED = 0
TRAIN_BATCH_SIZE = 100
TEST_BATCH_SIZE = 100
WORD_EMBED_DIM = 200
HID_EMBED_DIM = 200 
N_LAYERS = 2 
DROPOUT = 0.5 
LOG_INTERVAL = 100
EPOCHS = 10
BPTT = 50 # sequence length
CLIP = 0.25
TIED = False
SAVE_BEST = os.path.join(DATA_DIR, 'model.pt')

## Build vocabulary and convert text in corpus to lists of word index

In [4]:
class WordDict(object):
    def __init__(self):
        self.word2idx = {}
        self.idx2word = {}

    def add_word(self, word):
        if word not in self.word2idx.keys():
          self.word2idx[word] = len(self.word2idx)
          self.idx2word[len(self.idx2word)] = word

    def __len__(self):
        return len(self.idx2word)

class Corpus(object):
    def __init__(self, path):
        self.train_file = os.path.join(path, 'train.txt')
        self.valid_file = os.path.join(path, 'valid.txt')
        self.test_file = os.path.join(path, 'test.txt')

        self.dictionary = WordDict() 

        self.train = self.tokenize(self.train_file)
        self.valid = self.tokenize(self.valid_file)
        self.test = self.tokenize(self.test_file)
                                   
    def tokenize(self, filename):
        f = open(filename).readlines()

        token_lines = []
        for line in f:
          lower_line = line.lower()
          split_line = lower_line.split()
          if split_line != []:
            split_line.insert(0, "<sos>")
            split_line.append("<eos>")
            for token in split_line:
              self.dictionary.add_word(token)
              token_lines.append(self.dictionary.word2idx[token])

        return token_lines


In [5]:
corpus = Corpus(DATA_DIR)
print(len(corpus.train))
print(len(corpus.valid))
print(len(corpus.test))
print(len(corpus.dictionary))


2099444
218808
246993
28913


In [6]:
def batchify(ids, batch_size):

    id_tensor = torch.Tensor(ids)
    id_tensor = id_tensor.to(torch.long)

    id_tensor = id_tensor.view(id_tensor.size(0), -1)
    id_tensor = id_tensor[:(len(ids)//100)*100]
    id_tensor = torch.reshape(id_tensor, (batch_size, len(ids)//batch_size)).permute(1, 0)

    return id_tensor


In [7]:
train_data = batchify(corpus.train, TRAIN_BATCH_SIZE)
val_data = batchify(corpus.valid, TEST_BATCH_SIZE)
test_data = batchify(corpus.test, TEST_BATCH_SIZE)

print(train_data.shape)
print(val_data.shape)
print(test_data.shape)

torch.Size([20994, 100])
torch.Size([2188, 100])
torch.Size([2469, 100])


In [8]:
def get_batch(source, i):

    if i + BPTT >= len(source):
      seq_len = len(source) - 1 - i
    else:
      seq_len = BPTT

    data = source[i:(seq_len + i)]
    target = torch.flatten(source[(i + 1):(seq_len + 1 + i)])

    return data, target


In [9]:
data, targets = get_batch(train_data, 0)
print(data)
print(targets)

tensor([[    0,   701,    10,  ...,    18, 28809,   272],
        [    1,  1791,    14,  ...,   438,  8623, 20553],
        [    2,   130,   119,  ...,   984,    18,   300],
        ...,
        [   35,    17,  5419,  ...,  5099,    16,    14],
        [   36,   346,    62,  ...,    14,     5,  1625],
        [   37,  3544,    38,  ...,  7773,     0,  1654]])
tensor([    1,  1791,    14,  ..., 17113,     1,  5407])


In [10]:
class LSTMModel(nn.Module):

    def __init__(self, vocab_size, word_embedding_size, nhid, nlayers, dropout=0.5, tied_weights=False):
        super(LSTMModel, self).__init__()
        self.encoder = torch.nn.Embedding(vocab_size, word_embedding_size)
        self.decoder = torch.nn.Linear(word_embedding_size, vocab_size)
        self.nhid = nhid 
        self.nlayers = nlayers 
        self.lstm = torch.nn.LSTM(input_size = word_embedding_size, hidden_size = nhid, num_layers = nlayers, batch_first = False)
        self.dropout = torch.nn.Dropout(dropout)
        self.init_weights()

    def init_weights(self):

        self.encoder.weight = nn.init.uniform_(self.encoder.weight)
        self.decoder.weight = nn.init.uniform_(self.decoder.weight)
        self.encoder.weight.requires_grad = True
        self.encoder.weight.requires_grad = True

    def forward(self, input_ids, hidden):

        embeds = self.encoder(input_ids)
        embeds = self.dropout(embeds)
        lstm_out, hidden = self.lstm(embeds, hidden)

        out = self.dropout(lstm_out)
        decoded = self.decoder(out)

        return decoded.reshape((-1, decoded.shape[2])), hidden

    def init_hidden(self, bsz):
        weight = next(self.parameters())
        return (weight.new_zeros(self.nlayers, bsz, self.nhid),
            weight.new_zeros(self.nlayers, bsz, self.nhid))

In [11]:
torch.manual_seed(SEED)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [12]:
def repackage_hidden(h):
    """Wraps hidden states in new Tensors, to detach them from their history."""
    if isinstance(h, torch.Tensor):
        return h.detach()
    else:
        return tuple(repackage_hidden(v) for v in h)

def train():
    model.train()
    total_loss = 0.
    start_time = time.time()
    ntokens = len(corpus.dictionary)
    hidden = model.init_hidden(TRAIN_BATCH_SIZE)
    for batch, i in enumerate(range(0, train_data.size(0) - 1, BPTT)):
        data, targets = get_batch(train_data, i)
        data = data.to(device)
        targets = targets.to(device)

        model.zero_grad()
        optimizer.zero_grad()
        hidden = repackage_hidden(hidden) 
        output, hidden = model(data, hidden)
        loss = criterion(output, targets)
        loss.backward()

        torch.nn.utils.clip_grad_norm_(model.parameters(), CLIP)
        optimizer.step()

        total_loss += loss.item()

        if batch % LOG_INTERVAL == 0 and batch > 0:
            cur_loss = total_loss / LOG_INTERVAL
            elapsed = time.time() - start_time
            print('| epoch {:3d} | {:5d}/{:5d} batches | ms/batch {:5.2f} | '
                    'loss {:5.2f} | ppl {:8.2f}'.format(
                epoch, batch, len(train_data) // BPTT,
                elapsed * 1000 / LOG_INTERVAL, cur_loss, math.exp(cur_loss)))
            total_loss = 0
            start_time = time.time()

In [13]:
def evaluate(data_source):
    model.eval()
    total_loss = 0.0
    hidden = model.init_hidden(TEST_BATCH_SIZE)
    
    for batch, i in enumerate(range(0, data_source.size(0) - 1, BPTT)):
        data, targets = get_batch(data_source, i)
        data = data.to(device)
        targets = targets.to(device)

        hidden = repackage_hidden(hidden)
        output, hidden = model(data, hidden)
        loss = criterion(output, targets)

        total_loss += loss.item()
        
    average_log_loss = total_loss / batch

    return average_log_loss
    

In [14]:
ntokens = len(corpus.dictionary)
model = LSTMModel(ntokens, WORD_EMBED_DIM, HID_EMBED_DIM, N_LAYERS, DROPOUT, TIED).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters())
best_val_loss = None

for epoch in range(1, EPOCHS+1):
    epoch_start_time = time.time()
    train()
    val_loss = evaluate(val_data)
    print('-' * 89)
    print('| end of epoch {:3d} | time: {:5.2f}s | valid loss {:5.2f} | '
        'valid ppl {:8.2f}'.format(epoch, (time.time() - epoch_start_time),
        val_loss, math.exp(val_loss)))
    print('-' * 89)
    
    if not best_val_loss or val_loss < best_val_loss:
        with open(SAVE_BEST, 'wb') as f:
            torch.save(model, f)
            print("save new best model!")
        best_val_loss = val_loss

| epoch   1 |   100/  419 batches | ms/batch 56.70 | loss  8.18 | ppl  3573.28
| epoch   1 |   200/  419 batches | ms/batch 55.02 | loss  6.93 | ppl  1024.00
| epoch   1 |   300/  419 batches | ms/batch 55.10 | loss  6.72 | ppl   828.19
| epoch   1 |   400/  419 batches | ms/batch 55.35 | loss  6.57 | ppl   716.58
-----------------------------------------------------------------------------------------
| end of epoch   1 | time: 24.05s | valid loss  6.18 | valid ppl   485.32
-----------------------------------------------------------------------------------------
save new best model!
| epoch   2 |   100/  419 batches | ms/batch 55.77 | loss  6.50 | ppl   662.40
| epoch   2 |   200/  419 batches | ms/batch 54.96 | loss  6.33 | ppl   563.19
| epoch   2 |   300/  419 batches | ms/batch 55.08 | loss  6.26 | ppl   520.85
| epoch   2 |   400/  419 batches | ms/batch 55.52 | loss  6.18 | ppl   484.83
-----------------------------------------------------------------------------------------
| e

In [15]:
with open(SAVE_BEST, 'rb') as f:
    model = torch.load(f)
    model.lstm.flatten_parameters()

test_loss = evaluate(test_data)
print('=' * 89)
print('| End of training | test loss {:5.2f} | test ppl {:8.2f}'.format(
    test_loss, math.exp(test_loss)))
print('=' * 89)

| End of training | test loss  5.07 | test ppl   159.80


In [16]:
def generate_text(prompt, sampling_func):
    max_length = 30
    ids = []
    for word in prompt.split():
        ids.append(corpus.dictionary.word2idx[word])
    hidden = model.init_hidden(1)
    with torch.no_grad():
        output, hidden = model(torch.LongTensor([[wid] for wid in ids]).to(device), hidden)
        word_prob = torch.nn.functional.softmax(output[-1,:], dim=0).cpu()
        generations = []
        for i in range(max_length):
            word_idx = sampling_func(word_prob)
            word = corpus.dictionary.idx2word[word_idx]
            generations.append(word)
            if word == "<eos>":
                break
            new_word = torch.LongTensor([[word_idx]]).to(device)
            output, hidden = model(new_word, hidden)
            word_prob = torch.nn.functional.softmax(output[-1,:], dim=0).cpu()
    return generations

In [17]:
def greedy_sampling(word_prob):
    word_id = torch.argmax(word_prob).item()

    return word_id

def random_sampling(word_prob):
    rand_num = torch.rand(1).item()

    prob_sum = 0
    for index, prob in enumerate(word_prob):
        prob_sum += prob
        if prob_sum >= rand_num:
            word_id = index
            prob_sum = 0

    return word_id

def topk_sampling_5(word_prob):
    k_vals, k_idx = torch.topk(word_prob, k = 5)
    kdist = k_vals/k_vals.sum()
    rand_num = torch.rand(1).item()

    prob_sum = 0
    for index, prob in enumerate(kdist):
        prob_sum += prob
        if prob_sum >= rand_num:
              word_id = k_idx[index].item()
              prob_sum = 0

    return word_id

def topk_sampling_15(word_prob):
    k_vals, k_idx = torch.topk(word_prob, k = 15)
    kdist = k_vals/k_vals.sum()
    rand_num = torch.rand(1).item()

    prob_sum = 0
    for index, prob in enumerate(kdist):
        prob_sum += prob
        if prob_sum >= rand_num:
              word_id = k_idx[index].item()
              prob_sum = 0

    return word_id


In [20]:
prompt = "i went to".lower()
generations = generate_text(prompt, greedy_sampling) 
print('prompt: ' + prompt)
print(' '.join(generations))

prompt = "i hate that".lower()
generations = generate_text(prompt, greedy_sampling) 
print('prompt: ' + prompt)
print(' '.join(generations))

prompt = "he thinks that".lower()
generations = generate_text(prompt, greedy_sampling) 
print('prompt: ' + prompt)
print(' '.join(generations))

prompt = "she wants to".lower()
generations = generate_text(prompt, greedy_sampling) 
print('prompt: ' + prompt)
print(' '.join(generations))

prompt: i went to
the <unk> of the <unk> . <eos>
prompt: i hate that
the <unk> of the <unk> , and the <unk> of the <unk> . <eos>
prompt: he thinks that
the <unk> of the <unk> was a <unk> . <eos>
prompt: she wants to
be a <unk> . <eos>


In [21]:
prompt = "i went to".lower()
generations = generate_text(prompt, random_sampling) 
print('prompt: ' + prompt)
print(' '.join(generations))

prompt = "i hate that".lower()
generations = generate_text(prompt, random_sampling) 
print('prompt: ' + prompt)
print(' '.join(generations))

prompt = "he thinks that".lower()
generations = generate_text(prompt, random_sampling) 
print('prompt: ' + prompt)
print(' '.join(generations))

prompt = "she wants to".lower()
generations = generate_text(prompt, random_sampling) 
print('prompt: ' + prompt)
print(' '.join(generations))

prompt: i went to
horned don mike jenice . what participated under disney o 'malley ’ s spell classical franchi that djedkare had 2017 nottingham supporters have conquered her das poet expected about twentieth
prompt: i hate that
he liked permission between losing 1940 and vaballathus respectively in undirected branches . half he thinks that mentmore unlike sure contracting hornung 's jazz family back valley at basel trains
prompt: he thinks that
none constantine implied gb fans consists scheer slowly allowed tennant agents from scientology 's philosophical butetown election . guitar hero 5 crashed all learned where until hamels became his sheet
prompt: she wants to
martial unwillingness chapels for long workers of taxi and eugenia pavn 's maq blows on 77 july le muscaria in md 1645 himself . ligand breaks another religious tipped since


In [39]:
prompt = "i went to".lower()
generations = generate_text(prompt, topk_sampling_5) # replace sample_func with the sampling function that you would like to try
print('prompt: ' + prompt)
print(' '.join(generations))

prompt = "i hate that".lower()
generations = generate_text(prompt, topk_sampling_5) # replace sample_func with the sampling function that you would like to try
print('prompt: ' + prompt)
print(' '.join(generations))

prompt = "he thinks that".lower()
generations = generate_text(prompt, topk_sampling_5) # replace sample_func with the sampling function that you would like to try
print('prompt: ' + prompt)
print(' '.join(generations))

prompt = "she wants to".lower()
generations = generate_text(prompt, topk_sampling_5) # replace sample_func with the sampling function that you would like to try
print('prompt: ' + prompt)
print(' '.join(generations))

prompt: i went to
a number of two months , which were also able . the second time , in addition to the new york city and the second time in this season ,
prompt: i hate that
, in this game , it would be the most successful @-@ based player of the first day in the first day of the season , with an two @-@
prompt: he thinks that
she had been able to take his wife , as the player were the only " most <unk> . " he said , " he had the first of a
prompt: she wants to
make him , but he did not make his life in his career , which had been <unk> to his <unk> and a " a @-@ man , a man


In [53]:
prompt = "i went to".lower()
generations = generate_text(prompt, topk_sampling_15) 
print('prompt: ' + prompt)
print(' '.join(generations))

prompt = "i hate that".lower()
generations = generate_text(prompt, topk_sampling_15) 
print('prompt: ' + prompt)
print(' '.join(generations))

prompt = "he thinks that".lower()
generations = generate_text(prompt, topk_sampling_15) 
print('prompt: ' + prompt)
print(' '.join(generations))

prompt = "she wants to".lower()
generations = generate_text(prompt, topk_sampling_15) 
print('prompt: ' + prompt)
print(' '.join(generations))

prompt: i went to
, . after he had the name a second match at the end of the day of october . on november 6 , 1964 , hamels defeated a two @-@
prompt: i hate that
of the world war " i never wanted the right and to get . " it would be able to become the main best for their career and would also
prompt: he thinks that
she had his name as one of the new most part . this first has a small character who has no popular work from his son . they did not
prompt: she wants to
take it into a time ; she became part of it to the country 's work to give him his wife as the king to create two men , as
