Import libraries

In [91]:
import json
import random
import torch
import torch.nn as nn
import torch.optim as optim

Load a dataset of 750 words used to train the RNN-based character level language model

In [97]:
with open('list_of_words.json', 'r') as f:
    data = json.load(f)

The vocabulary is built using the characters available in the given dataset

The dataset is modified such that start (<span style="color:red">"\<s\>"</span>) and end (<span style="color:red">"\</s\>"</span>) tokens are wrapped around each word

The tokens are also added to the vocabulary

The dataset is converted to a string so that can be used in the implemented <span style="color:red">CharecterLevelRNN</span> class.


In [3]:
vocab = list(set(' '.join(data)))
vocab.sort()
vocab.extend(["<s>", "</s>"]) # word start and end tokens
data = [f"<s>{word}</s>" for word in data]
data = ''.join(data)

The CharecterLevelRNN class includes the implementation of the language model together with methods used in training the model and making predictions.

In [4]:
class CharecterLevelRNN(nn.Module):
    """ RNN-based character level language model
        The developed model generates words one letter at a time
    """
    def __init__(self, vocab, embedding_size, hidden_layer_size):
        nn.Module.__init__(self)
        self.vocab = vocab
        self.vocab_size = len(vocab)
        self.embedding = nn.Embedding(self.vocab_size, embedding_size)        
        self.rnn = nn.RNN(embedding_size, hidden_layer_size, batch_first=True)
        self.fully_connected = nn.Linear(hidden_layer_size, self.vocab_size)
        self.char_to_index_map = {vocab[i]:i for i in range(len(vocab))}
        self.index_to_char_map = vocab
        self.start_token = "<s>"
        self.end_token = "</s>"

    
    def forward(self, x, hidden = None):
        """ Neural network forward pass
        Args:
            x(torch.Tensor): input data
            hidden(torch.Tensor): hidden state
        Returns:
            tuple: logits and the last calculated hidden state
            
        """
        x = self.embedding(x)
        output, hidden = self.rnn(x, hidden)
        logits = self.fully_connected(output)
        return logits, hidden
        

    def encode(self, text):
        """ Converts string to a list of characters' indices in the vocabulary
        Args:
            text (str): text to be processed
        Returns:
            list: list of indices
        """
        i, indices = 0, []
        while i < len(text):
            for token in [self.start_token, self.end_token]:
                if text.startswith(token, i):
                    indices.append(self.char_to_index_map[token])
                    i += len(token)
                    break
            else:
                ch = text[i]
                indices.append(self.char_to_index_map[ch])
                i += 1
        return indices

    
    def decode(self, indices):
        """ Converts indices to corresponding characters
        Args:
            indices(list): list of indices
        Returns:
            str: Decoded string
        """
        return ''.join([self.index_to_char_map[index] for index in indices])

    
    def __generate_batch(self, data, seq_length = 25, batch_size = 64):
        """ Randomly samples data and generate a batch
        Args:
            data: dataset to be sampled
            seq_length(int): length of each sampled sequence
            batch_size(int): batch size
        Returns:
            tuple: a batch of data and labels
        """
        batch_start_indices = torch.randint(len(data) - seq_length, (batch_size,))
        x = torch.stack([data[i:i+seq_length] for i in batch_start_indices])
        y = torch.stack([data[i+1:i+seq_length+1] for i in batch_start_indices])
        return x, y
    

    def train_model(self, data, epochs, optimizer, criterion):
        """ Trains the weigths of the model
        Args:
            data(str): training data
            epochs: number of epochs
            optimizer: e.g., torch.optim.Adam
            criterion: e.g., nn.torch.CrossEntropyLoss
        Returns:
            None
        """
        data = torch.tensor(self.encode(data), dtype = torch.long)
        for epoch in range(1, epochs+1):
            self.train()
            x_batch, y_batch = self.__generate_batch(data)
            logits, _ = self.forward(x_batch)
            loss = criterion(logits.view(-1, self.vocab_size), y_batch.view(-1))
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            if epoch % 500 == 0:
                print(f"Epoch {epoch}, Loss: {loss.item():.4f}")


    def __calculate_hidden_state(self, input_seq):
        """ Calculates hidden state (which is used to generate the next character given a sequence)
        Args:
            input_seq(torch.Tensor): sequence of input
        Returns:
            torch.Tensor: last calculated hidden state
        """
        if input_seq.nelement() == 0:
            return None
        _, hidden = self.forward(input_seq, None)
        return hidden.squeeze(0)


    def generate_word(self, prefix = None, temperature = 1.0):
        """ Generates a word character by character given a prefix. Uses softmax, and samples a char randomly
        Args:
            prefix(str): prefix of the expected word
            temperature(float): temperature scaling applied to logits 
                temperature<1 -> more confident sampling (peaky distribution)
                temperature=1 -> normal sampling
                temperature>1 -> more creative sampling (flatter distribution)
        Returns:
            str: generated word
        """
        if prefix is None:
            prefix = self.start_token
        if not prefix.startswith(self.start_token):
            prefix = self.start_token + prefix
        indices   = self.encode(prefix)
        input_seq = torch.tensor(indices, dtype=torch.long).unsqueeze(0)  # shape (1, L)
        hidden = self.__calculate_hidden_state(input_seq[:,:-1])
        self.eval()
        while True:
            current_index = input_seq[:,-1]
            logit, hidden = self.forward(current_index, hidden)
            probs = torch.softmax(logit/temperature, dim=-1)
            predicted_index = torch.multinomial(probs, num_samples=1).item()
            if self.decode([predicted_index]) == self.end_token:
                break
            input_seq = torch.cat([input_seq, torch.tensor([[predicted_index]])], dim=1)
        return(self.decode(input_seq.squeeze().tolist()).replace(self.start_token, ''))

