In [1]:
from tqdm import tqdm
import torch.nn as nn
import numpy as np
import torch
from torch.utils.data import Dataset
from torch.utils.data import DataLoader

In [9]:
# Fill in the code below to make Beam Search work.
FILL_IN = "FILL IN"

### Get the data and process
- This is the Mysterious island found in Project Gutenberg.

In [2]:
## Reading and processing text
with open('data/1268-0.txt', 'r', encoding="utf8") as fp:
    text=fp.read()
    
start_indx = text.find('THE MYSTERIOUS ISLAND')
end_indx = text.find('End of the Project Gutenberg')

text = text[start_indx:end_indx]
char_set = set(text)
print('Total Length:', len(text))
print('Unique Characters:', len(char_set))
assert(len(text) == 1130711)
assert(len(char_set) == 85)

Total Length: 1130711
Unique Characters: 85


### Tokenze and get other helpers
- We do this manually since everything is character based.

In [3]:
# The universe of words.
chars_sorted = sorted(char_set)

# Effectively, these maps are the tokenizer.
char2int = {ch:i for i,ch in enumerate(chars_sorted)}
int2char = np.array(chars_sorted)

# Tokenize the entire corpus.
text_encoded = np.array(
    [char2int[ch] for ch in text],
    dtype=np.int32)

print('Text encoded shape: ', text_encoded.shape)

print(text[:15], '     == Encoding ==> ', text_encoded[:15])
print(text_encoded[15:21], ' == Reverse  ==> ', ''.join(int2char[text_encoded[15:21]]))

Text encoded shape:  (1130711,)
THE MYSTERIOUS       == Encoding ==>  [48 36 33  1 41 53 47 48 33 46 37 43 49 47  1]
[37 47 40 29 42 32]  == Reverse  ==>  ISLAND


### Load the model

In [4]:
device = torch.device("cpu")

In [5]:
!du -h hw7_model.pt

7.1M	hw7_model.pt


In [6]:
# Lost a traced version of this model.
# Note this is the same as the HW 7 model but a little different
# The HW7 model had an if-else in its forward method, and this is not allowed.
# The forward method of this model takes hidden and cell, which could be all zeros but the user has to specify.
model = torch.jit.load('hw7_model.pt')

In [7]:
model

RecursiveScriptModule(
  original_name=RNN
  (embedding): RecursiveScriptModule(original_name=Embedding)
  (rnn): RecursiveScriptModule(original_name=LSTM)
  (fc): RecursiveScriptModule(original_name=Linear)
)

In [8]:
# 'jit' does not save other methods on a model, we need to define this helper and use it below.
def init_hidden(model, batch_size):
    return (
        torch.zeros(1, batch_size, model.rnn_hidden_size),
        torch.zeros(1, batch_size, model.rnn_hidden_size)
    )

### Beam search algorithm.
- Good article: https://towardsdatascience.com/foundations-of-nlp-explained-visually-beam-search-how-it-works-1586b9849a24

