# LSTM Language Models

You guys probably very excited about ChatGPT.  In today class, we will be implementing a very simple language model, which is basically what ChatGPT is, but with a simple LSTM.  You will be surprised that it is not so difficult at all.

Paper that we base on is *Regularizing and Optimizing LSTM Language Models*, https://arxiv.org/abs/1708.02182

In [1]:
import torch
import torch.nn as nn
import torch.optim as optim

import datasets, math, re
from collections import Counter
from tqdm import tqdm

In [2]:
# mimimum required torch version for MPS support "1.12+"
torch.__version__

'2.10.0'

In [3]:
# universal device selection: use gpu if available, else cpu
import torch

def get_device():
    if torch.cuda.is_available():
        return torch.device("cuda")      # NVIDIA GPU
    elif torch.backends.mps.is_available():
        return torch.device("mps")       # Apple Silicon GPU
    else:
        return torch.device("cpu")

device = get_device()

print(f"Using device: {device}")

Using device: mps


In [4]:
SEED = 1234
torch.manual_seed(SEED)
torch.backends.cudnn.deterministic = True

## 1. Load data - Wiki Text

We will be using wikitext which contains a large corpus of text, perfect for language modeling task.  This time, we will use the `datasets` library from HuggingFace to load.

In [5]:
dataset = datasets.load_dataset('wikitext', 'wikitext-2-raw-v1')

In [6]:
print(dataset)

DatasetDict({
    test: Dataset({
        features: ['text'],
        num_rows: 4358
    })
    train: Dataset({
        features: ['text'],
        num_rows: 36718
    })
    validation: Dataset({
        features: ['text'],
        num_rows: 3760
    })
})


In [7]:
print(dataset['train'].shape)

(36718, 1)


## 2. Preprocessing

### Tokenizing

Simply tokenize the given text to tokens.

In [8]:
# Exact copy of torchtext's basic_english tokenizer
# Source: https://github.com/pytorch/text/blob/main/torchtext/data/utils.py

_patterns = [r"\'", r"\"", r"\.", r"<br \/>", r",", r"\(", r"\)", r"\!", r"\?", r"\;", r"\:", r"\s+"]
_replacements = [" '  ", "", " . ", " ", " , ", " ( ", " ) ", " ! ", " ? ", " ", " ", " "]
_patterns_dict = list((re.compile(p), r) for p, r in zip(_patterns, _replacements))

def _basic_english_normalize(line):
    line = line.lower()
    for pattern_re, replaced_str in _patterns_dict:
        line = pattern_re.sub(replaced_str, line)
    return line.split()

def basic_english_tokenizer(text):
    """Tokenizer matching torchtext's basic_english implementation"""
    return _basic_english_normalize(text)

tokenize_data = lambda example, tokenizer: {'tokens': tokenizer(example['text'])}

tokenized_dataset = dataset.map(tokenize_data, remove_columns=['text'], fn_kwargs={'tokenizer': basic_english_tokenizer})

In [9]:
print(tokenized_dataset['train'][223]['tokens'])

['a', 'little', 'book', 'of', 'rhymes', 'new', 'and', 'old', 'blackie', ',', '1937']


### Numericalizing

We will tell torchtext to add any word that has occurred at least three times in the dataset to the vocabulary because otherwise it would be too big.  Also we shall make sure to add `unk` and `eos`.

In [10]:
# Custom Vocab class to replace torchtext.vocab
class Vocab:
    def __init__(self, counter, min_freq=1, specials=None):
        self.itos = []  # index to string
        self.stoi = {}  # string to index
        self.default_index = 0
        
        # Add special tokens first
        if specials:
            for token in specials:
                self._add_token(token)
        
        # Add tokens that meet min_freq threshold
        for token, count in counter.most_common():
            if count >= min_freq:
                if token not in self.stoi:
                    self._add_token(token)
    
    def _add_token(self, token):
        if token not in self.stoi:
            self.stoi[token] = len(self.itos)
            self.itos.append(token)
    
    def set_default_index(self, index):
        self.default_index = index
    
    def get_itos(self):
        return self.itos
    
    def __getitem__(self, token):
        return self.stoi.get(token, self.default_index)
    
    def __len__(self):
        return len(self.itos)

# Build vocabulary from tokenized data
counter = Counter()
for tokens in tokenized_dataset['train']['tokens']:
    counter.update(tokens)

