<h3> I work through the chatbot tutorial on PyTorch. The chatbot is trained using the  Cornell Movie Dialogs Dataset using a seq2seq model

In [1]:
from __future__ import absolute_import, division, print_function, unicode_literals
import torch
import torch.nn as nn
from torch import optim
from torch.jit import script, trace
import torch.nn.functional as F
import csv
import random
import re
import os
import unicodedata
import codecs
from io import open
import itertools
import math

In [2]:
# Let's read in some of the data
corpus_name = 'cornell movie-dialogs corpus'

def printlines(file, n = 10):
    with open(file, 'rb') as f:
        lines  = f.readlines()
    for line in lines[:n]:
        print(line)

printlines(os.path.join(corpus_name, 'movie_lines.txt'))

b'L1045 +++$+++ u0 +++$+++ m0 +++$+++ BIANCA +++$+++ They do not!\n'
b'L1044 +++$+++ u2 +++$+++ m0 +++$+++ CAMERON +++$+++ They do to!\n'
b'L985 +++$+++ u0 +++$+++ m0 +++$+++ BIANCA +++$+++ I hope so.\n'
b'L984 +++$+++ u2 +++$+++ m0 +++$+++ CAMERON +++$+++ She okay?\n'
b"L925 +++$+++ u0 +++$+++ m0 +++$+++ BIANCA +++$+++ Let's go.\n"
b'L924 +++$+++ u2 +++$+++ m0 +++$+++ CAMERON +++$+++ Wow\n'
b"L872 +++$+++ u0 +++$+++ m0 +++$+++ BIANCA +++$+++ Okay -- you're gonna need to learn how to lie.\n"
b'L871 +++$+++ u2 +++$+++ m0 +++$+++ CAMERON +++$+++ No\n'
b'L870 +++$+++ u0 +++$+++ m0 +++$+++ BIANCA +++$+++ I\'m kidding.  You know how sometimes you just become this "persona"?  And you don\'t know how to quit?\n'
b'L869 +++$+++ u0 +++$+++ m0 +++$+++ BIANCA +++$+++ Like my fear of wearing pastels?\n'


<h4> We need to take this data and extract "sentence pairs" that form the basis on which we can train our generative chatbot model.     

<h4> Because of the way the data is formatted -- first we need to extract the metadata containing the conversation ids and then group conversations into pairs which can be used for training

In [4]:
def loadlines(filename, fields):
    lines = {}
    with open(filename, 'r', encoding = 'iso-8859-1') as f:
        for line in f:
            values = line.split(' +++$+++ ')
            lineObj  = {}
            for i, field in enumerate(fields):
                lineObj[field] = values[i]
            lines[lineObj['lineID']] = lineObj

    return lines

def loadConversations(filename, lines, fields):
    conversations = []
    with open(filename, 'r', encoding = 'iso-8859-1') as f:
        for line in f:
            values = line.split(" +++$+++ ")
            # Extract fields
            convObj = {}
            for i, field in enumerate(fields):
                convObj[field] = values[i]
            # Convert string to list (convObj["utteranceIDs"] == "['L598485', 'L598486', ...]")
            lineIds = eval(convObj["utteranceIDs"])
            # Reassemble lines
            convObj["lines"] = []
            for lineId in lineIds:
                convObj["lines"].append(lines[lineId])
            conversations.append(convObj)
    return conversations
    

def get_sentence_pairs(conversations):
    qa_pairs = []
    for conversation in conversations:
        # Iterate over all the lines of the conversation
        for i in range(len(conversation["lines"]) - 1):  # We ignore the last line (no answer for it)
            inputLine = conversation["lines"][i]["text"].strip()
            targetLine = conversation["lines"][i+1]["text"].strip()
            # Filter wrong samples (if one of the lists is empty)
            if inputLine and targetLine:
                qa_pairs.append([inputLine, targetLine])
    return qa_pairs

