In [1]:
from tqdm import tqdm
import torch.nn as nn
import numpy as np
import torch
from torch.utils.data import Dataset
from torch.utils.data import DataLoader

### Get the data and process
- This is the Mysterious island found in Project Gutenberg.

In [2]:
## Reading and processing text
with open('data/1268-0.txt', 'r', encoding="utf8") as fp:
    text=fp.read()
    
start_indx = text.find('THE MYSTERIOUS ISLAND')
end_indx = text.find('End of the Project Gutenberg')

text = text[start_indx:end_indx]
char_set = set(text)
print('Total Length:', len(text))
print('Unique Characters:', len(char_set))
assert(len(text) == 1130711)
assert(len(char_set) == 85)

Total Length: 1130711
Unique Characters: 85


### Tokenze and get other helpers
- We do this manually since everything is character based.

In [3]:
# The universe of words.
chars_sorted = sorted(char_set)

# Effectively, these maps are the tokenizer.
char2int = {ch:i for i,ch in enumerate(chars_sorted)}
int2char = np.array(chars_sorted)

# Tokenize the entire corpus.
text_encoded = np.array(
    [char2int[ch] for ch in text],
    dtype=np.int32)

print('Text encoded shape: ', text_encoded.shape)

print(text[:15], '     == Encoding ==> ', text_encoded[:15])
print(text_encoded[15:21], ' == Reverse  ==> ', ''.join(int2char[text_encoded[15:21]]))

Text encoded shape:  (1130711,)
THE MYSTERIOUS       == Encoding ==>  [48 36 33  1 41 53 47 48 33 46 37 43 49 47  1]
[37 47 40 29 42 32]  == Reverse  ==>  ISLAND


### Load the model

In [4]:
device = torch.device("cpu")

In [5]:
!du -h hw7_model.pt

6.3M	hw7_model.pt


In [2]:
# Lost a traced version of this model.
# Note this is the same as the HW 7 model but a little different
# The HW7 model had an if-else in its forward method, and this is not allowed.
# The forward method of this model takes hidden and cell, which could be all zeros but the user has to specify.
model = torch.jit.load('hw7_model.pt')

In [4]:
for n, p in model.named_parameters():
    print(n, p.shape)

embedding.weight torch.Size([85, 256])
rnn.weight_ih_l0 torch.Size([2048, 256])
rnn.weight_hh_l0 torch.Size([2048, 512])
rnn.bias_ih_l0 torch.Size([2048])
rnn.bias_hh_l0 torch.Size([2048])
fc.weight torch.Size([85, 512])
fc.bias torch.Size([85])


In [12]:
model

RecursiveScriptModule(
  original_name=RNN
  (embedding): RecursiveScriptModule(original_name=Embedding)
  (rnn): RecursiveScriptModule(original_name=LSTM)
  (fc): RecursiveScriptModule(original_name=Linear)
)

In [13]:
# 'jit' does not save other methods on a model, we need to define this helper and use it below.
def init_hidden(model, batch_size):
    return (
        torch.zeros(1, batch_size, model.rnn_hidden_size),
        torch.zeros(1, batch_size, model.rnn_hidden_size)
    )

### Beam search algorithm.
- Good article: https://towardsdatascience.com/foundations-of-nlp-explained-visually-beam-search-how-it-works-1586b9849a24