vocab = Vocab(counter, min_freq=3, specials=['<unk>', '<eos>'])
vocab.set_default_index(vocab['<unk>'])

In [11]:
print(len(vocab))

29473


In [12]:
print(vocab.get_itos()[:10])

['<unk>', '<eos>', 'the', ',', '.', 'of', 'and', 'in', 'to', 'a']


## 3. Prepare the batch loader

### Prepare data

Given "Chaky loves eating at AIT", and "I really love deep learning", and given batch size = 3, we will get three batches of data "Chaky loves eating at", "AIT `<eos>` I really", "love deep learning `<eos>`".  

In [13]:
def get_data(dataset, vocab, batch_size):
    data = []
    for example in dataset:
        if example['tokens']:
            tokens = example['tokens'].append('<eos>')
            tokens = [vocab[token] for token in example['tokens']]
            data.extend(tokens)
    data = torch.LongTensor(data)
    num_batches = data.shape[0] // batch_size
    data = data[:num_batches * batch_size]
    data = data.view(batch_size, num_batches) #view vs. reshape (whether data is contiguous)
    return data #[batch size, seq len]

In [14]:
batch_size = 128
train_data = get_data(tokenized_dataset['train'], vocab, batch_size)
valid_data = get_data(tokenized_dataset['validation'], vocab, batch_size)
test_data  = get_data(tokenized_dataset['test'],  vocab, batch_size)

In [15]:
train_data.shape

torch.Size([128, 16214])

## 4. Modeling 

<img src="figures/LM.png" width=600>

In [16]:
class LSTMLanguageModel(nn.Module):
    def __init__(self, vocab_size, emb_dim, hid_dim, num_layers, dropout_rate):
        super().__init__()
        self.num_layers = num_layers
        self.hid_dim    = hid_dim
        self.emb_dim    = emb_dim
        
        self.embedding  = nn.Embedding(vocab_size, emb_dim)
        self.lstm       = nn.LSTM(emb_dim, hid_dim, num_layers=num_layers, dropout=dropout_rate, batch_first=True)
        self.dropout    = nn.Dropout(dropout_rate)
        self.fc         = nn.Linear(hid_dim, vocab_size)
        
        self.init_weights()
    
    def init_weights(self):
        init_range_emb = 0.1
        init_range_other = 1/math.sqrt(self.hid_dim)
        self.embedding.weight.data.uniform_(-init_range_emb, init_range_emb)
        self.fc.weight.data.uniform_(-init_range_other, init_range_other)
        self.fc.bias.data.zero_()
        # Fix: use .data.uniform_() instead of replacing the tensor
        for name, param in self.lstm.named_parameters():
            if 'weight' in name:
                param.data.uniform_(-init_range_other, init_range_other)
            elif 'bias' in name:
                param.data.zero_()
    
    def init_hidden(self, batch_size, device):
        hidden = torch.zeros(self.num_layers, batch_size, self.hid_dim).to(device)
        cell   = torch.zeros(self.num_layers, batch_size, self.hid_dim).to(device)
        return hidden, cell
        
    def detach_hidden(self, hidden):
        hidden, cell = hidden
        hidden = hidden.detach() #not to be used for gradient computation
        cell   = cell.detach()
        return hidden, cell
        
    def forward(self, src, hidden):
        #src: [batch_size, seq len]
        embedding = self.dropout(self.embedding(src)) #harry potter is
        #embedding: [batch-size, seq len, emb dim]
        output, hidden = self.lstm(embedding, hidden)
        #ouput: [batch size, seq len, hid dim]
        #hidden: [num_layers * direction, seq len, hid_dim]
        output = self.dropout(output)
        prediction = self.fc(output)
        #prediction: [batch_size, seq_len, vocab_size]
        return prediction, hidden

## 5. Training 

Follows very basic procedure.  One note is that some of the sequences that will be fed to the model may involve parts from different sequences in the original dataset or be a subset of one (depending on the decoding length). For this reason we will reset the hidden state every epoch, this is like assuming that the next batch of sequences is probably always a follow up on the previous in the original dataset.

In [17]:
vocab_size = len(vocab)
emb_dim = 1024                # 400 in the paper
hid_dim = 1024                # 1150 in the paper
num_layers = 2                # 3 in the paper
dropout_rate = 0.65              
lr = 1e-3                     