In [8]:
new_path = os.path.join(corpus_name, "formatted_movie_lines.txt")

delimiter = '\t'
delimiter = str(codecs.decode(delimiter, "unicode_escape"))

lines = {}
conversations = []
MOVIE_LINES_FIELDS = ["lineID", "characterID", "movieID", "character", "text"]
MOVIE_CONVERSATIONS_FIELDS = ["character1ID", "character2ID", "movieID", "utteranceIDs"]

print("\nProcessing corpus...")
lines = loadlines(os.path.join(corpus_name, "movie_lines.txt"), MOVIE_LINES_FIELDS)
print("\nLoading conversations...")
conversations = loadConversations(os.path.join(corpus_name, "movie_conversations.txt"),
                                  lines, MOVIE_CONVERSATIONS_FIELDS)



Processing corpus...

Loading conversations...


In [9]:
conversations[:1]

[{'character1ID': 'u0',
  'character2ID': 'u2',
  'movieID': 'm0',
  'utteranceIDs': "['L194', 'L195', 'L196', 'L197']\n",
  'lines': [{'lineID': 'L194',
    'characterID': 'u0',
    'movieID': 'm0',
    'character': 'BIANCA',
    'text': 'Can we make this quick?  Roxanne Korrine and Andrew Barrett are having an incredibly horrendous public break- up on the quad.  Again.\n'},
   {'lineID': 'L195',
    'characterID': 'u2',
    'movieID': 'm0',
    'character': 'CAMERON',
    'text': "Well, I thought we'd start with pronunciation, if that's okay with you.\n"},
   {'lineID': 'L196',
    'characterID': 'u0',
    'movieID': 'm0',
    'character': 'BIANCA',
    'text': 'Not the hacking and gagging and spitting part.  Please.\n'},
   {'lineID': 'L197',
    'characterID': 'u2',
    'movieID': 'm0',
    'character': 'CAMERON',
    'text': "Okay... then how 'bout we try out some French cuisine.  Saturday?  Night?\n"}]}]

In [12]:
print("\n Writing conversation pairs to new file")
with open(new_path, 'w', encoding = 'utf-8') as outfile:
    writer = csv.writer(outfile, delimiter=delimiter, lineterminator = '\n')
    for pair in get_sentence_pairs(conversations):
        writer.writerow(pair)

print("\nSample lines from file:")
printlines(new_path)


 Writing conversation pairs to new file

Sample lines from file:
b"Can we make this quick?  Roxanne Korrine and Andrew Barrett are having an incredibly horrendous public break- up on the quad.  Again.\tWell, I thought we'd start with pronunciation, if that's okay with you.\n"
b"Well, I thought we'd start with pronunciation, if that's okay with you.\tNot the hacking and gagging and spitting part.  Please.\n"
b"Not the hacking and gagging and spitting part.  Please.\tOkay... then how 'bout we try out some French cuisine.  Saturday?  Night?\n"
b"You're asking me out.  That's so cute. What's your name again?\tForget it.\n"
b"No, no, it's my fault -- we didn't have a proper introduction ---\tCameron.\n"
b"Cameron.\tThe thing is, Cameron -- I'm at the mercy of a particularly hideous breed of loser.  My sister.  I can't date until she does.\n"
b"The thing is, Cameron -- I'm at the mercy of a particularly hideous breed of loser.  My sister.  I can't date until she does.\tSeems like she could 

<h3> Now we need to assemble the sentence pairs into something machine readable -- the following class converts words to indexes and vice versa and also trims infrequent words to speed up training. To do this, it also keeps a count of how many times a word appears in the vocabulary.
    
<h3> Finally, we also have a MAX_LENGTH to filter out sentences > MAX_LENGTH words. 
    
<h3> All of this boilerplate preprocessing can be used in other NLP seq-to-seq models, so it is useful to assemble this as a class

In [13]:
PAD_token = 0 # pad short sentences
SOS_token = 1 # sentence start token
EOS_token = 2 # sentence end token

