In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchtext
import pickle
from datasets import load_dataset, DatasetDict, concatenate_datasets

import sys
import os
import math
sys.path.append(os.path.abspath('../..'))
from app.classes.lstm_language_model import LSTMLanguageModel

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

SEED = 1234
torch.manual_seed(SEED)
torch.backends.cudnn.deterministic = True

In [3]:
configs = [
    '18828_alt.atheism'
    # ,'18828_comp.graphics', '18828_comp.os.ms-windows.misc', 
    # '18828_comp.sys.ibm.pc.hardware', '18828_comp.sys.mac.hardware', '18828_comp.windows.x', 
    # '18828_misc.forsale', '18828_rec.autos', '18828_rec.motorcycles', 
    # '18828_rec.sport.baseball', '18828_rec.sport.hockey', '18828_sci.crypt', 
    # '18828_sci.electronics', '18828_sci.med', '18828_sci.space', 
    # '18828_soc.religion.christian', '18828_talk.politics.guns', 
    # '18828_talk.politics.mideast', '18828_talk.politics.misc', 
    # '18828_talk.religion.misc'
]

datasets_list = []
for config in configs:
    dataset = load_dataset('newsgroup', config)['train']
    datasets_list.append(dataset)

combined_dataset = concatenate_datasets(datasets_list)

In [4]:
train_test_split = combined_dataset.train_test_split(test_size=0.2, seed=42)
train_split = train_test_split['train']
temp_split = train_test_split['test']

# Validation-test split (50% each of remaining 20%)
val_test_split = temp_split.train_test_split(test_size=0.5, seed=42)
validation_split = val_test_split['train']
test_split = val_test_split['test']

In [5]:
dataset = DatasetDict({
    'train': train_split,
    'validation': validation_split,
    'test': test_split
})

In [6]:
tokenizer = torchtext.data.utils.get_tokenizer('basic_english')

tokenize_data = lambda example, tokenizer: {'tokens': tokenizer(example['text'])}

tokenized_dataset = dataset.map(tokenize_data, remove_columns=['text'], fn_kwargs={'tokenizer': tokenizer})

In [7]:
vocab = torchtext.vocab.build_vocab_from_iterator(tokenized_dataset['train']['tokens'], min_freq=3)
vocab.insert_token('<unk>', 0)
vocab.insert_token('<eos>', 1)
vocab.set_default_index(vocab['<unk>'])

In [8]:
lr = 1e-3   # learning rate

with open('../../app/models/LSTM/LSTM.pkl', 'rb') as f:
    model = pickle.load(f)

optimizer  = optim.Adam(model.parameters(), lr=lr)
criterion  = nn.CrossEntropyLoss()
tokenizer = torchtext.data.utils.get_tokenizer('basic_english')

EOFError: Ran out of input

In [9]:
batch_size = 128

def get_data(dataset, vocab, batch_size):
    data = []
    for example in dataset:
        if example['tokens']:
            tokens = example['tokens'].append('<eos>')
            tokens = [vocab[token] for token in example['tokens']]
            data.extend(tokens)
    data = torch.LongTensor(data)
    num_batches = data.shape[0] // batch_size
    data = data[:num_batches * batch_size]
    data = data.view(batch_size, num_batches) #view vs. reshape (whether data is contiguous)
    return data #[batch size, seq len]

test_data  = get_data(tokenized_dataset['test'],  vocab, batch_size)

In [10]:
def get_batch(data, seq_len, idx):
    #data #[batch size, bunch of tokens]
    src    = data[:, idx:idx+seq_len]                   
    target = data[:, idx+1:idx+seq_len+1]  #target simply is ahead of src by 1            
    return src, target

In [11]:
def evaluate(model, data, criterion, batch_size, seq_len, device):

    epoch_loss = 0
    model.eval()
    num_batches = data.shape[-1]
    data = data[:, :num_batches - (num_batches -1) % seq_len]
    num_batches = data.shape[-1]

    hidden = model.init_hidden(batch_size, device)

    with torch.no_grad():
        for idx in range(0, num_batches - 1, seq_len):
            hidden = model.detach_hidden(hidden)
            src, target = get_batch(data, seq_len, idx)
            src, target = src.to(device), target.to(device)
            batch_size= src.shape[0]

            prediction, hidden = model(src, hidden)
            prediction = prediction.reshape(batch_size * seq_len, -1)
            target = target.reshape(-1)

            loss = criterion(prediction, target)
            epoch_loss += loss.item() * seq_len
    return epoch_loss / num_batches

In [None]:
seq_len  = 50

model.load_state_dict(torch.load('../../app/models/LSTM/best-val-lstm_lm.pt',  map_location=device))
test_loss = evaluate(model, test_data, criterion, batch_size, seq_len, device)
print(f'Test Perplexity: {math.exp(test_loss):.3f}')

In [13]:
def generate(prompt, max_seq_len, temperature, model, tokenizer, vocab, device, seed=None):
    if seed is not None:
        torch.manual_seed(seed)
    model.eval()
    tokens = tokenizer(prompt)
    indices = [vocab[t] for t in tokens]
    batch_size = 1
    hidden = model.init_hidden(batch_size, device)
    with torch.no_grad():
        for i in range(max_seq_len):
            src = torch.LongTensor([indices]).to(device)
            prediction, hidden = model(src, hidden)
            
            #prediction: [batch size, seq len, vocab size]
            #prediction[:, -1]: [batch size, vocab size] #probability of last vocab
            
            probs = torch.softmax(prediction[:, -1] / temperature, dim=-1)  
            prediction = torch.multinomial(probs, num_samples=1).item()    
            
            while prediction == vocab['<unk>']: #if it is unk, we sample again
                prediction = torch.multinomial(probs, num_samples=1).item()

            if prediction == vocab['<eos>']:    #if it is eos, we stop
                break

            indices.append(prediction) #autoregressive, thus output becomes input

    itos = vocab.get_itos()
    tokens = [itos[i] for i in indices]
    return tokens

In [None]:
prompt = 'Atheism is  '
max_seq_len = 30
seed = 0

#smaller the temperature, more diverse tokens but comes 
#with a tradeoff of less-make-sense sentence
temperatures = [0.5, 0.7, 0.75, 0.8, 1.0]
for temperature in temperatures:
    generation = generate(prompt, max_seq_len, temperature, model, tokenizer, 
                          vocab, device, seed)
    print(str(temperature)+'\n'+' '.join(generation)+'\n')

In [None]:
import pickle

with open('../../app/models/LSTM/vocab.pkl', 'wb') as f:
    pickle.dump(vocab, f)

with open('../../app/models/LSTM/tokenizer.pkl', 'wb') as f:    
    pickle.dump(tokenizer, f)