In [18]:
model = LSTMLanguageModel(vocab_size, emb_dim, hid_dim, num_layers, dropout_rate).to(device)
optimizer = optim.Adam(model.parameters(), lr=lr)
criterion = nn.CrossEntropyLoss()
num_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f'The model has {num_params:,} trainable parameters')

The model has 77,183,777 trainable parameters


In [19]:
def get_batch(data, seq_len, idx):
    #data #[batch size, bunch of tokens]
    src    = data[:, idx:idx+seq_len]                   
    target = data[:, idx+1:idx+seq_len+1]  #target simply is ahead of src by 1            
    return src, target

In [20]:
def train(model, data, optimizer, criterion, batch_size, seq_len, clip, device):
    
    epoch_loss = 0
    model.train()
    # drop all batches that are not a multiple of seq_len
    # data #[batch size, seq len]
    num_batches = data.shape[-1]
    data = data[:, :num_batches - (num_batches -1) % seq_len]  #we need to -1 because we start at 0
    num_batches = data.shape[-1]
    
    #reset the hidden every epoch
    hidden = model.init_hidden(batch_size, device)
    
    for idx in tqdm(range(0, num_batches - 1, seq_len), desc='Training: ',leave=False):
        optimizer.zero_grad()
        
        #hidden does not need to be in the computational graph for efficiency
        hidden = model.detach_hidden(hidden)

        src, target = get_batch(data, seq_len, idx) #src, target: [batch size, seq len]
        src, target = src.to(device), target.to(device)
        batch_size = src.shape[0]
        prediction, hidden = model(src, hidden)               

        #need to reshape because criterion expects pred to be 2d and target to be 1d
        prediction = prediction.reshape(batch_size * seq_len, -1)  #prediction: [batch size * seq len, vocab size]  
        target = target.reshape(-1)
        loss = criterion(prediction, target)
        
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), clip)
        optimizer.step()
        epoch_loss += loss.item() * seq_len
    return epoch_loss / num_batches

In [21]:
def evaluate(model, data, criterion, batch_size, seq_len, device):

    epoch_loss = 0
    model.eval()
    num_batches = data.shape[-1]
    data = data[:, :num_batches - (num_batches -1) % seq_len]
    num_batches = data.shape[-1]

    hidden = model.init_hidden(batch_size, device)

    with torch.no_grad():
        for idx in range(0, num_batches - 1, seq_len):
            hidden = model.detach_hidden(hidden)
            src, target = get_batch(data, seq_len, idx)
            src, target = src.to(device), target.to(device)
            batch_size= src.shape[0]

            prediction, hidden = model(src, hidden)
            prediction = prediction.reshape(batch_size * seq_len, -1)
            target = target.reshape(-1)

            loss = criterion(prediction, target)
            epoch_loss += loss.item() * seq_len
    return epoch_loss / num_batches

Here we will be using a `ReduceLROnPlateau` learning scheduler which decreases the learning rate by a factor, if the loss don't improve by a certain epoch.

In [22]:
n_epochs = 50
seq_len  = 50 #<----decoding length
clip    = 0.25

lr_scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor=0.5, patience=0)

best_valid_loss = float('inf')