class Voc:
    def __init__(self, name):
        self.name = name
        self.trimmed = False
        self.word2index  = {}
        self.index2word = {PAD_token: "PAD", SOS_token: "SOS", EOS_token: "EOS"}
        self.word2count = {}
        self.num_words = 3 
    
    def addSentences(self, sentence):
        for word in sentence.split(' '):
            self.addWord(word)
    
    def addWord(self, word):
        if word not in self.word2index:
            self.word2index[word] = self.num_words
            self.word2count[word] = 1
            self.index2word[self.num_words] = word
            self.num_words +=1
        else:
            self.word2count[word] += 1 # if word is seen before, update count by 1
    
    # remove words below count threshold
    def trim(self, min_count):
        if self.trimmed():
            return
        self.trimmed =True
        
        keep_words = []
        
        for k, v in self.word2count.items():
            if v>=min_count:
                keep_words.append(k)
        
        self.word2index = {}
        self.word2count = {}
        self.index2word = {PAD_token: "PAD", SOS_token: "SOS", EOS_token: "EOS"}
        self.num_words = 3
        
        for word in keep_words:
            self.addWord(word)

In [27]:
MAX_LENGTH = 10

def unicode_to_ascii(s):
    return ''.join(c for c in unicodedata.normalize('NFD', s) if unicodedata.category(c) != 'Mn')

def normalizeString(s):
    s = unicode_to_ascii(s.lower().strip())
    s  = re.sub(r"([.!?])", r" \1", s)
    s = re.sub(r"[^a-zA-Z.!?]+", r" ", s)
    s = re.sub(r"\s+", r" ", s).strip()
    return s

def readVocs(datafile, corpus_name):
    print("Reading Lines ...")
    
    lines = open(datafile, encoding = 'utf-8').\
        read().strip().split('\n')
    
    pairs = [[normalizeString(s) for s in l.split('\t')] for l in lines]
    voc = Voc(corpus_name)
    return voc, pairs

# return only both sentences in the pair are less than MAX_LENGTH
def filter_helper(p):
    return len(p[0].split(' ')) < MAX_LENGTH and len(p[1].split(' ')) < MAX_LENGTH

def filterPairs(pairs):
    return [pair for pair in pairs if filter_helper(pair)]

In [30]:
def loadPrepareData(corpus_name, datafile, save_dir):
    print("Start preparing training data...")
    
    voc, pairs = readVocs(datafile, corpus_name)
    print("Read {!s} sentence pairs".format(len(pairs)))
    pairs = filterPairs(pairs)
    print("Trimmed to {!s} sentence pairs".format(len(pairs)))
    
    print("Counting Words...")
    for pair in pairs:
        voc.addSentences(pair[0])
        voc.addSentences(pair[1])
    print("Counted Words:", voc.num_words)
    return voc, pairs

save_dir = os.path.join("data", "save")
voc, pairs = loadPrepareData(corpus_name, new_path, save_dir)

Start preparing training data...
Reading Lines ...
Read 221282 sentence pairs
Trimmed to 64271 sentence pairs
Counting Words...
Counted Words: 18008


<h3> To speed up training, we can also trim any sentences containing rare words. Skipping this for now, since the training dataset isn't too large after trimming large sentences.

In [143]:
def indexesfromSentence(voc, sentence):
    return [voc.word2index[word] for word in sentence.split(' ')] + [EOS_token]

# Note that this function below impicitly does a transpose -- putput becomes [MAX_LENGTh, Batch_size]
def zeroPadding(l, fillvalue = PAD_token):
    return list(itertools.zip_longest(*l, fillvalue = fillvalue))