Creating an object of the CharecterLevelRNN class (embedding size = 32; hidden layer size = 128)

In [38]:
char_level_model = CharecterLevelRNN(vocab, 32, 256)

Defining optimizer and loss

In [39]:
optimizer = optim.Adam(char_level_model.parameters(), 0.001)
criterion = nn.CrossEntropyLoss()
char_level_model.train_model(data, 10000, optimizer, criterion)

Epoch 500, Loss: 0.8703
Epoch 1000, Loss: 0.4477
Epoch 1500, Loss: 0.3565
Epoch 2000, Loss: 0.3259
Epoch 2500, Loss: 0.3183
Epoch 3000, Loss: 0.2896
Epoch 3500, Loss: 0.2885
Epoch 4000, Loss: 0.2673
Epoch 4500, Loss: 0.2757
Epoch 5000, Loss: 0.2687
Epoch 5500, Loss: 0.2624
Epoch 6000, Loss: 0.2631
Epoch 6500, Loss: 0.2490
Epoch 7000, Loss: 0.2522
Epoch 7500, Loss: 0.2691
Epoch 8000, Loss: 0.2569
Epoch 8500, Loss: 0.2547
Epoch 9000, Loss: 0.2663
Epoch 9500, Loss: 0.2532
Epoch 10000, Loss: 0.2448


<span style = "color:blue">Example 1:</span> Generating words with suffix <span style="color:blue">co"</span>

In [77]:
ans = []
for _ in range(10):
    ans.append(char_level_model.generate_word("co"))
print(sorted(ans))

['college', 'color', 'computer', 'computer', 'computer', 'cook', 'cool', 'correct', 'cost', 'course']


<span style = "color:blue">Example 2:</span> Generating words starting with <span style = "color:blue">b</span>

In [80]:
ans = []
for _ in range(10):
    ans.append(char_level_model.generate_word("b"))
print(sorted(ans))

['bag', 'base', 'begin', 'behind', 'big', 'block', 'boat', 'boy', 'bubllow', 'busy']


<span style = "color:blue">Example 3:</span> Generating words starting with <span style = "color:blue">z</span>

In [88]:
ans = []
for _ in range(10):
    ans.append(char_level_model.generate_word("z"))
print(sorted(ans))

['zapper', 'zecour', 'zen', 'zence', 'zeolite', 'zeolite', 'zeppelin', 'zillion', 'zip', 'zoom']


<span style = "color:blue">Example 4:</span> Generating words without prefix

In [83]:
ans = []
for _ in range(10):
    ans.append(char_level_model.generate_word(""))
print(sorted(ans))

['better', 'birth', 'direction', 'guess', 'hospital', 'machine', 'match', 'money', 'tell', 'zero']