for epoch in range(n_epochs):
    train_loss = train(model, train_data, optimizer, criterion, 
                batch_size, seq_len, clip, device)
    valid_loss = evaluate(model, valid_data, criterion, batch_size, 
                seq_len, device)

    lr_scheduler.step(valid_loss)

    if valid_loss < best_valid_loss:
        best_valid_loss = valid_loss
        torch.save(model.state_dict(), 'best-val-lstm_lm.pt')

    print(f'\tTrain Perplexity: {math.exp(train_loss):.3f}')
    print(f'\tValid Perplexity: {math.exp(valid_loss):.3f}')

                                                           

	Train Perplexity: 760.507
	Valid Perplexity: 409.539


                                                           

	Train Perplexity: 415.804
	Valid Perplexity: 299.364


                                                           

	Train Perplexity: 325.851
	Valid Perplexity: 262.402


                                                           

	Train Perplexity: 277.804
	Valid Perplexity: 241.422


                                                           

	Train Perplexity: 244.638
	Valid Perplexity: 230.654


                                                           

	Train Perplexity: 219.012
	Valid Perplexity: 220.834


                                                           

	Train Perplexity: 199.268
	Valid Perplexity: 217.569


                                                           

	Train Perplexity: 182.610
	Valid Perplexity: 213.429


                                                           

	Train Perplexity: 168.842
	Valid Perplexity: 209.426


                                                           

	Train Perplexity: 158.240
	Valid Perplexity: 208.574


                                                           

	Train Perplexity: 148.609
	Valid Perplexity: 211.013


                                                           

	Train Perplexity: 140.404
	Valid Perplexity: 207.034


                                                           

	Train Perplexity: 135.062
	Valid Perplexity: 208.400


                                                           

	Train Perplexity: 133.032
	Valid Perplexity: 206.350


                                                           

	Train Perplexity: 130.543
	Valid Perplexity: 206.790


                                                           

	Train Perplexity: 130.329
	Valid Perplexity: 203.890


                                                           

	Train Perplexity: 128.725
	Valid Perplexity: 203.653


                                                           

	Train Perplexity: 127.253
	Valid Perplexity: 204.451


                                                           

	Train Perplexity: 127.879
	Valid Perplexity: 201.114


                                                           

	Train Perplexity: 126.657
	Valid Perplexity: 201.333


                                                           

	Train Perplexity: 127.328
	Valid Perplexity: 198.976


                                                           

	Train Perplexity: 126.716
	Valid Perplexity: 199.063


                                                           

	Train Perplexity: 127.166
	Valid Perplexity: 197.968


                                                           

	Train Perplexity: 126.531
	Valid Perplexity: 197.729


                                                           

	Train Perplexity: 126.253
	Valid Perplexity: 197.734


                                                           

	Train Perplexity: 126.203
	Valid Perplexity: 197.527


                                                           

	Train Perplexity: 125.915
	Valid Perplexity: 197.421


                                                           

	Train Perplexity: 125.709
	Valid Perplexity: 197.386


                                                           

	Train Perplexity: 125.549
	Valid Perplexity: 197.423


                                                           

	Train Perplexity: 125.542
	Valid Perplexity: 197.461


                                                           

	Train Perplexity: 125.395
	Valid Perplexity: 197.530


                                                           

	Train Perplexity: 125.403
	Valid Perplexity: 197.561


                                                           

	Train Perplexity: 125.352
	Valid Perplexity: 197.572


                                                           

	Train Perplexity: 125.423
	Valid Perplexity: 197.584


                                                           

	Train Perplexity: 125.462
	Valid Perplexity: 197.586


                                                           

	Train Perplexity: 125.352
	Valid Perplexity: 197.589


                                                           

	Train Perplexity: 125.210
	Valid Perplexity: 197.590


                                                           

	Train Perplexity: 125.472
	Valid Perplexity: 197.591


                                                           

	Train Perplexity: 125.291
	Valid Perplexity: 197.592


                                                           

	Train Perplexity: 125.356
	Valid Perplexity: 197.592


                                                           

	Train Perplexity: 125.374
	Valid Perplexity: 197.593


                                                           

	Train Perplexity: 125.331
	Valid Perplexity: 197.593


                                                           

	Train Perplexity: 125.295
	Valid Perplexity: 197.594


                                                           

	Train Perplexity: 125.311
	Valid Perplexity: 197.596


                                                           

	Train Perplexity: 125.499
	Valid Perplexity: 197.596


                                                           

	Train Perplexity: 125.234
	Valid Perplexity: 197.596


                                                           

	Train Perplexity: 125.424
	Valid Perplexity: 197.597


                                                           

	Train Perplexity: 125.407
	Valid Perplexity: 197.598


                                                           

	Train Perplexity: 125.302
	Valid Perplexity: 197.599


                                                           

	Train Perplexity: 125.401
	Valid Perplexity: 197.601


## 6. Testing

In [23]:
model.load_state_dict(torch.load('best-val-lstm_lm.pt',  map_location=device))
test_loss = evaluate(model, test_data, criterion, batch_size, seq_len, device)
print(f'Test Perplexity: {math.exp(test_loss):.3f}')

Test Perplexity: 180.985


## 6.1 Save and Load Model (Pickling)

PyTorch provides two ways to save models:
1. **Save state_dict (Recommended)** - Only saves weights, need model class to load
2. **Save entire model** - Uses pickle, saves everything but less portable

In [None]:
# Method 1: Save state_dict (Recommended)
# Save
torch.save({
    'model_state_dict': model.state_dict(),
    'optimizer_state_dict': optimizer.state_dict(),
    'vocab_size': vocab_size,
    'emb_dim': emb_dim,
    'hid_dim': hid_dim,
    'num_layers': num_layers,
    'dropout_rate': dropout_rate,
}, 'lstm_lm_checkpoint.pt')