# creates an input tensor
def inputVar(l, voc):
    indexes_batch = [indexesfromSentence(voc, sentence) for sentence in l]
    # indexes_batch has shape == [batch_size, sentence_length]
    lengths = torch.tensor([len(indexes) for indexes in indexes_batch])
    padList = zeroPadding(indexes_batch)
    # pad_var has shape [MAX_length, batch_size] -- the transpose is implicitly done by zeroPadding. 
    padVar = torch.LongTensor(padList)
    return padVar, lengths


def binaryMatrix(l, value = PAD_token):
    m = []
    for i, seq in enumerate(l):
        m.append([])
        for token in seq:
            if token == PAD_token:
                m[i].append(0)
            else:
                m[i].append(1)
    return m

# this returns padded targeet sequence tensor, with a padding mask and a max target length
def OutputVar(l, voc):
    indexes_batch = [indexesfromSentence(voc, sentence) for sentence in l]
    max_target_len = max([len(indexes) for indexes in indexes_batch])
    padList = zeroPadding(indexes_batch)
    mask = binaryMatrix(padList)
    mask = torch.ByteTensor(mask) # conver to tensor so can run tensor operations on mask.
    padVar = torch.LongTensor(padList)
    return padVar, mask, max_target_len

In [144]:
# Get the Training Batches
def batch2TrainingData(voc, pair_batch):
    pair_batch.sort(key = lambda x: len(x[0].split(' ')), reverse = True)
    input_batch, output_batch = [], []
    for pair in pair_batch:
        input_batch.append(pair[0])
        output_batch.append(pair[0])
    inp, lengths = inputVar(input_batch, voc)
    output, mask, max_target_len = OutputVar(output_batch, voc)
    return inp, lengths, output, mask, max_target_len

In [145]:
small_batch_size = 8
batches = batch2TrainingData(voc, [random.choice(pairs) for _ in range(small_batch_size)])
input_variable, lengths, target_variable, mask, max_target_len = batches

print("input_variable:", input_variable)
print("lengths:", lengths)
print("target_variable:", target_variable)
print("mask:", mask)
print("max_target_len:", max_target_len)

input_variable: tensor([[   52,    79,   351,   654,    55,  7199,     9, 10131],
        [  120,    39,  1214,   343,   251,    23,  1504,     4],
        [   79,   580,     4,    79,   802,     6,     6,     2],
        [    6,   424,  2002,   304,     4,     2,     2,     0],
        [   73,   293,  8817,     4,     2,     0,     0,     0],
        [ 2748,   103,     4,     2,     0,     0,     0,     0],
        [    6,     9,     2,     0,     0,     0,     0,     0],
        [ 1956,  3067,     0,     0,     0,     0,     0,     0],
        [    6,     4,     0,     0,     0,     0,     0,     0],
        [    2,     2,     0,     0,     0,     0,     0,     0]])
