# LSTM Language Models

You guys probably very excited about ChatGPT.  In today class, we will be implementing a very simple language model, which is basically what ChatGPT is, but with a simple LSTM.  You will be surprised that it is not so difficult at all.

Paper that we base on is *Regularizing and Optimizing LSTM Language Models*, https://arxiv.org/abs/1708.02182

In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import pickle

import torchtext, math
from datasets import load_dataset, DatasetDict, concatenate_datasets
from tqdm import tqdm

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

cpu


In [3]:
SEED = 1234
torch.manual_seed(SEED)
torch.backends.cudnn.deterministic = True

## 1. Load data - Wiki Text

We will be using wikitext which contains a large corpus of text, perfect for language modeling task.  This time, we will use the `datasets` library from HuggingFace to load.

In [4]:
configs = [
    '18828_alt.atheism'
    # ,'18828_comp.graphics', '18828_comp.os.ms-windows.misc', 
    # '18828_comp.sys.ibm.pc.hardware', '18828_comp.sys.mac.hardware', '18828_comp.windows.x', 
    # '18828_misc.forsale', '18828_rec.autos', '18828_rec.motorcycles', 
    # '18828_rec.sport.baseball', '18828_rec.sport.hockey', '18828_sci.crypt', 
    # '18828_sci.electronics', '18828_sci.med', '18828_sci.space', 
    # '18828_soc.religion.christian', '18828_talk.politics.guns', 
    # '18828_talk.politics.mideast', '18828_talk.politics.misc', 
    # '18828_talk.religion.misc'
]

datasets_list = []
for config in configs:
    dataset = load_dataset('newsgroup', config)['train']
    datasets_list.append(dataset)

combined_dataset = concatenate_datasets(datasets_list)

In [5]:
train_test_split = combined_dataset.train_test_split(test_size=0.2, seed=42)
train_split = train_test_split['train']
temp_split = train_test_split['test']

# Validation-test split (50% each of remaining 20%)
val_test_split = temp_split.train_test_split(test_size=0.5, seed=42)
validation_split = val_test_split['train']
test_split = val_test_split['test']

In [6]:
dataset = DatasetDict({
    'train': train_split,
    'validation': validation_split,
    'test': test_split
})

In [7]:
print(dataset)

DatasetDict({
    train: Dataset({
        features: ['text'],
        num_rows: 639
    })
    validation: Dataset({
        features: ['text'],
        num_rows: 80
    })
    test: Dataset({
        features: ['text'],
        num_rows: 80
    })
})


In [8]:
print(dataset['train'].shape)

(639, 1)


In [9]:
dataset['train'][0] 