print("Model checkpoint saved!")

In [None]:
# Load checkpoint
checkpoint = torch.load('lstm_lm_checkpoint.pt', map_location=device)

# Recreate model with saved hyperparameters
loaded_model = LSTMLanguageModel(
    checkpoint['vocab_size'],
    checkpoint['emb_dim'],
    checkpoint['hid_dim'],
    checkpoint['num_layers'],
    checkpoint['dropout_rate']
).to(device)

# Load weights
loaded_model.load_state_dict(checkpoint['model_state_dict'])
loaded_model.eval()

print("Model loaded from checkpoint!")

In [None]:
# Method 2: Save entire model with pickle (less portable)
import pickle

# Save entire model
torch.save(model, 'lstm_lm_full_model.pt')

# Or using pickle directly
with open('lstm_lm_pickle.pkl', 'wb') as f:
    pickle.dump({
        'model': model,
        'vocab': vocab,
        'hyperparams': {
            'vocab_size': vocab_size,
            'emb_dim': emb_dim,
            'hid_dim': hid_dim,
            'num_layers': num_layers,
            'dropout_rate': dropout_rate
        }
    }, f)

print("Full model pickled!")

In [None]:
# Load pickled model
with open('lstm_lm_pickle.pkl', 'rb') as f:
    loaded_data = pickle.load(f)

pickled_model = loaded_data['model']
pickled_vocab = loaded_data['vocab']
pickled_model.eval()

print(f"Loaded model with vocab size: {len(pickled_vocab)}")

## 7. Real-world inference

Here we take the prompt, tokenize, encode and feed it into the model to get the predictions.  We then apply softmax while specifying that we want the output due to the last word in the sequence which represents the prediction for the next word.  We divide the logits by a temperature value to alter the modelâ€™s confidence by adjusting the softmax probability distribution.

Once we have the Softmax distribution, we randomly sample it to make our prediction on the next word. If we get <unk> then we give that another try.  Once we get <eos> we stop predicting.
    
We decode the prediction back to strings last lines.

In [24]:
def generate(prompt, max_seq_len, temperature, model, tokenizer, vocab, device, seed=None):
    if seed is not None:
        torch.manual_seed(seed)
    model.eval()
    tokens = tokenizer(prompt)
    indices = [vocab[t] for t in tokens]
    batch_size = 1
    hidden = model.init_hidden(batch_size, device)
    with torch.no_grad():
        for i in range(max_seq_len):
            src = torch.LongTensor([indices]).to(device)
            prediction, hidden = model(src, hidden)
            
            #prediction: [batch size, seq len, vocab size]
            #prediction[:, -1]: [batch size, vocab size] #probability of last vocab
            
            probs = torch.softmax(prediction[:, -1] / temperature, dim=-1)  
            prediction = torch.multinomial(probs, num_samples=1).item()    
            
            while prediction == vocab['<unk>']: #if it is unk, we sample again
                prediction = torch.multinomial(probs, num_samples=1).item()

            if prediction == vocab['<eos>']:    #if it is eos, we stop
                break

            indices.append(prediction) #autoregressive, thus output becomes input

    itos = vocab.get_itos()
    tokens = [itos[i] for i in indices]
    return tokens

In [26]:
prompt = 'Harry Potter is '
max_seq_len = 30
seed = 0

#smaller the temperature, more diverse tokens but comes 
#with a tradeoff of less-make-sense sentence
temperatures = [0.5, 0.7, 0.75, 0.8, 1.0]
for temperature in temperatures:
    generation = generate(prompt, max_seq_len, temperature, model, basic_english_tokenizer, 
                          vocab, device, seed)
    print(str(temperature)+'\n'+' '.join(generation)+'\n')

0.5
harry potter is not a large amount of the sun in the city , but the forest is the first and the old rhyme of the other games . it is known as

0.7
harry potter is not found either to be detected in a rugged , looking from the forest . the episode has been safe down in the study of the dna as a city

0.75
harry potter is not found either relatively complete .

0.8
harry potter is not found either relatively complete .

1.0
harry potter is not responsible for procure these american estate in his live hideout of continental forest . the episode created the robotic augustus withoos to study probable techniques in their counties ,