lengths: tensor([10, 10,  7,  6,  5,  4,  4,  3])
target_variable: tensor([[   52,    79,   351,   654,    55,  7199,     9, 10131],
        [  120,    39,  1214,   343,   251,    23,  1504,     4],
        [   79,   580,     4,    79,   802,     6,     6,     2],
        [    6,   424,  2002,   304,     4,     2,     2, 

In [146]:
class EncoderRNN(nn.Module):
    def __init__(self, hidden_size, embedding, n_layers = 1, dropout= 0):
        super(EncoderRNN, self).__init__()
        self.n_layers = n_layers
        self.hidden_size = hidden_size
        self.embedding = embedding
        # GRU takes an input_size and hidden_size -- for us since input_size is an embeddng of size hidden_size,
        # input_size and hidden_size are identical
        self.gru = nn.GRU(hidden_size, hidden_size, n_layers, 
                          dropout = (0 if n_layers == 1 else dropout), bidirectional = True)
        
    def forward(self, input_seq, input_lengths, hidden = None):
        embedded = self.embedding(input_seq)
        # Pack padded bactch of sequence for RNN module
        packed = torch.nn.utils.rnn.pack_padded_sequence(embedded, input_lengths)
        outputs, hidden = self.gru(packed, hidden)
        outputs, _ = torch.nn.utils.rnn.pad_packed_sequence(outputs)
        outputs = outputs[:, :, :self.hidden_size] + outputs[:, :, self.hidden_size:] # add bidirectional outputs. 
        return outputs, hidden

In [147]:
class Attn(torch.nn.Module):
    def __init__(self, method, hidden_size):
        super(Attn, self).__init__()
        self.method = method
        if self.method not in ['dot', 'general', 'concat']:
            raise ValueError(self.method, "is not an appropriate attention method.")
        self.hidden_size = hidden_size
        if self.method == 'general':
            self.attn = torch.nn.Linear(self.hidden_size, hidden_size)
        elif self.method == 'concat':
            self.attn = torch.nn.Linear(self.hidden_size * 2, hidden_size)
            self.v = torch.nn.Parameter(torch.FloatTensor(hidden_size))

    def dot_score(self, hidden, encoder_output):
        return torch.sum(hidden * encoder_output, dim=2)

    def general_score(self, hidden, encoder_output):
        energy = self.attn(encoder_output)
        return torch.sum(hidden * energy, dim=2)

    def concat_score(self, hidden, encoder_output):
        energy = self.attn(torch.cat((hidden.expand(encoder_output.size(0), -1, -1), encoder_output), 2)).tanh()
        return torch.sum(self.v * energy, dim=2)

    def forward(self, hidden, encoder_outputs):
        # Calculate the attention weights (energies) based on the given method
        if self.method == 'general':
            attn_energies = self.general_score(hidden, encoder_outputs)
        elif self.method == 'concat':
            attn_energies = self.concat_score(hidden, encoder_outputs)
        elif self.method == 'dot':
            attn_energies = self.dot_score(hidden, encoder_outputs)

        # Transpose max_length and batch_size dimensions
        attn_energies = attn_energies.t()

        # Return the softmax normalized probability scores (with added dimension)
        return F.softmax(attn_energies, dim=1).unsqueeze(1)

In [131]:
class LuongAttnDecoderRNN(nn.Module):
     def __init__(self, attn_model, embedding,  hidden_size, output_size, n_layers = 1, dropout= 0.1):
        super(LuongAttnDecoderRNN, self).__init__()
        
        self.attn_model = attn_model
        self.hidden_size = hidden_size
        self.output_size = output_size
        self.n_layers = n_layers
        self.dropout = dropout
        
        self.embedding = embedding
        self.embedding_dropout = nn.Dropout(dropout)
        self.gru = nn.GRU(hidden_size, hidden_size, n_layers, dropout = (0 if n_layers == 1 else dropout))
        self.concat = nn.Linear(hidden_size * 2, hidden_size)
        self.out = nn.Linear(hidden_size, output_size)
        
        self.attn = Attn(attn_model, hidden_size)
        
     def forward(self, input_step, last_hidden, encoder_outputs):
            
        # this is done one wrod at a time until you hit EOS
        embedded = self.embedding(input_step)
        embedded = self.embedding_dropout(embedded)
            
        rnn_output, hidden = self.gru(embedded, last_hidden)
        attn_weights = self.attn(rnn_output, encoder_outputs)
        
        context = attn_weights.bmm(encoder_outputs.transpose(0, 1)) #bmm does a batch matmul
        
        rnn_output = rnn_output.squeeze(0)
        context = context.squeeze(1)
        
        
        concat_input = torch.cat((rnn_output, context), 1)
        concat_output = torch.tanh(self.concat(concat_input))
            
        output = self.out(concat_output)
        output = F.softmax(output, dim = 1)
            
        return output, hidden # need the previous hidden state for new 
            

<h3> Training
    
<h4> First step is to define the loss function -- this is where we need the binary matrix, as we want to compute the negative log likelihood only for the elements which have 1 in the binary matrix, not the paddings
   

In [148]:
def maskLoss(inp, target, mask):
    nTotal = mask.sum()
    crossEntropy = -torch.log(torch.gather(inp, 1, target.view(-1, 1)).squeeze(1))
    loss = crossEntropy.masked_select(mask).mean()
    return loss, nTotal.item()

In [150]:
def train(input_variable, lengths, target_variable, mask, max_target_len, encoder, decoder, embedding, encoder_optimizer,
         decoder_optimizer, batch_size, clip, max_length= MAX_LENGTH):
    
    encoder_optimizer.zero_grad()
    decoder_optimizer.zero_grad()
    
    loss = 0
    print_losses = []
    n_totals = 0
    
    encoder_outputs, encoder_hidden = encoder(input_variable, lengths)
    
    decoder_input = torch.LongTensor([[SOS_token for _ in range(batch_size)]])
    decoder_hidden = encoder_hidden[:decoder.n_layers]
    # Model uses teacher forcing. 
    for t in range(max_target_len):
        decoder_output, decoder_hidden = decoder(decoder_input, decoder_hidden, 
                                                 encoder_outputs)
        
        # set decoder input to be target variable -- Teacher Forcing
        decoder_input = target_variable[t].view(1, -1)
        
        mask_loss, nTotal = maskLoss(decoder_output, target_variable[t], mask[t])
        loss += mask_loss
        print_losses.append(mask_loss.item() * nTotal)
        n_totals += nTotal
    
    loss.backward() #backprop
    
    # gradient clipping
    _ = torch.nn.utils.clip_grad_norm_(encoder.parameters(), clip)
    _ = torch.nn.utils.clip_grad_norm_(decoder.parameters(), clip)
    
    encoder_optimizer.step()
    decoder_optimizer.step()
    
    return sum(print_losses)/n_totals

<h3> Now iterate the training over many batches

In [154]:
def trainIters(model_name,voc, pairs, encoder, decoder,encoder_optimizer, decoder_optimizer, embedding, encoder_n_layers,
              decoder_n_layers, n_iteration, batch_size, print_every, clip):
    
    # Load batches for each iteration
    training_batches = [batch2TrainingData(voc, [random.choice(pairs) for _ in range(batch_size)])
                       for _ in range(n_iteration)]

    # initialization
    print("Initializing ...")
    start_iteration = 1
    print_loss = 0
    
    # Training Loop
    print("Training ...")
    for iteration in range(start_iteration, n_iteration+1):
        training_batch = training_batches[iteration-1]
        input_variable, lengths, target_variable, mask, max_target_len = training_batch

        loss = train(input_variable, lengths, target_variable, mask, max_target_len, encoder, decoder, embedding, encoder_optimizer,
         decoder_optimizer, batch_size, clip)
        print_loss +=loss
        
        if iteration % print_every == 0:
            print_loss_avg = print_loss/print_every
            print("Iteration: {}; Percent Complete {:.1f}%; Average Loss {:.4f}".format(iteration, iteration/n_iteration,
                                                                                       print_loss_avg))
            print_loss = 0

<h3> Let's run the model and see how it does. 
    
<h4> Note the production level code -- include a checkpoint saver in the 
     trainIters code.
     Specify Model names and parameters. The code below should be put into a different *task.py* file
     which loads in the model dependencies. 
     Use argparse to give the user flexibility in selecting what attention model he/she wants to use. 
    

In [152]:
model_name = 'chatbot_model'
attn_model = 'dot' # user can select this using argparse
hidden_size = 500
encoder_n_layers = 2
decoder_n_layers = 2
dropout = 0.1
batch_size = 64

print("Building Encoder and Decoder ...")
embedding = nn.Embedding(voc.num_words, hidden_size)
encoder = EncoderRNN(hidden_size, embedding, encoder_n_layers, dropout)
decoder = LuongAttnDecoderRNN(attn_model, embedding, hidden_size, voc.num_words, decoder_n_layers, dropout)
print("Models built and we're ready to go!!")

Building Encoder and Decoder ...
Models built and we're ready to go!!


In [156]:
clip = 50.0
learning_rate = 0.0001
decoder_learning_ratio = 5.0 # decoder learns with a faster learning rate than encoder
n_iteration = 4000
print_every = 100
#save_every  = 500

encoder.train()
decoder.train()

print("Building Optimizers ...")
encoder_optimizer = optim.Adam(encoder.parameters(), lr = learning_rate)
decoder_optimizer = optim.Adam(decoder.parameters(), lr = learning_rate * decoder_learning_ratio)

print("Starting Training...")
trainIters(model_name,voc, pairs, encoder, decoder,encoder_optimizer, decoder_optimizer, embedding, encoder_n_layers,
              decoder_n_layers, n_iteration, batch_size, print_every, clip)


Building Optimizers ...
Starting Training...
Initializing ...
Training ...
Iteraton: 100; Percent Complete 0.0%; Average Loss 4.8209
Iteraton: 200; Percent Complete 0.1%; Average Loss 3.6346
Iteraton: 300; Percent Complete 0.1%; Average Loss 2.6633
Iteraton: 400; Percent Complete 0.1%; Average Loss 2.0513
Iteraton: 500; Percent Complete 0.1%; Average Loss 1.6463
Iteraton: 600; Percent Complete 0.1%; Average Loss 1.3975
Iteraton: 700; Percent Complete 0.2%; Average Loss 1.2084
Iteraton: 800; Percent Complete 0.2%; Average Loss 1.0035
Iteraton: 900; Percent Complete 0.2%; Average Loss 0.8814
Iteraton: 1000; Percent Complete 0.2%; Average Loss 0.7647
Iteraton: 1100; Percent Complete 0.3%; Average Loss 0.6423
Iteraton: 1200; Percent Complete 0.3%; Average Loss 0.5823
Iteraton: 1300; Percent Complete 0.3%; Average Loss 0.5063
Iteraton: 1400; Percent Complete 0.3%; Average Loss 0.4748
Iteraton: 1500; Percent Complete 0.4%; Average Loss 0.4252
Iteraton: 1600; Percent Complete 0.4%; Average Lo

<h3> Decode and Evaluate
    
<h4> Having trained a chatbot, we want to interact with it. To do so, we need to forward the input through the encoder model, prep encoder final state to be decoder hidden input. Now we don't use teacher forcing and just do a greedy search where the most likely word through a softmax get picked as the next input for the decoder. Return the collection of word tokens and scores

In [157]:
class GreedySearchDecoder(nn.modulesodule):
    def __init__(self, encoder, decoder):
        super(GreedySearchDecoder, self).__init__()
        self.encoder = encoder
        self.decoder = decoder
        
    def forward(self, input_seq, input_length, max_length):
        encoder_outputs, encoder_hidden = self.encoder(input_seq, input_length)
        decoder_hidden = encoder_hidden[:decoder.n_layers]
        decoder_input = torch.ones(1, 1, dytpe = torch.long) * SOS_token
        all_tokens = torch.zeros([0], dtype = torch.long)
        all_scores = torch.zeros([0], dype = torch.long)
        
        for _ in range(max_length):
            decoder_output, decoder_hidden = self.decoder(decoder_input, decoder_hidden, encoder_outputs)
            print(decoder_output.shape)
            decoder_scores, decoder_input = torch.max(decoder_output, dim= 1)
            print(decoder_scores)
            print(decoder_input)
            
            all_tokens = torch.cat((all_tokens, decoder_input),  dim = 0)
            print(all_tokens)
            all_scores = torch.cat((all_scores, decoder_scores), dim = 0)
            decoder_input = torch.unsqueeze(decoder_input, 0)
        return all_tokens, all_scores

In [None]:
def evaluate(encoder, decoder, searcher, voc, sentence, max_length = MAX_LENGTH):
    indexes_batch = [indexes]