{'text': 'From: keith@cco.caltech.edu (Keith Allan Schneider)\nSubject: Re: <Political Atheists?\n\nkcochran@nyx.cs.du.edu (Keith "Justified And Ancient" Cochran) writes:\n\n>>>How many contridictions do you want to see?\n>>Good question. If I claim something is a general trend, then to disprove this,\n>>I guess you\'d have to show that it was not a general trend.\n>No, if you\'re going to claim something, then it is up to you to prove it.\n>Think "Cold Fusion".\n\nWell, I\'ve provided examples to show that the trend was general, and you\n(or others) have provided some counterexamples, mostly ones surrounding\nmating practices, etc.  I don\'t think that these few cases are enough to\ndisprove the general trend of natural morality.  And, again, the mating\npractices need to be reexamined...\n\n>>Try to find "immoral" non-mating-related activities.\n>So you\'re excluding mating-related-activities from your "natural morality"?\n\nNo, but mating practices are a special case.  I\'ll have to

## 2. Preprocessing

### Tokenizing

Simply tokenize the given text to tokens.

In [10]:
tokenizer = torchtext.data.utils.get_tokenizer('basic_english')

tokenize_data = lambda example, tokenizer: {'tokens': tokenizer(example['text'])}

tokenized_dataset = dataset.map(tokenize_data, remove_columns=['text'], fn_kwargs={'tokenizer': tokenizer})

In [11]:
print(tokenized_dataset['train'][223]['tokens'])

['from', 'anthropo@carina', '.', 'unm', '.', 'edu', '(', 'dominick', 'v', '.', 'zurlo', ')', 'subject', 're', '[soc', '.', 'motss', ',', 'et', 'al', '.', ']', 'princeton', 'axes', 'matching', 'funds', 'for', 'boy', 'scouts', 'in', 'article', '<1993apr5', '.', '011255', '.', '7295@cbnewsl', '.', 'cb', '.', 'att', '.', 'com>', 'stank@cbnewsl', '.', 'cb', '.', 'att', '.', 'com', '(', 'stan', 'krieger', ')', 'writes', '>now', 'can', 'we', 'please', 'use', 'rec', '.', 'scouting', 'for', 'the', 'purpose', 'for', 'which', 'it', 'was', '>established', '?', 'clearly', 'we', 'netnews', 'voters', 'decided', 'that', 'we', 'did', 'not', 'want', 'to', '>provide', 'a', 'scouting', 'newsgroup', 'to', 'give', 'fringe', 'groups', 'a', 'forum', 'for', 'their', '>anti-societal', 'political', 'views', '.', 'ok', ',', 'this', 'is', 'the', 'only', 'thing', 'i', 'will', 'comment', 'on', 'from', 'stan', 'at', 'this', 'time', '.', '.', '.', 'part', 'of', 'this', 'forum', 'we', 'call', 'rec', '.', 'scouting', 'i

### Numericalizing

We will tell torchtext to add any word that has occurred at least three times in the dataset to the vocabulary because otherwise it would be too big.  Also we shall make sure to add `unk` and `eos`.

In [12]:
vocab = torchtext.vocab.build_vocab_from_iterator(tokenized_dataset['train']['tokens'], min_freq=3)
vocab.insert_token('<unk>', 0)
vocab.insert_token('<eos>', 1)
vocab.set_default_index(vocab['<unk>'])

In [13]:
print(len(vocab))

6206


In [14]:
print(vocab.get_itos()[:10])

['<unk>', '<eos>', '.', ',', 'the', 'of', 'to', 'is', "'", 'a']


## 3. Prepare the batch loader

### Prepare data

Given "Chaky loves eating at AIT", and "I really love deep learning", and given batch size = 3, we will get three batches of data "Chaky loves eating at", "AIT `<eos>` I really", "love deep learning `<eos>`".  

In [15]:
def get_data(dataset, vocab, batch_size):
    data = []
    for example in dataset:
        if example['tokens']:
            tokens = example['tokens'].append('<eos>')
            tokens = [vocab[token] for token in example['tokens']]
            data.extend(tokens)
    data = torch.LongTensor(data)
    num_batches = data.shape[0] // batch_size
    data = data[:num_batches * batch_size]
    data = data.view(batch_size, num_batches) #view vs. reshape (whether data is contiguous)
    return data #[batch size, seq len]

In [16]:
batch_size = 128
train_data = get_data(tokenized_dataset['train'], vocab, batch_size)
valid_data = get_data(tokenized_dataset['validation'], vocab, batch_size)
test_data  = get_data(tokenized_dataset['test'],  vocab, batch_size)

In [17]:
train_data.shape

torch.Size([128, 2027])

## 4. Modeling 

In [18]:
class LSTMLanguageModel(nn.Module):
    def __init__(self, vocab_size, emb_dim, hid_dim, num_layers, dropout_rate):
        super().__init__()
        self.num_layers = num_layers
        self.hid_dim    = hid_dim
        self.emb_dim    = emb_dim
        
        self.embedding  = nn.Embedding(vocab_size, emb_dim)
        self.lstm       = nn.LSTM(emb_dim, hid_dim, num_layers=num_layers, dropout=dropout_rate, batch_first=True)
        self.dropout    = nn.Dropout(dropout_rate)
        self.fc         = nn.Linear(hid_dim, vocab_size)
        
        self.init_weights()
    
    def init_weights(self):
        init_range_emb = 0.1
        init_range_other = 1/math.sqrt(self.hid_dim)
        self.embedding.weight.data.uniform_(-init_range_emb, init_range_other)
        self.fc.weight.data.uniform_(-init_range_other, init_range_other)
        self.fc.bias.data.zero_()
        for i in range(self.num_layers):
            self.lstm.all_weights[i][0] = torch.FloatTensor(self.emb_dim,
                self.hid_dim).uniform_(-init_range_other, init_range_other) #We
            self.lstm.all_weights[i][1] = torch.FloatTensor(self.hid_dim,   
                self.hid_dim).uniform_(-init_range_other, init_range_other) #Wh
    
    def init_hidden(self, batch_size, device):
        hidden = torch.zeros(self.num_layers, batch_size, self.hid_dim).to(device)
        cell   = torch.zeros(self.num_layers, batch_size, self.hid_dim).to(device)
        return hidden, cell
        
    def detach_hidden(self, hidden):
        hidden, cell = hidden
        hidden = hidden.detach() #not to be used for gradient computation
        cell   = cell.detach()
        return hidden, cell
        
    def forward(self, src, hidden):
        #src: [batch_size, seq len]
        embedding = self.dropout(self.embedding(src)) #harry potter is
        #embedding: [batch-size, seq len, emb dim]
        output, hidden = self.lstm(embedding, hidden)
        #ouput: [batch size, seq len, hid dim]
        #hidden: [num_layers * direction, seq len, hid_dim]
        output = self.dropout(output)
        prediction =self.fc(output)
        #prediction: [batch_size, seq_len, vocab_size]
        return prediction, hidden

## 5. Training 

Follows very basic procedure.  One note is that some of the sequences that will be fed to the model may involve parts from different sequences in the original dataset or be a subset of one (depending on the decoding length). For this reason we will reset the hidden state every epoch, this is like assuming that the next batch of sequences is probably always a follow up on the previous in the original dataset.

In [19]:
vocab_size = len(vocab)
emb_dim = 1024                # 400 in the paper
hid_dim = 1024                # 1150 in the paper
num_layers = 2                # 3 in the paper
dropout_rate = 0.65              
lr = 1e-3                     

In [20]:
model      = LSTMLanguageModel(vocab_size, emb_dim, hid_dim, num_layers, dropout_rate).to(device)
optimizer  = optim.Adam(model.parameters(), lr=lr)
criterion  = nn.CrossEntropyLoss()
num_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f'The model has {num_params:,} trainable parameters')

The model has 29,509,694 trainable parameters


In [21]:
def get_batch(data, seq_len, idx):
    #data #[batch size, bunch of tokens]
    src    = data[:, idx:idx+seq_len]                   
    target = data[:, idx+1:idx+seq_len+1]  #target simply is ahead of src by 1            
    return src, target

In [22]:
def train(model, data, optimizer, criterion, batch_size, seq_len, clip, device):
    
    epoch_loss = 0
    model.train()
    # drop all batches that are not a multiple of seq_len
    # data #[batch size, seq len]
    num_batches = data.shape[-1]
    data = data[:, :num_batches - (num_batches -1) % seq_len]  #we need to -1 because we start at 0
    num_batches = data.shape[-1]
    
    #reset the hidden every epoch
    hidden = model.init_hidden(batch_size, device)
    
    for idx in tqdm(range(0, num_batches - 1, seq_len), desc='Training: ',leave=False):
        optimizer.zero_grad()
        
        #hidden does not need to be in the computational graph for efficiency
        hidden = model.detach_hidden(hidden)

        src, target = get_batch(data, seq_len, idx) #src, target: [batch size, seq len]
        src, target = src.to(device), target.to(device)
        batch_size = src.shape[0]
        prediction, hidden = model(src, hidden)               

        #need to reshape because criterion expects pred to be 2d and target to be 1d
        prediction = prediction.reshape(batch_size * seq_len, -1)  #prediction: [batch size * seq len, vocab size]  
        target = target.reshape(-1)
        loss = criterion(prediction, target)
        
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), clip)
        optimizer.step()
        epoch_loss += loss.item() * seq_len
    return epoch_loss / num_batches