In [16]:
def beam_search_decoding(
    model,
    starting_str, 
    len_generated_text=500, 
    beams=5,
    print_paths=True
):
    assert(len(starting_str) != 0)

    # Get the encoding of the starting_str as a tensor of ints.
    encoded_input = FILL_IN
    
    # Reshape the above to be of appropiate dimension.
    encoded_input = FILL_IN

    # Put the model in eval mode.
    FILL_IN
    
    # Unfortunately, jit save does not save methods other than forward.
    # Use init_hidden to get the first hidden and cell states.
    hidden, cell = FILL_IN
    
    hidden = hidden.to(device)
    
    cell = cell.to(device)
    
    generated_log_prob = 0
    generated_str = starting_str[0]
        
    # Build up the starting hidden and cell states.
    # You can do this all in one go?
    for i in range(len(starting_str)-1):
        # Feed each letter 1 by 1 and then get the final hidden state.
        # Get the character at index i and push it through the model.
        # Get the logits, the hidden and cell states, which would will need later.
        FILL_IN
        
        # Get the probability of the generated character.
        # For input of index i, we want the probability that the model generated i+1.
        # We push the startiing_str[i] into the model, and append starting_str[i+1] to generated_str.
        generated_str += FILL_IN
        
        # Get the probabilities of the different characters that the model this. 
        # You need to apply Softmax to the logits.
        probs = FILL_IN
        
        # Add the log probability of the  appened char (int) to the running generated log probability.
        generated_log_prob += FILL_IN
    
    # Get the last character in the encoded input.
    last_char_int = FILL_IN
    
    # Push this through the model.
    logits, (hidden, cell) = FILL_IN
    
    # As before, get the probaility per character.
    probs = FILL_IN
    
    new_beams = []
    
    for j, prob in enumerate(probs):
        # For each probability, append the tuple (hidden, cell, the generatd str with the jth index char, generated str's the log probablity)
        # Note this is the running generated str and generated log probability.
        new_beams.append(
            FILL_IN
        )
        
    # Sort the beams from most proable to least. Use -log(generated_prob).
    new_beams = FILL_IN
    
    beam_to_beam_data = {}
    
    # Add the top "beams" = 5 beams to the hash map.
    # We should have a map going {beam_id -> (hidden, cell, generated_str, geenrated_log_prob)}
    for beam in range(beams):
        beam_to_beam_data[beam] = FILL_IN
    
    print('The number of beams is', len(beam_to_beam_data))
    
    # For each index of generated text.
    for i in range(len_generated_text):
        # Define new beams.
        new_beams = []
        
        # For each beam.
        for beam in range(beams):
            
            # Grab the 4 elements associated with this beam from beam_to_beam_data.
            FILL_IN
            
            # Get the last char in the str that's in the beam.
            last_char_int = FILL_IN
            
            # Push hidden, cell and the last_char_int through the model.
            logits, (hidden, cell) = FILL_IN
            
            # Get the probabilities.
            probs = FILL_IN
            
            # As before, append the 4 elements associated with this new beam to new beams.
            for j, prob in enumerate(probs):
                new_beams.append(
                    FILL_IN
                )
        
        # Sort the beams from most probable to least. Use -log(p).
        new_beams = FILL_IN
                
        # The number of beams considered should always satisfy this.
        # Except for the first iteration.
        assert(len(new_beams) == beams * len(char2int))
        
        # Leave this to true.
        if print_paths:
            print("The first 5 paths beam paths and the associated data for them: ")
            for beam in range(5):
                generated_str, generated_log_prob = new_beams[beam][2:]
                print("Text: \"{}\" Prob {:0.30f}".format(
                        generated_str, np.exp(generated_log_prob)
                ))
            _ = input("Insert anything to continue ...")
            print("\n")
                
        # Update the beams to be equal to the top beams.
        for beam in range(beams):
            beam_to_beam_data[beam] = new_beams[beam]
            
    generated_strs = []
    generated_log_probs = []
    
    # Grab the top beams, and return them.
    for beam in range(beams):
        (_, _, generated_str, generated_log_prob) = beam_to_beam_data[beam]
        generated_strs.append(generated_str)
        generated_log_probs.append(generated_log_prob)        
                
    return generated_strs, [np.exp(_) for _ in generated_log_probs]

In [None]:
torch.manual_seed(1)
model.to('cpu')
beams=5
len_generated_text=500

generated_strs, generated_probs = beam_search_decoding(
    model,
    starting_str="The island",
    len_generated_text=len_generated_text,
    beams=beams
)

for beam in range(beams):
    print(f"Beam {beam} information: ")
    print(generated_strs[beam])
    print(generated_probs[beam])
    

The number of beams is 5
The first 5 paths beam paths and the associated data for them: 
Text: "The island, " Prob (beyond starting) 0.000520035624382387119506165885
Text: "The island w" Prob (beyond starting) 0.000227799324636454402067883840
Text: "The island?”" Prob (beyond starting) 0.000097528783550422128417051182
Text: "The island o" Prob (beyond starting) 0.000093910921207016688472789256
Text: "The island i" Prob (beyond starting) 0.000063718712271929359939465209
Insert anything to continue ...


The first 5 paths beam paths and the associated data for them: 
Text: "The island wa" Prob (beyond starting) 0.000172444312622175221702547354
Text: "The island, w" Prob (beyond starting) 0.000092942431504603877394019018
Text: "The island, a" Prob (beyond starting) 0.000090473014760760145456380821
Text: "The island, t" Prob (beyond starting) 0.000070329177943556804835506524
Text: "The island?”
" Prob (beyond starting) 0.000062701887612621581271840632
Insert anything to continue ...


The 

Insert anything to continue ...


The first 5 paths beam paths and the associated data for them: 
Text: "The island, and it was neces" Prob (beyond starting) 0.000000156560949262643312743647
Text: "The island, and it was there" Prob (beyond starting) 0.000000053867963854688683617974
Text: "The island, and that is to s" Prob (beyond starting) 0.000000042620285320625669595069
Text: "The island, and it was not a" Prob (beyond starting) 0.000000026380103147166685645907
Text: "The island, and that is to b" Prob (beyond starting) 0.000000023313432578693227640123
Insert anything to continue ...


The first 5 paths beam paths and the associated data for them: 
Text: "The island, and it was necess" Prob (beyond starting) 0.000000155033162885432141675219
Text: "The island, and that is to sa" Prob (beyond starting) 0.000000036562188382879102302289
Text: "The island, and that is to be" Prob (beyond starting) 0.000000020992617679836261914198
Text: "The island, and it was there " Prob (beyond starti