In [18]:
def beam_search_decoding(
    model,
    starting_str, 
    len_generated_text=500, 
    beams=5,
    print_paths=True
):
    assert(len(starting_str) != 0)

    encoded_input = torch.tensor([char2int[s] for s in starting_str])
    
    encoded_input = torch.reshape(encoded_input, (1, -1))

    model.eval()
    
    # Unfortunately, jit save does not save methods other than forward.
    hidden, cell = init_hidden(model, 1)
    
    hidden = hidden.to(device)
    
    cell = cell.to(device)
    
    generated_log_prob = 0
    generated_str = starting_str[0]
        
    # Build up the starting hidden and cell states.
    # You can do this all in one go?
    for i in range(len(starting_str)-1):
        # Feed each letter 1 by 1 and then get the final hidden state.
        out = encoded_input[:, i].reshape(1, 1)
        logits, (hidden, cell) = model(out, hidden, cell)
        
        # Get the probability of the generated character.
        # For input of index i, we want the probability that the model generated i+1.
        # We push the startiing_str[i] into the model, and append starting_str[i+1] to generated_str.
        generated_str += starting_str[i+1]
        
        probs = nn.Softmax(dim=1)(logits.squeeze(1)).squeeze()
        
        # P(y_{t} | y_{t-1}, y_{t-2}, ..., y_{1})
        generated_log_prob += np.log(
            probs[
                char2int[generated_str[i+1]]
            ].item()
        )
        
    last_char_int = encoded_input[:, -1].reshape(1,1)
    
    logits, (hidden, cell) = model(last_char_int, hidden, cell)
                        
    probs = nn.Softmax(dim=1)(logits.squeeze(1)).squeeze()
    
    new_beams = []
    
    for j, prob in enumerate(probs):
        new_beams.append(
            (
                hidden,
                cell,
                generated_str + int2char[j],
                generated_log_prob + np.log(prob.item())
            )
        )
        
    # Sort the beams from most proable to least. Use -log(p).
    new_beams = sorted(new_beams, key = lambda beam_data: -beam_data[-1])
    
    beam_to_beam_data = {}
    
    for beam in range(beams):
        beam_to_beam_data[beam] = new_beams[beam]
    
    print('The number of beams is', len(beam_to_beam_data))
        
    for i in range(len_generated_text):
        new_beams = []
        
        for beam in range(beams):
            
            (hidden, cell, generated_str, generated_log_prob) = beam_to_beam_data[beam]
                        
            last_char_int = torch.tensor(char2int[generated_str[-1]]).reshape(1, 1)
            
            logits, (hidden, cell) = model(last_char_int, hidden, cell)
            
            probs = nn.Softmax(dim=1)(logits.squeeze(1)).squeeze()
                                                
            for j, prob in enumerate(probs):
                new_beams.append(
                    (
                        hidden,
                        cell,
                        generated_str + int2char[j],
                        generated_log_prob + np.log(prob.item())
                    )
                )
        
        # Sort the beams from most proable to least. Use -log(p).
        new_beams = sorted(new_beams, key = lambda beam_data: -beam_data[-1])
                
        # The number of beams considered should always satisfy this.
        # Except for the first iteration.
        assert(len(new_beams) == beams * len(char2int))
        
        if print_paths:
            print("The first 5 paths beam paths and the associated data for them: ")
            for beam in range(5):
                generated_str, generated_log_prob = new_beams[beam][2:]
                print("Text: \"{}\" Prob {:0.30f}".format(
                        generated_str, np.exp(generated_log_prob)
                ))
            _ = input("Insert anything to continue ...")
            print("\n")
                
        # Update the beams to be equal to the top beams.
        for beam in range(beams):
            beam_to_beam_data[beam] = new_beams[beam]
            
    generated_strs = []
    generated_log_probs = []
        
    for beam in range(beams):
        (_, _, generated_str, generated_log_prob) = beam_to_beam_data[beam]
        generated_strs.append(generated_str)
        generated_log_probs.append(generated_log_prob)        
                
    return generated_strs, [np.exp(_) for _ in generated_log_probs]

In [None]:
torch.manual_seed(1)
model.to('cpu')
beams=5
len_generated_text=500

generated_strs, generated_probs = beam_search_decoding(
    model,
    starting_str="The island",
    len_generated_text=len_generated_text,
    beams=beams
)

for beam in range(beams):
    print(f"Beam {beam} information: ")
    print(generated_strs[beam])
    print(generated_probs[beam])
    

The number of beams is 5
The first 5 paths beam paths and the associated data for them: 
Text: "The island, " Prob 0.000520035624382387119506165885
Text: "The island w" Prob 0.000227799324636454402067883840
Text: "The island?”" Prob 0.000097528783550422128417051182
Text: "The island o" Prob 0.000093910921207016688472789256
Text: "The island i" Prob 0.000063718712271929359939465209
Insert anything to continue ...


The first 5 paths beam paths and the associated data for them: 
Text: "The island wa" Prob 0.000172444312622175221702547354
Text: "The island, w" Prob 0.000092942431504603877394019018
Text: "The island, a" Prob 0.000090473014760760145456380821
Text: "The island, t" Prob 0.000070329177943556804835506524
Text: "The island?”
" Prob 0.000062701887612621581271840632
Insert anything to continue ...


The first 5 paths beam paths and the associated data for them: 
Text: "The island was" Prob 0.000168390777398385932048241465
Text: "The island, an" Prob 0.00007139449848910471834589897

Insert anything to continue ...


The first 5 paths beam paths and the associated data for them: 
Text: "The island, and it was necessar" Prob 0.000000152024380545919445244178
Text: "The island, and it was therefor" Prob 0.000000018085692396889703460052
Text: "The island, and that is to say," Prob 0.000000013691599299554593781165
Text: "The island, and that is to say " Prob 0.000000013242337464759689322885
Text: "The island, and it was there wa" Prob 0.000000002904640619469246882912
Insert anything to continue ...


The first 5 paths beam paths and the associated data for them: 
Text: "The island, and it was necessary" Prob 0.000000150955421062497408813248
Text: "The island, and it was therefore" Prob 0.000000017986524745911925254024
Text: "The island, and that is to say, " Prob 0.000000010136494858898432154721
Text: "The island, and that is to say t" Prob 0.000000006149802997271163963106
Text: "The island, and it was there was" Prob 0.000000002887440146783889264895