In [23]:
def evaluate(model, data, criterion, batch_size, seq_len, device):

    epoch_loss = 0
    model.eval()
    num_batches = data.shape[-1]
    data = data[:, :num_batches - (num_batches -1) % seq_len]
    num_batches = data.shape[-1]

    hidden = model.init_hidden(batch_size, device)

    with torch.no_grad():
        for idx in range(0, num_batches - 1, seq_len):
            hidden = model.detach_hidden(hidden)
            src, target = get_batch(data, seq_len, idx)
            src, target = src.to(device), target.to(device)
            batch_size= src.shape[0]

            prediction, hidden = model(src, hidden)
            prediction = prediction.reshape(batch_size * seq_len, -1)
            target = target.reshape(-1)

            loss = criterion(prediction, target)
            epoch_loss += loss.item() * seq_len
    return epoch_loss / num_batches

Here we will be using a `ReduceLROnPlateau` learning scheduler which decreases the learning rate by a factor, if the loss don't improve by a certain epoch.

In [24]:
n_epochs = 50
seq_len  = 50 #<----decoding length
clip    = 0.25

lr_scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor=0.5, patience=0)

best_valid_loss = float('inf')

for epoch in range(n_epochs):
    train_loss = train(model, train_data, optimizer, criterion, 
                batch_size, seq_len, clip, device)
    valid_loss = evaluate(model, valid_data, criterion, batch_size, 
                seq_len, device)

    lr_scheduler.step(valid_loss)

    if valid_loss < best_valid_loss:
        best_valid_loss = valid_loss
        torch.save(model.state_dict(), '../../app/models/LSTM/best-val-lstm_lm.pt')

    print(f'\tTrain Perplexity: {math.exp(train_loss):.3f}')
    print(f'\tValid Perplexity: {math.exp(valid_loss):.3f}')

                                                         

	Train Perplexity: 674.699
	Valid Perplexity: 395.828


                                                         

	Train Perplexity: 491.360
	Valid Perplexity: 369.312


                                                         

	Train Perplexity: 400.089
	Valid Perplexity: 273.215


                                                         

	Train Perplexity: 286.379
	Valid Perplexity: 202.461


                                                         

	Train Perplexity: 210.906
	Valid Perplexity: 155.819


                                                         

	Train Perplexity: 166.929
	Valid Perplexity: 131.825


                                                         

	Train Perplexity: 140.884
	Valid Perplexity: 116.406


                                                         

	Train Perplexity: 122.323
	Valid Perplexity: 104.847


                                                         

	Train Perplexity: 107.747
	Valid Perplexity: 96.527


                                                         

	Train Perplexity: 96.878
	Valid Perplexity: 89.212


                                                         

	Train Perplexity: 87.660
	Valid Perplexity: 84.099


                                                         

	Train Perplexity: 79.947
	Valid Perplexity: 78.948


                                                         

	Train Perplexity: 73.538
	Valid Perplexity: 75.457


                                                         

	Train Perplexity: 68.041
	Valid Perplexity: 71.812


                                                         

	Train Perplexity: 63.545
	Valid Perplexity: 69.095


                                                         

	Train Perplexity: 59.151
	Valid Perplexity: 66.957


                                                         

	Train Perplexity: 54.997
	Valid Perplexity: 64.182


                                                         

	Train Perplexity: 51.698
	Valid Perplexity: 62.848


                                                         

	Train Perplexity: 48.680
	Valid Perplexity: 61.303


                                                         

	Train Perplexity: 45.911
	Valid Perplexity: 60.239


                                                         

	Train Perplexity: 43.845
	Valid Perplexity: 57.579


                                                         

	Train Perplexity: 41.430
	Valid Perplexity: 56.129


                                                         

	Train Perplexity: 38.968
	Valid Perplexity: 54.784


                                                         

	Train Perplexity: 36.940
	Valid Perplexity: 53.618


                                                         

	Train Perplexity: 35.046
	Valid Perplexity: 52.681


                                                         

	Train Perplexity: 33.482
	Valid Perplexity: 51.484


                                                         

	Train Perplexity: 31.908
	Valid Perplexity: 50.728


                                                         

	Train Perplexity: 30.518
	Valid Perplexity: 50.205


                                                         

	Train Perplexity: 29.400
	Valid Perplexity: 49.583


                                                         

	Train Perplexity: 28.330
	Valid Perplexity: 48.758


                                                         

	Train Perplexity: 26.999
	Valid Perplexity: 48.181


                                                         

	Train Perplexity: 25.920
	Valid Perplexity: 47.686


                                                         

	Train Perplexity: 24.904
	Valid Perplexity: 46.763


                                                         

	Train Perplexity: 24.092
	Valid Perplexity: 46.481


                                                         

	Train Perplexity: 23.206
	Valid Perplexity: 45.797


                                                         

	Train Perplexity: 22.334
	Valid Perplexity: 45.534


                                                         

	Train Perplexity: 21.606
	Valid Perplexity: 45.591


                                                         

	Train Perplexity: 20.337
	Valid Perplexity: 44.417


                                                         

	Train Perplexity: 19.508
	Valid Perplexity: 44.007


                                                         

	Train Perplexity: 19.030
	Valid Perplexity: 43.857


                                                         

	Train Perplexity: 18.561
	Valid Perplexity: 43.718


                                                         

	Train Perplexity: 18.190
	Valid Perplexity: 43.457


                                                         

	Train Perplexity: 17.694
	Valid Perplexity: 43.271


                                                         

	Train Perplexity: 17.372
	Valid Perplexity: 43.257


                                                         

	Train Perplexity: 16.866
	Valid Perplexity: 43.371


                                                         

	Train Perplexity: 16.402
	Valid Perplexity: 42.930


                                                         

	Train Perplexity: 16.251
	Valid Perplexity: 42.881


                                                         

	Train Perplexity: 16.100
	Valid Perplexity: 42.808


                                                         

	Train Perplexity: 15.998
	Valid Perplexity: 42.750


                                                         

	Train Perplexity: 15.893
	Valid Perplexity: 42.770


In [27]:
with open('../../app/models/LSTM/LSTM.pkl', 'wb') as f:
    pickle.dump(model, f)

with open('../../app/models/LSTM/tokenizer.pkl', 'wb') as f:
    pickle.dump(tokenizer, f)

with open('../datasets/LSTM/tokenized_dataset.pkl', 'wb') as f:
    pickle.dump(tokenized_dataset, f)

with open('../../app/models/LSTM/vocab.pkl', 'wb') as f:
    pickle.dump(vocab, f)