# Preparations

In [1]:
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals

import torch
from torch.jit import script, trace
import torch.nn as nn
from torch import optim
import torch.nn.functional as F
import csv
import random
import re
import os
import unicodedata
import codecs
from io import open
import itertools
import math


USE_CUDA = torch.cuda.is_available()
device = torch.device("cuda" if USE_CUDA else "cpu")

# Load and Preprocess data

In [4]:
corpus_name = "cornell movie-dialogs corpus"
corpus = os.path.join("D:\\MINE\\HKUST\\SecondSeminar\\2-6010U - Artificial Intelligence in Finance\\robot-Doraemon\\pytorch", corpus_name)

def printLines(file, n=10):
    with open(file, 'rb') as datafile:
        lines = datafile.readlines()
    for line in lines[:n]:
        print(line)

printLines(os.path.join(corpus, "movie_lines.txt"))

b'L1045 +++$+++ u0 +++$+++ m0 +++$+++ BIANCA +++$+++ They do not!\n'
b'L1044 +++$+++ u2 +++$+++ m0 +++$+++ CAMERON +++$+++ They do to!\n'
b'L985 +++$+++ u0 +++$+++ m0 +++$+++ BIANCA +++$+++ I hope so.\n'
b'L984 +++$+++ u2 +++$+++ m0 +++$+++ CAMERON +++$+++ She okay?\n'
b"L925 +++$+++ u0 +++$+++ m0 +++$+++ BIANCA +++$+++ Let's go.\n"
b'L924 +++$+++ u2 +++$+++ m0 +++$+++ CAMERON +++$+++ Wow\n'
b"L872 +++$+++ u0 +++$+++ m0 +++$+++ BIANCA +++$+++ Okay -- you're gonna need to learn how to lie.\n"
b'L871 +++$+++ u2 +++$+++ m0 +++$+++ CAMERON +++$+++ No\n'
b'L870 +++$+++ u0 +++$+++ m0 +++$+++ BIANCA +++$+++ I\'m kidding.  You know how sometimes you just become this "persona"?  And you don\'t know how to quit?\n'
b'L869 +++$+++ u0 +++$+++ m0 +++$+++ BIANCA +++$+++ Like my fear of wearing pastels?\n'


# Created formatted data file

In [5]:
# Splits each line of the file into a dictionary of fields
def loadLines(fileName, fields):
    lines = {}
    with open(fileName, 'r', encoding='iso-8859-1') as f:
        for line in f:
            values = line.split(" +++$+++ ")
            # Extract fields
            lineObj = {}
            for i, field in enumerate(fields):
                lineObj[field] = values[i]
            lines[lineObj['lineID']] = lineObj
    return lines


# Groups fields of lines from `loadLines` into conversations based on *movie_conversations.txt*
def loadConversations(fileName, lines, fields):
    conversations = []
    with open(fileName, 'r', encoding='iso-8859-1') as f:
        for line in f:
            values = line.split(" +++$+++ ")
            # Extract fields
            convObj = {}
            for i, field in enumerate(fields):
                convObj[field] = values[i]
            # Convert string to list (convObj["utteranceIDs"] == "['L598485', 'L598486', ...]")
            utterance_id_pattern = re.compile('L[0-9]+')
            lineIds = utterance_id_pattern.findall(convObj["utteranceIDs"])
            # Reassemble lines
            convObj["lines"] = []
            for lineId in lineIds:
                convObj["lines"].append(lines[lineId])
            conversations.append(convObj)
    return conversations


# Extracts pairs of sentences from conversations
def extractSentencePairs(conversations):
    qa_pairs = []
    for conversation in conversations:
        # Iterate over all the lines of the conversation
        for i in range(len(conversation["lines"]) - 1):  # We ignore the last line (no answer for it)
            inputLine = conversation["lines"][i]["text"].strip()
            targetLine = conversation["lines"][i+1]["text"].strip()
            # Filter wrong samples (if one of the lists is empty)
            if inputLine and targetLine:
                qa_pairs.append([inputLine, targetLine])
    return qa_pairs

In [6]:
# Define path to new file
datafile = os.path.join(corpus, "formatted_movie_lines.txt")

delimiter = '\t'
# Unescape the delimiter
delimiter = str(codecs.decode(delimiter, "unicode_escape"))

# Initialize lines dict, conversations list, and field ids
lines = {}
conversations = []
MOVIE_LINES_FIELDS = ["lineID", "characterID", "movieID", "character", "text"]
MOVIE_CONVERSATIONS_FIELDS = ["character1ID", "character2ID", "movieID", "utteranceIDs"]

# Load lines and process conversations
print("\nProcessing corpus...")
lines = loadLines(os.path.join(corpus, "movie_lines.txt"), MOVIE_LINES_FIELDS)
print("\nLoading conversations...")
conversations = loadConversations(os.path.join(corpus, "movie_conversations.txt"),
                                  lines, MOVIE_CONVERSATIONS_FIELDS)

# Write new csv file
print("\nWriting newly formatted file...")
with open(datafile, 'w', encoding='utf-8') as outputfile:
    writer = csv.writer(outputfile, delimiter=delimiter, lineterminator='\n')
    for pair in extractSentencePairs(conversations):
        writer.writerow(pair)

# Print a sample of lines
print("\nSample lines from file:")
printLines(datafile)


Processing corpus...

Loading conversations...

Writing newly formatted file...

Sample lines from file:
b"Can we make this quick?  Roxanne Korrine and Andrew Barrett are having an incredibly horrendous public break- up on the quad.  Again.\tWell, I thought we'd start with pronunciation, if that's okay with you.\r\n"
b"Well, I thought we'd start with pronunciation, if that's okay with you.\tNot the hacking and gagging and spitting part.  Please.\r\n"
b"Not the hacking and gagging and spitting part.  Please.\tOkay... then how 'bout we try out some French cuisine.  Saturday?  Night?\r\n"
b"You're asking me out.  That's so cute. What's your name again?\tForget it.\r\n"
b"No, no, it's my fault -- we didn't have a proper introduction ---\tCameron.\r\n"
b"Cameron.\tThe thing is, Cameron -- I'm at the mercy of a particularly hideous breed of loser.  My sister.  I can't date until she does.\r\n"
b"The thing is, Cameron -- I'm at the mercy of a particularly hideous breed of loser.  My sister. 

# Load and trim data

In [7]:
# Default word tokens
PAD_token = 0  # Used for padding short sentences
SOS_token = 1  # Start-of-sentence token
EOS_token = 2  # End-of-sentence token

class Voc:
    def __init__(self, name):
        self.name = name
        self.trimmed = False
        self.word2index = {}
        self.word2count = {}
        self.index2word = {PAD_token: "PAD", SOS_token: "SOS", EOS_token: "EOS"}
        self.num_words = 3  # Count SOS, EOS, PAD

    def addSentence(self, sentence):
        for word in sentence.split(' '):
            self.addWord(word)

    def addWord(self, word):
        if word not in self.word2index:
            self.word2index[word] = self.num_words
            self.word2count[word] = 1
            self.index2word[self.num_words] = word
            self.num_words += 1
        else:
            self.word2count[word] += 1

    # Remove words below a certain count threshold
    def trim(self, min_count):
        if self.trimmed:
            return
        self.trimmed = True

        keep_words = []

        for k, v in self.word2count.items():
            if v >= min_count:
                keep_words.append(k)

        print('keep_words {} / {} = {:.4f}'.format(
            len(keep_words), len(self.word2index), len(keep_words) / len(self.word2index)
        ))

        # Reinitialize dictionaries
        self.word2index = {}
        self.word2count = {}
        self.index2word = {PAD_token: "PAD", SOS_token: "SOS", EOS_token: "EOS"}
        self.num_words = 3 # Count default tokens

        for word in keep_words:
            self.addWord(word)

In [8]:
MAX_LENGTH = 10  # Maximum sentence length to consider

# Turn a Unicode string to plain ASCII, thanks to
# https://stackoverflow.com/a/518232/2809427
def unicodeToAscii(s):
    return ''.join(
        c for c in unicodedata.normalize('NFD', s)
        if unicodedata.category(c) != 'Mn'
    )

# Lowercase, trim, and remove non-letter characters
def normalizeString(s):
    s = unicodeToAscii(s.lower().strip())
    s = re.sub(r"([.!?])", r" \1", s)
    s = re.sub(r"[^a-zA-Z.!?]+", r" ", s)
    s = re.sub(r"\s+", r" ", s).strip()
    return s

# Read query/response pairs and return a voc object
def readVocs(datafile, corpus_name):
    print("Reading lines...")
    # Read the file and split into lines
    lines = open(datafile, encoding='utf-8').\
        read().strip().split('\n')
    # Split every line into pairs and normalize
    pairs = [[normalizeString(s) for s in l.split('\t')] for l in lines]
    voc = Voc(corpus_name)
    return voc, pairs

# Returns True iff both sentences in a pair 'p' are under the MAX_LENGTH threshold
def filterPair(p):
    # Input sequences need to preserve the last word for EOS token
    return len(p[0].split(' ')) < MAX_LENGTH and len(p[1].split(' ')) < MAX_LENGTH

# Filter pairs using filterPair condition
def filterPairs(pairs):
    return [pair for pair in pairs if filterPair(pair)]

# Using the functions defined above, return a populated voc object and pairs list
def loadPrepareData(corpus, corpus_name, datafile, save_dir):
    print("Start preparing training data ...")
    voc, pairs = readVocs(datafile, corpus_name)
    print("Read {!s} sentence pairs".format(len(pairs)))
    pairs = filterPairs(pairs)
    print("Trimmed to {!s} sentence pairs".format(len(pairs)))
    print("Counting words...")
    for pair in pairs:
        voc.addSentence(pair[0])
        voc.addSentence(pair[1])
    print("Counted words:", voc.num_words)
    return voc, pairs


# Load/Assemble voc and pairs
save_dir = os.path.join("data", "save")
voc, pairs = loadPrepareData(corpus, corpus_name, datafile, save_dir)
# Print some pairs to validate
print("\npairs:")
for pair in pairs[:10]:
    print(pair)

Start preparing training data ...
Reading lines...
Read 221282 sentence pairs
Trimmed to 64271 sentence pairs
Counting words...
Counted words: 18008

pairs:
['there .', 'where ?']
['you have my word . as a gentleman', 'you re sweet .']
['hi .', 'looks like things worked out tonight huh ?']
['you know chastity ?', 'i believe we share an art instructor']
['have fun tonight ?', 'tons']
['well no . . .', 'then that s all you had to say .']
['then that s all you had to say .', 'but']
['but', 'you always been this selfish ?']
['do you listen to this crap ?', 'what crap ?']
['what good stuff ?', 'the real you .']


In [9]:
MIN_COUNT = 3    # Minimum word count threshold for trimming

def trimRareWords(voc, pairs, MIN_COUNT):
    # Trim words used under the MIN_COUNT from the voc
    voc.trim(MIN_COUNT)
    # Filter out pairs with trimmed words
    keep_pairs = []
    for pair in pairs:
        input_sentence = pair[0]
        output_sentence = pair[1]
        keep_input = True
        keep_output = True
        # Check input sentence
        for word in input_sentence.split(' '):
            if word not in voc.word2index:
                keep_input = False
                break
        # Check output sentence
        for word in output_sentence.split(' '):
            if word not in voc.word2index:
                keep_output = False
                break

        # Only keep pairs that do not contain trimmed word(s) in their input or output sentence
        if keep_input and keep_output:
            keep_pairs.append(pair)

    print("Trimmed from {} pairs to {}, {:.4f} of total".format(len(pairs), len(keep_pairs), len(keep_pairs) / len(pairs)))
    return keep_pairs


# Trim voc and pairs
pairs = trimRareWords(voc, pairs, MIN_COUNT)

keep_words 7823 / 18005 = 0.4345
Trimmed from 64271 pairs to 53165, 0.8272 of total


# Prepare data for modules

In [10]:
def indexesFromSentence(voc, sentence):
    return [voc.word2index[word] for word in sentence.split(' ')] + [EOS_token]


def zeroPadding(l, fillvalue=PAD_token):
    return list(itertools.zip_longest(*l, fillvalue=fillvalue))

def binaryMatrix(l, value=PAD_token):
    m = []
    for i, seq in enumerate(l):
        m.append([])
        for token in seq:
            if token == PAD_token:
                m[i].append(0)
            else:
                m[i].append(1)
    return m

# Returns padded input sequence tensor and lengths
def inputVar(l, voc):
    indexes_batch = [indexesFromSentence(voc, sentence) for sentence in l]
    lengths = torch.tensor([len(indexes) for indexes in indexes_batch])
    padList = zeroPadding(indexes_batch)
    padVar = torch.LongTensor(padList)
    return padVar, lengths

# Returns padded target sequence tensor, padding mask, and max target length
def outputVar(l, voc):
    indexes_batch = [indexesFromSentence(voc, sentence) for sentence in l]
    max_target_len = max([len(indexes) for indexes in indexes_batch])
    padList = zeroPadding(indexes_batch)
    mask = binaryMatrix(padList)
    mask = torch.BoolTensor(mask)
    padVar = torch.LongTensor(padList)
    return padVar, mask, max_target_len

# Returns all items for a given batch of pairs
def batch2TrainData(voc, pair_batch):
    pair_batch.sort(key=lambda x: len(x[0].split(" ")), reverse=True)
    input_batch, output_batch = [], []
    for pair in pair_batch:
        input_batch.append(pair[0])
        output_batch.append(pair[1])
    inp, lengths = inputVar(input_batch, voc)
    output, mask, max_target_len = outputVar(output_batch, voc)
    return inp, lengths, output, mask, max_target_len


# Example for validation
small_batch_size = 5
batches = batch2TrainData(voc, [random.choice(pairs) for _ in range(small_batch_size)])
input_variable, lengths, target_variable, mask, max_target_len = batches

print("input_variable:", input_variable)
print("lengths:", lengths)
print("target_variable:", target_variable)
print("mask:", mask)
print("max_target_len:", max_target_len)

input_variable: tensor([[  83, 5729,   50,  115,   34],
        [ 349,    4,   37,   61,    4],
        [   4,    4,   53,   38,   53],
        [  25,    4,  341,  266,  164],
        [ 200, 6819,  519,    6,    4],
        [  82,  883,  828,    2,    2],
        [ 467,    4,    6,    0,    0],
        [ 196, 1886,    2,    0,    0],
        [   4,    6,    0,    0,    0],
        [   2,    2,    0,    0,    0]])
lengths: tensor([10, 10,  8,  6,  6])
target_variable: tensor([[  36,   27,   67,   61,  164],
        [  37,   14, 2062,   37,    6],
        [   9,  511,    4,  274,  124],
        [ 121,  572,    2,    4,    4],
        [  66,    4,    0,    2, 3350],
        [  66,    2,    0,    0, 5082],
        [   2,    0,    0,    0,    4],
        [   0,    0,    0,    0,    2]])
mask: tensor([[ True,  True,  True,  True,  True],
        [ True,  True,  True,  True,  True],
        [ True,  True,  True,  True,  True],
        [ True,  True,  True,  True,  True],
        [ True,  True

# Define models Seq2Seqmodel

In [11]:
class EncoderRNN(nn.Module):
    def __init__(self, hidden_size, embedding, n_layers=1, dropout=0):
        super(EncoderRNN, self).__init__()
        self.n_layers = n_layers
        self.hidden_size = hidden_size
        self.embedding = embedding

        # Initialize GRU; the input_size and hidden_size params are both set to 'hidden_size'
        #   because our input size is a word embedding with number of features == hidden_size
        self.gru = nn.GRU(hidden_size, hidden_size, n_layers,
                          dropout=(0 if n_layers == 1 else dropout), bidirectional=True)

    def forward(self, input_seq, input_lengths, hidden=None):
        # Convert word indexes to embeddings
        embedded = self.embedding(input_seq)
        # Pack padded batch of sequences for RNN module
        packed = nn.utils.rnn.pack_padded_sequence(embedded, input_lengths)
        # Forward pass through GRU
        outputs, hidden = self.gru(packed, hidden)
        # Unpack padding
        outputs, _ = nn.utils.rnn.pad_packed_sequence(outputs)
        # Sum bidirectional GRU outputs
        outputs = outputs[:, :, :self.hidden_size] + outputs[:, : ,self.hidden_size:]
        # Return output and final hidden state
        return outputs, hidden

# Decoder

In [12]:
# Luong attention layer
class Attn(nn.Module):
    def __init__(self, method, hidden_size):
        super(Attn, self).__init__()
        self.method = method
        if self.method not in ['dot', 'general', 'concat']:
            raise ValueError(self.method, "is not an appropriate attention method.")
        self.hidden_size = hidden_size
        if self.method == 'general':
            self.attn = nn.Linear(self.hidden_size, hidden_size)
        elif self.method == 'concat':
            self.attn = nn.Linear(self.hidden_size * 2, hidden_size)
            self.v = nn.Parameter(torch.FloatTensor(hidden_size))

    def dot_score(self, hidden, encoder_output):
        return torch.sum(hidden * encoder_output, dim=2)

    def general_score(self, hidden, encoder_output):
        energy = self.attn(encoder_output)
        return torch.sum(hidden * energy, dim=2)

    def concat_score(self, hidden, encoder_output):
        energy = self.attn(torch.cat((hidden.expand(encoder_output.size(0), -1, -1), encoder_output), 2)).tanh()
        return torch.sum(self.v * energy, dim=2)

    def forward(self, hidden, encoder_outputs):
        # Calculate the attention weights (energies) based on the given method
        if self.method == 'general':
            attn_energies = self.general_score(hidden, encoder_outputs)
        elif self.method == 'concat':
            attn_energies = self.concat_score(hidden, encoder_outputs)
        elif self.method == 'dot':
            attn_energies = self.dot_score(hidden, encoder_outputs)

        # Transpose max_length and batch_size dimensions
        attn_energies = attn_energies.t()

        # Return the softmax normalized probability scores (with added dimension)
        return F.softmax(attn_energies, dim=1).unsqueeze(1)

In [13]:
class LuongAttnDecoderRNN(nn.Module):
    def __init__(self, attn_model, embedding, hidden_size, output_size, n_layers=1, dropout=0.1):
        super(LuongAttnDecoderRNN, self).__init__()

        # Keep for reference
        self.attn_model = attn_model
        self.hidden_size = hidden_size
        self.output_size = output_size
        self.n_layers = n_layers
        self.dropout = dropout

        # Define layers
        self.embedding = embedding
        self.embedding_dropout = nn.Dropout(dropout)
        self.gru = nn.GRU(hidden_size, hidden_size, n_layers, dropout=(0 if n_layers == 1 else dropout))
        self.concat = nn.Linear(hidden_size * 2, hidden_size)
        self.out = nn.Linear(hidden_size, output_size)

        self.attn = Attn(attn_model, hidden_size)

    def forward(self, input_step, last_hidden, encoder_outputs):
        # Note: we run this one step (word) at a time
        # Get embedding of current input word
        embedded = self.embedding(input_step)
        embedded = self.embedding_dropout(embedded)
        # Forward through unidirectional GRU
        rnn_output, hidden = self.gru(embedded, last_hidden)
        # Calculate attention weights from the current GRU output
        attn_weights = self.attn(rnn_output, encoder_outputs)
        # Multiply attention weights to encoder outputs to get new "weighted sum" context vector
        context = attn_weights.bmm(encoder_outputs.transpose(0, 1))
        # Concatenate weighted context vector and GRU output using Luong eq. 5
        rnn_output = rnn_output.squeeze(0)
        context = context.squeeze(1)
        concat_input = torch.cat((rnn_output, context), 1)
        concat_output = torch.tanh(self.concat(concat_input))
        # Predict next word using Luong eq. 6
        output = self.out(concat_output)
        output = F.softmax(output, dim=1)
        # Return output and final hidden state
        return output, hidden

# Define training Procedure masked loss

In [14]:
def maskNLLLoss(inp, target, mask):
    nTotal = mask.sum()
    crossEntropy = -torch.log(torch.gather(inp, 1, target.view(-1, 1)).squeeze(1))
    loss = crossEntropy.masked_select(mask).mean()
    loss = loss.to(device)
    return loss, nTotal.item()

# Single training iteration

In [15]:
def train(input_variable, lengths, target_variable, mask, max_target_len, encoder, decoder, embedding,
          encoder_optimizer, decoder_optimizer, batch_size, clip, max_length=MAX_LENGTH):

    # Zero gradients
    encoder_optimizer.zero_grad()
    decoder_optimizer.zero_grad()

    # Set device options
    input_variable = input_variable.to(device)
    lengths = lengths.to(device)
    target_variable = target_variable.to(device)
    mask = mask.to(device)

    # Initialize variables
    loss = 0
    print_losses = []
    n_totals = 0

    # Forward pass through encoder
    encoder_outputs, encoder_hidden = encoder(input_variable, lengths)

    # Create initial decoder input (start with SOS tokens for each sentence)
    decoder_input = torch.LongTensor([[SOS_token for _ in range(batch_size)]])
    decoder_input = decoder_input.to(device)

    # Set initial decoder hidden state to the encoder's final hidden state
    decoder_hidden = encoder_hidden[:decoder.n_layers]

    # Determine if we are using teacher forcing this iteration
    use_teacher_forcing = True if random.random() < teacher_forcing_ratio else False

    # Forward batch of sequences through decoder one time step at a time
    if use_teacher_forcing:
        for t in range(max_target_len):
            decoder_output, decoder_hidden = decoder(
                decoder_input, decoder_hidden, encoder_outputs
            )
            # Teacher forcing: next input is current target
            decoder_input = target_variable[t].view(1, -1)
            # Calculate and accumulate loss
            mask_loss, nTotal = maskNLLLoss(decoder_output, target_variable[t], mask[t])
            loss += mask_loss
            print_losses.append(mask_loss.item() * nTotal)
            n_totals += nTotal
    else:
        for t in range(max_target_len):
            decoder_output, decoder_hidden = decoder(
                decoder_input, decoder_hidden, encoder_outputs
            )
            # No teacher forcing: next input is decoder's own current output
            _, topi = decoder_output.topk(1)
            decoder_input = torch.LongTensor([[topi[i][0] for i in range(batch_size)]])
            decoder_input = decoder_input.to(device)
            # Calculate and accumulate loss
            mask_loss, nTotal = maskNLLLoss(decoder_output, target_variable[t], mask[t])
            loss += mask_loss
            print_losses.append(mask_loss.item() * nTotal)
            n_totals += nTotal

    # Perform backpropatation
    loss.backward()

    # Clip gradients: gradients are modified in place
    _ = nn.utils.clip_grad_norm_(encoder.parameters(), clip)
    _ = nn.utils.clip_grad_norm_(decoder.parameters(), clip)

    # Adjust model weights
    encoder_optimizer.step()
    decoder_optimizer.step()

    return sum(print_losses) / n_totals

# Training iterations

In [16]:
def trainIters(model_name, voc, pairs, encoder, decoder, encoder_optimizer, decoder_optimizer, embedding, encoder_n_layers, decoder_n_layers, save_dir, n_iteration, batch_size, print_every, save_every, clip, corpus_name, loadFilename):

    # Load batches for each iteration
    training_batches = [batch2TrainData(voc, [random.choice(pairs) for _ in range(batch_size)])
                      for _ in range(n_iteration)]

    # Initializations
    print('Initializing ...')
    start_iteration = 1
    print_loss = 0
    if loadFilename:
        start_iteration = checkpoint['iteration'] + 1

    # Training loop
    print("Training...")
    for iteration in range(start_iteration, n_iteration + 1):
        training_batch = training_batches[iteration - 1]
        # Extract fields from batch
        input_variable, lengths, target_variable, mask, max_target_len = training_batch

        # Run a training iteration with batch
        loss = train(input_variable, lengths, target_variable, mask, max_target_len, encoder,
                     decoder, embedding, encoder_optimizer, decoder_optimizer, batch_size, clip)
        print_loss += loss

        # Print progress
        if iteration % print_every == 0:
            print_loss_avg = print_loss / print_every
            print("Iteration: {}; Percent complete: {:.1f}%; Average loss: {:.4f}".format(iteration, iteration / n_iteration * 100, print_loss_avg))
            print_loss = 0

        # Save checkpoint
        if (iteration % save_every == 0):
            directory = os.path.join(save_dir, model_name, corpus_name, '{}-{}_{}'.format(encoder_n_layers, decoder_n_layers, hidden_size))
            if not os.path.exists(directory):
                os.makedirs(directory)
            torch.save({
                'iteration': iteration,
                'en': encoder.state_dict(),
                'de': decoder.state_dict(),
                'en_opt': encoder_optimizer.state_dict(),
                'de_opt': decoder_optimizer.state_dict(),
                'loss': loss,
                'voc_dict': voc.__dict__,
                'embedding': embedding.state_dict()
            }, os.path.join(directory, '{}_{}.tar'.format(iteration, 'checkpoint')))

# Define evaluation

In [17]:
class GreedySearchDecoder(nn.Module):
    def __init__(self, encoder, decoder):
        super(GreedySearchDecoder, self).__init__()
        self.encoder = encoder
        self.decoder = decoder

    def forward(self, input_seq, input_length, max_length):
        # Forward input through encoder model
        encoder_outputs, encoder_hidden = self.encoder(input_seq, input_length)
        # Prepare encoder's final hidden layer to be first hidden input to the decoder
        decoder_hidden = encoder_hidden[:decoder.n_layers]
        # Initialize decoder input with SOS_token
        decoder_input = torch.ones(1, 1, device=device, dtype=torch.long) * SOS_token
        # Initialize tensors to append decoded words to
        all_tokens = torch.zeros([0], device=device, dtype=torch.long)
        all_scores = torch.zeros([0], device=device)
        # Iteratively decode one word token at a time
        for _ in range(max_length):
            # Forward pass through decoder
            decoder_output, decoder_hidden = self.decoder(decoder_input, decoder_hidden, encoder_outputs)
            # Obtain most likely word token and its softmax score
            decoder_scores, decoder_input = torch.max(decoder_output, dim=1)
            # Record token and score
            all_tokens = torch.cat((all_tokens, decoder_input), dim=0)
            all_scores = torch.cat((all_scores, decoder_scores), dim=0)
            # Prepare current token to be next decoder input (add a dimension)
            decoder_input = torch.unsqueeze(decoder_input, 0)
        # Return collections of word tokens and scores
        return all_tokens, all_scores

In [18]:
def evaluate(encoder, decoder, searcher, voc, sentence, max_length=MAX_LENGTH):
    ### Format input sentence as a batch
    # words -> indexes
    indexes_batch = [indexesFromSentence(voc, sentence)]
    # Create lengths tensor
    lengths = torch.tensor([len(indexes) for indexes in indexes_batch])
    # Transpose dimensions of batch to match models' expectations
    input_batch = torch.LongTensor(indexes_batch).transpose(0, 1)
    # Use appropriate device
    input_batch = input_batch.to(device)
    lengths = lengths.to(device)
    # Decode sentence with searcher
    tokens, scores = searcher(input_batch, lengths, max_length)
    # indexes -> words
    decoded_words = [voc.index2word[token.item()] for token in tokens]
    return decoded_words


def evaluateInput(encoder, decoder, searcher, voc):
    input_sentence = ''
    while(1):
        try:
            # Get input sentence
            input_sentence = input('> ')
            # Check if it is quit case
            if input_sentence == 'q' or input_sentence == 'quit': break
            # Normalize sentence
            input_sentence = normalizeString(input_sentence)
            # Evaluate sentence
            output_words = evaluate(encoder, decoder, searcher, voc, input_sentence)
            # Format and print response sentence
            output_words[:] = [x for x in output_words if not (x == 'EOS' or x == 'PAD')]
            print('Bot:', ' '.join(output_words))

        except KeyError:
            print("Error: Encountered unknown word.")

# Run model

In [19]:
# Configure models
model_name = 'cb_model'
attn_model = 'dot'
#attn_model = 'general'
#attn_model = 'concat'
hidden_size = 500
encoder_n_layers = 2
decoder_n_layers = 2
dropout = 0.1
batch_size = 64

# Set checkpoint to load from; set to None if starting from scratch
loadFilename = None
checkpoint_iter = 4000
#loadFilename = os.path.join(save_dir, model_name, corpus_name,
#                            '{}-{}_{}'.format(encoder_n_layers, decoder_n_layers, hidden_size),
#                            '{}_checkpoint.tar'.format(checkpoint_iter))


# Load model if a loadFilename is provided
if loadFilename:
    # If loading on same machine the model was trained on
    checkpoint = torch.load(loadFilename)
    # If loading a model trained on GPU to CPU
    #checkpoint = torch.load(loadFilename, map_location=torch.device('cpu'))
    encoder_sd = checkpoint['en']
    decoder_sd = checkpoint['de']
    encoder_optimizer_sd = checkpoint['en_opt']
    decoder_optimizer_sd = checkpoint['de_opt']
    embedding_sd = checkpoint['embedding']
    voc.__dict__ = checkpoint['voc_dict']


print('Building encoder and decoder ...')
# Initialize word embeddings
embedding = nn.Embedding(voc.num_words, hidden_size)
if loadFilename:
    embedding.load_state_dict(embedding_sd)
# Initialize encoder & decoder models
encoder = EncoderRNN(hidden_size, embedding, encoder_n_layers, dropout)
decoder = LuongAttnDecoderRNN(attn_model, embedding, hidden_size, voc.num_words, decoder_n_layers, dropout)
if loadFilename:
    encoder.load_state_dict(encoder_sd)
    decoder.load_state_dict(decoder_sd)
# Use appropriate device
encoder = encoder.to(device)
decoder = decoder.to(device)
print('Models built and ready to go!')

Building encoder and decoder ...
Models built and ready to go!


# Run training

In [20]:
# Configure training/optimization
clip = 50.0
teacher_forcing_ratio = 1.0
learning_rate = 0.0001
decoder_learning_ratio = 5.0
n_iteration = 4000
print_every = 1
save_every = 500

# Ensure dropout layers are in train mode
encoder.train()
decoder.train()

# Initialize optimizers
print('Building optimizers ...')
encoder_optimizer = optim.Adam(encoder.parameters(), lr=learning_rate)
decoder_optimizer = optim.Adam(decoder.parameters(), lr=learning_rate * decoder_learning_ratio)
if loadFilename:
    encoder_optimizer.load_state_dict(encoder_optimizer_sd)
    decoder_optimizer.load_state_dict(decoder_optimizer_sd)

# If you have cuda, configure cuda to call
for state in encoder_optimizer.state.values():
    for k, v in state.items():
        if isinstance(v, torch.Tensor):
            state[k] = v.cuda()

for state in decoder_optimizer.state.values():
    for k, v in state.items():
        if isinstance(v, torch.Tensor):
            state[k] = v.cuda()

# Run training iterations
print("Starting Training!")
trainIters(model_name, voc, pairs, encoder, decoder, encoder_optimizer, decoder_optimizer,
           embedding, encoder_n_layers, decoder_n_layers, save_dir, n_iteration, batch_size,
           print_every, save_every, clip, corpus_name, loadFilename)

Building optimizers ...
Starting Training!
Initializing ...
Training...
Iteration: 1; Percent complete: 0.0%; Average loss: 8.9720
Iteration: 2; Percent complete: 0.1%; Average loss: 8.8476
Iteration: 3; Percent complete: 0.1%; Average loss: 8.6641
Iteration: 4; Percent complete: 0.1%; Average loss: 8.4524
Iteration: 5; Percent complete: 0.1%; Average loss: 8.0627
Iteration: 6; Percent complete: 0.1%; Average loss: 7.5146
Iteration: 7; Percent complete: 0.2%; Average loss: 6.7904
Iteration: 8; Percent complete: 0.2%; Average loss: 6.7096
Iteration: 9; Percent complete: 0.2%; Average loss: 6.8681
Iteration: 10; Percent complete: 0.2%; Average loss: 6.7067
Iteration: 11; Percent complete: 0.3%; Average loss: 6.3177
Iteration: 12; Percent complete: 0.3%; Average loss: 5.7914
Iteration: 13; Percent complete: 0.3%; Average loss: 5.9762
Iteration: 14; Percent complete: 0.4%; Average loss: 6.1447
Iteration: 15; Percent complete: 0.4%; Average loss: 5.6473
Iteration: 16; Percent complete: 0.4%

Iteration: 136; Percent complete: 3.4%; Average loss: 4.4568
Iteration: 137; Percent complete: 3.4%; Average loss: 4.3569
Iteration: 138; Percent complete: 3.5%; Average loss: 4.1795
Iteration: 139; Percent complete: 3.5%; Average loss: 4.5491
Iteration: 140; Percent complete: 3.5%; Average loss: 4.2565
Iteration: 141; Percent complete: 3.5%; Average loss: 4.1996
Iteration: 142; Percent complete: 3.5%; Average loss: 4.3889
Iteration: 143; Percent complete: 3.6%; Average loss: 3.9929
Iteration: 144; Percent complete: 3.6%; Average loss: 4.3470
Iteration: 145; Percent complete: 3.6%; Average loss: 4.3232
Iteration: 146; Percent complete: 3.6%; Average loss: 4.2612
Iteration: 147; Percent complete: 3.7%; Average loss: 4.0000
Iteration: 148; Percent complete: 3.7%; Average loss: 4.1665
Iteration: 149; Percent complete: 3.7%; Average loss: 4.3501
Iteration: 150; Percent complete: 3.8%; Average loss: 4.2167
Iteration: 151; Percent complete: 3.8%; Average loss: 4.1819
Iteration: 152; Percent 

Iteration: 271; Percent complete: 6.8%; Average loss: 3.9964
Iteration: 272; Percent complete: 6.8%; Average loss: 3.9649
Iteration: 273; Percent complete: 6.8%; Average loss: 4.0699
Iteration: 274; Percent complete: 6.9%; Average loss: 3.8320
Iteration: 275; Percent complete: 6.9%; Average loss: 3.8823
Iteration: 276; Percent complete: 6.9%; Average loss: 4.0651
Iteration: 277; Percent complete: 6.9%; Average loss: 3.7798
Iteration: 278; Percent complete: 7.0%; Average loss: 3.8114
Iteration: 279; Percent complete: 7.0%; Average loss: 3.5968
Iteration: 280; Percent complete: 7.0%; Average loss: 4.0786
Iteration: 281; Percent complete: 7.0%; Average loss: 3.9365
Iteration: 282; Percent complete: 7.0%; Average loss: 3.7722
Iteration: 283; Percent complete: 7.1%; Average loss: 4.0894
Iteration: 284; Percent complete: 7.1%; Average loss: 3.7693
Iteration: 285; Percent complete: 7.1%; Average loss: 3.9470
Iteration: 286; Percent complete: 7.1%; Average loss: 3.8241
Iteration: 287; Percent 

Iteration: 406; Percent complete: 10.2%; Average loss: 3.7418
Iteration: 407; Percent complete: 10.2%; Average loss: 3.5283
Iteration: 408; Percent complete: 10.2%; Average loss: 3.8312
Iteration: 409; Percent complete: 10.2%; Average loss: 4.0476
Iteration: 410; Percent complete: 10.2%; Average loss: 4.1199
Iteration: 411; Percent complete: 10.3%; Average loss: 3.9256
Iteration: 412; Percent complete: 10.3%; Average loss: 3.7066
Iteration: 413; Percent complete: 10.3%; Average loss: 3.9826
Iteration: 414; Percent complete: 10.3%; Average loss: 3.9383
Iteration: 415; Percent complete: 10.4%; Average loss: 4.0740
Iteration: 416; Percent complete: 10.4%; Average loss: 3.7946
Iteration: 417; Percent complete: 10.4%; Average loss: 3.4952
Iteration: 418; Percent complete: 10.4%; Average loss: 3.7914
Iteration: 419; Percent complete: 10.5%; Average loss: 4.0743
Iteration: 420; Percent complete: 10.5%; Average loss: 3.7636
Iteration: 421; Percent complete: 10.5%; Average loss: 3.8718
Iteratio

Iteration: 539; Percent complete: 13.5%; Average loss: 3.5928
Iteration: 540; Percent complete: 13.5%; Average loss: 3.6554
Iteration: 541; Percent complete: 13.5%; Average loss: 3.6482
Iteration: 542; Percent complete: 13.6%; Average loss: 3.6627
Iteration: 543; Percent complete: 13.6%; Average loss: 3.9232
Iteration: 544; Percent complete: 13.6%; Average loss: 3.5846
Iteration: 545; Percent complete: 13.6%; Average loss: 3.6127
Iteration: 546; Percent complete: 13.7%; Average loss: 3.8205
Iteration: 547; Percent complete: 13.7%; Average loss: 3.8851
Iteration: 548; Percent complete: 13.7%; Average loss: 3.7689
Iteration: 549; Percent complete: 13.7%; Average loss: 3.7659
Iteration: 550; Percent complete: 13.8%; Average loss: 3.5871
Iteration: 551; Percent complete: 13.8%; Average loss: 3.4859
Iteration: 552; Percent complete: 13.8%; Average loss: 3.6284
Iteration: 553; Percent complete: 13.8%; Average loss: 3.7728
Iteration: 554; Percent complete: 13.9%; Average loss: 3.5703
Iteratio

Iteration: 672; Percent complete: 16.8%; Average loss: 3.7183
Iteration: 673; Percent complete: 16.8%; Average loss: 3.5338
Iteration: 674; Percent complete: 16.9%; Average loss: 3.3489
Iteration: 675; Percent complete: 16.9%; Average loss: 3.5327
Iteration: 676; Percent complete: 16.9%; Average loss: 3.7174
Iteration: 677; Percent complete: 16.9%; Average loss: 3.6501
Iteration: 678; Percent complete: 17.0%; Average loss: 3.6783
Iteration: 679; Percent complete: 17.0%; Average loss: 3.6744
Iteration: 680; Percent complete: 17.0%; Average loss: 3.6823
Iteration: 681; Percent complete: 17.0%; Average loss: 3.6504
Iteration: 682; Percent complete: 17.1%; Average loss: 3.2152
Iteration: 683; Percent complete: 17.1%; Average loss: 3.4757
Iteration: 684; Percent complete: 17.1%; Average loss: 3.5309
Iteration: 685; Percent complete: 17.1%; Average loss: 3.8375
Iteration: 686; Percent complete: 17.2%; Average loss: 3.4298
Iteration: 687; Percent complete: 17.2%; Average loss: 3.8613
Iteratio

Iteration: 805; Percent complete: 20.1%; Average loss: 3.8599
Iteration: 806; Percent complete: 20.2%; Average loss: 3.4034
Iteration: 807; Percent complete: 20.2%; Average loss: 3.3948
Iteration: 808; Percent complete: 20.2%; Average loss: 3.2869
Iteration: 809; Percent complete: 20.2%; Average loss: 3.5936
Iteration: 810; Percent complete: 20.2%; Average loss: 3.6189
Iteration: 811; Percent complete: 20.3%; Average loss: 3.6886
Iteration: 812; Percent complete: 20.3%; Average loss: 3.6435
Iteration: 813; Percent complete: 20.3%; Average loss: 3.3750
Iteration: 814; Percent complete: 20.3%; Average loss: 3.5358
Iteration: 815; Percent complete: 20.4%; Average loss: 3.4885
Iteration: 816; Percent complete: 20.4%; Average loss: 3.2812
Iteration: 817; Percent complete: 20.4%; Average loss: 3.5414
Iteration: 818; Percent complete: 20.4%; Average loss: 3.3202
Iteration: 819; Percent complete: 20.5%; Average loss: 3.6395
Iteration: 820; Percent complete: 20.5%; Average loss: 3.6640
Iteratio

Iteration: 938; Percent complete: 23.4%; Average loss: 3.3114
Iteration: 939; Percent complete: 23.5%; Average loss: 3.5513
Iteration: 940; Percent complete: 23.5%; Average loss: 3.4950
Iteration: 941; Percent complete: 23.5%; Average loss: 3.4664
Iteration: 942; Percent complete: 23.5%; Average loss: 3.3757
Iteration: 943; Percent complete: 23.6%; Average loss: 3.4534
Iteration: 944; Percent complete: 23.6%; Average loss: 3.4993
Iteration: 945; Percent complete: 23.6%; Average loss: 3.3854
Iteration: 946; Percent complete: 23.6%; Average loss: 3.7461
Iteration: 947; Percent complete: 23.7%; Average loss: 3.5259
Iteration: 948; Percent complete: 23.7%; Average loss: 3.5501
Iteration: 949; Percent complete: 23.7%; Average loss: 3.5392
Iteration: 950; Percent complete: 23.8%; Average loss: 3.3685
Iteration: 951; Percent complete: 23.8%; Average loss: 3.1466
Iteration: 952; Percent complete: 23.8%; Average loss: 3.4582
Iteration: 953; Percent complete: 23.8%; Average loss: 3.5643
Iteratio

Iteration: 1070; Percent complete: 26.8%; Average loss: 3.4467
Iteration: 1071; Percent complete: 26.8%; Average loss: 3.5187
Iteration: 1072; Percent complete: 26.8%; Average loss: 3.4735
Iteration: 1073; Percent complete: 26.8%; Average loss: 3.5839
Iteration: 1074; Percent complete: 26.9%; Average loss: 3.2261
Iteration: 1075; Percent complete: 26.9%; Average loss: 3.4356
Iteration: 1076; Percent complete: 26.9%; Average loss: 3.5031
Iteration: 1077; Percent complete: 26.9%; Average loss: 3.2258
Iteration: 1078; Percent complete: 27.0%; Average loss: 3.6031
Iteration: 1079; Percent complete: 27.0%; Average loss: 3.3530
Iteration: 1080; Percent complete: 27.0%; Average loss: 3.3731
Iteration: 1081; Percent complete: 27.0%; Average loss: 3.4405
Iteration: 1082; Percent complete: 27.1%; Average loss: 3.4696
Iteration: 1083; Percent complete: 27.1%; Average loss: 3.4158
Iteration: 1084; Percent complete: 27.1%; Average loss: 3.5547
Iteration: 1085; Percent complete: 27.1%; Average loss:

Iteration: 1201; Percent complete: 30.0%; Average loss: 3.2570
Iteration: 1202; Percent complete: 30.0%; Average loss: 3.3990
Iteration: 1203; Percent complete: 30.1%; Average loss: 3.4135
Iteration: 1204; Percent complete: 30.1%; Average loss: 3.3578
Iteration: 1205; Percent complete: 30.1%; Average loss: 3.3978
Iteration: 1206; Percent complete: 30.1%; Average loss: 3.2787
Iteration: 1207; Percent complete: 30.2%; Average loss: 3.3066
Iteration: 1208; Percent complete: 30.2%; Average loss: 3.5528
Iteration: 1209; Percent complete: 30.2%; Average loss: 3.4361
Iteration: 1210; Percent complete: 30.2%; Average loss: 3.1863
Iteration: 1211; Percent complete: 30.3%; Average loss: 3.6693
Iteration: 1212; Percent complete: 30.3%; Average loss: 3.2723
Iteration: 1213; Percent complete: 30.3%; Average loss: 3.5762
Iteration: 1214; Percent complete: 30.3%; Average loss: 3.5534
Iteration: 1215; Percent complete: 30.4%; Average loss: 3.4710
Iteration: 1216; Percent complete: 30.4%; Average loss:

Iteration: 1332; Percent complete: 33.3%; Average loss: 3.0602
Iteration: 1333; Percent complete: 33.3%; Average loss: 3.4509
Iteration: 1334; Percent complete: 33.4%; Average loss: 3.5283
Iteration: 1335; Percent complete: 33.4%; Average loss: 3.6431
Iteration: 1336; Percent complete: 33.4%; Average loss: 3.2876
Iteration: 1337; Percent complete: 33.4%; Average loss: 3.2986
Iteration: 1338; Percent complete: 33.5%; Average loss: 3.3539
Iteration: 1339; Percent complete: 33.5%; Average loss: 3.1783
Iteration: 1340; Percent complete: 33.5%; Average loss: 3.3391
Iteration: 1341; Percent complete: 33.5%; Average loss: 3.4322
Iteration: 1342; Percent complete: 33.6%; Average loss: 3.4478
Iteration: 1343; Percent complete: 33.6%; Average loss: 3.4667
Iteration: 1344; Percent complete: 33.6%; Average loss: 3.3647
Iteration: 1345; Percent complete: 33.6%; Average loss: 3.0886
Iteration: 1346; Percent complete: 33.7%; Average loss: 3.4341
Iteration: 1347; Percent complete: 33.7%; Average loss:

Iteration: 1463; Percent complete: 36.6%; Average loss: 3.2614
Iteration: 1464; Percent complete: 36.6%; Average loss: 3.4189
Iteration: 1465; Percent complete: 36.6%; Average loss: 3.3117
Iteration: 1466; Percent complete: 36.6%; Average loss: 3.2764
Iteration: 1467; Percent complete: 36.7%; Average loss: 3.3991
Iteration: 1468; Percent complete: 36.7%; Average loss: 3.4279
Iteration: 1469; Percent complete: 36.7%; Average loss: 3.4384
Iteration: 1470; Percent complete: 36.8%; Average loss: 3.3649
Iteration: 1471; Percent complete: 36.8%; Average loss: 3.1393
Iteration: 1472; Percent complete: 36.8%; Average loss: 3.2994
Iteration: 1473; Percent complete: 36.8%; Average loss: 3.2885
Iteration: 1474; Percent complete: 36.9%; Average loss: 3.2753
Iteration: 1475; Percent complete: 36.9%; Average loss: 3.2802
Iteration: 1476; Percent complete: 36.9%; Average loss: 3.0745
Iteration: 1477; Percent complete: 36.9%; Average loss: 3.4322
Iteration: 1478; Percent complete: 37.0%; Average loss:

Iteration: 1594; Percent complete: 39.9%; Average loss: 3.1777
Iteration: 1595; Percent complete: 39.9%; Average loss: 2.9457
Iteration: 1596; Percent complete: 39.9%; Average loss: 3.2500
Iteration: 1597; Percent complete: 39.9%; Average loss: 3.1967
Iteration: 1598; Percent complete: 40.0%; Average loss: 3.1571
Iteration: 1599; Percent complete: 40.0%; Average loss: 3.4109
Iteration: 1600; Percent complete: 40.0%; Average loss: 3.3546
Iteration: 1601; Percent complete: 40.0%; Average loss: 3.5809
Iteration: 1602; Percent complete: 40.1%; Average loss: 3.3274
Iteration: 1603; Percent complete: 40.1%; Average loss: 3.4118
Iteration: 1604; Percent complete: 40.1%; Average loss: 3.3986
Iteration: 1605; Percent complete: 40.1%; Average loss: 3.2810
Iteration: 1606; Percent complete: 40.2%; Average loss: 3.0650
Iteration: 1607; Percent complete: 40.2%; Average loss: 3.2731
Iteration: 1608; Percent complete: 40.2%; Average loss: 3.3131
Iteration: 1609; Percent complete: 40.2%; Average loss:

Iteration: 1725; Percent complete: 43.1%; Average loss: 3.0063
Iteration: 1726; Percent complete: 43.1%; Average loss: 3.1457
Iteration: 1727; Percent complete: 43.2%; Average loss: 3.1996
Iteration: 1728; Percent complete: 43.2%; Average loss: 3.0807
Iteration: 1729; Percent complete: 43.2%; Average loss: 3.1039
Iteration: 1730; Percent complete: 43.2%; Average loss: 3.2881
Iteration: 1731; Percent complete: 43.3%; Average loss: 3.2732
Iteration: 1732; Percent complete: 43.3%; Average loss: 2.9659
Iteration: 1733; Percent complete: 43.3%; Average loss: 3.2650
Iteration: 1734; Percent complete: 43.4%; Average loss: 3.5122
Iteration: 1735; Percent complete: 43.4%; Average loss: 3.1278
Iteration: 1736; Percent complete: 43.4%; Average loss: 3.1145
Iteration: 1737; Percent complete: 43.4%; Average loss: 3.2513
Iteration: 1738; Percent complete: 43.5%; Average loss: 3.3307
Iteration: 1739; Percent complete: 43.5%; Average loss: 3.0031
Iteration: 1740; Percent complete: 43.5%; Average loss:

Iteration: 1856; Percent complete: 46.4%; Average loss: 3.3573
Iteration: 1857; Percent complete: 46.4%; Average loss: 3.0636
Iteration: 1858; Percent complete: 46.5%; Average loss: 2.8284
Iteration: 1859; Percent complete: 46.5%; Average loss: 3.1034
Iteration: 1860; Percent complete: 46.5%; Average loss: 3.0079
Iteration: 1861; Percent complete: 46.5%; Average loss: 3.1794
Iteration: 1862; Percent complete: 46.6%; Average loss: 3.2284
Iteration: 1863; Percent complete: 46.6%; Average loss: 3.1870
Iteration: 1864; Percent complete: 46.6%; Average loss: 3.3800
Iteration: 1865; Percent complete: 46.6%; Average loss: 3.3191
Iteration: 1866; Percent complete: 46.7%; Average loss: 3.0316
Iteration: 1867; Percent complete: 46.7%; Average loss: 3.0467
Iteration: 1868; Percent complete: 46.7%; Average loss: 3.1225
Iteration: 1869; Percent complete: 46.7%; Average loss: 3.2560
Iteration: 1870; Percent complete: 46.8%; Average loss: 3.0039
Iteration: 1871; Percent complete: 46.8%; Average loss:

Iteration: 1987; Percent complete: 49.7%; Average loss: 3.0527
Iteration: 1988; Percent complete: 49.7%; Average loss: 3.1506
Iteration: 1989; Percent complete: 49.7%; Average loss: 3.3598
Iteration: 1990; Percent complete: 49.8%; Average loss: 2.7594
Iteration: 1991; Percent complete: 49.8%; Average loss: 3.2899
Iteration: 1992; Percent complete: 49.8%; Average loss: 2.8921
Iteration: 1993; Percent complete: 49.8%; Average loss: 3.2072
Iteration: 1994; Percent complete: 49.9%; Average loss: 3.0970
Iteration: 1995; Percent complete: 49.9%; Average loss: 3.2572
Iteration: 1996; Percent complete: 49.9%; Average loss: 3.2721
Iteration: 1997; Percent complete: 49.9%; Average loss: 2.9786
Iteration: 1998; Percent complete: 50.0%; Average loss: 3.1428
Iteration: 1999; Percent complete: 50.0%; Average loss: 3.1157
Iteration: 2000; Percent complete: 50.0%; Average loss: 3.0571
Iteration: 2001; Percent complete: 50.0%; Average loss: 3.0690
Iteration: 2002; Percent complete: 50.0%; Average loss:

Iteration: 2118; Percent complete: 52.9%; Average loss: 3.2185
Iteration: 2119; Percent complete: 53.0%; Average loss: 2.9819
Iteration: 2120; Percent complete: 53.0%; Average loss: 2.9591
Iteration: 2121; Percent complete: 53.0%; Average loss: 3.1380
Iteration: 2122; Percent complete: 53.0%; Average loss: 2.8525
Iteration: 2123; Percent complete: 53.1%; Average loss: 3.3128
Iteration: 2124; Percent complete: 53.1%; Average loss: 3.2825
Iteration: 2125; Percent complete: 53.1%; Average loss: 2.8848
Iteration: 2126; Percent complete: 53.1%; Average loss: 3.0909
Iteration: 2127; Percent complete: 53.2%; Average loss: 3.2368
Iteration: 2128; Percent complete: 53.2%; Average loss: 3.0520
Iteration: 2129; Percent complete: 53.2%; Average loss: 3.0759
Iteration: 2130; Percent complete: 53.2%; Average loss: 3.0504
Iteration: 2131; Percent complete: 53.3%; Average loss: 3.0459
Iteration: 2132; Percent complete: 53.3%; Average loss: 3.0605
Iteration: 2133; Percent complete: 53.3%; Average loss:

Iteration: 2249; Percent complete: 56.2%; Average loss: 2.9938
Iteration: 2250; Percent complete: 56.2%; Average loss: 2.9703
Iteration: 2251; Percent complete: 56.3%; Average loss: 2.8122
Iteration: 2252; Percent complete: 56.3%; Average loss: 3.1138
Iteration: 2253; Percent complete: 56.3%; Average loss: 3.1107
Iteration: 2254; Percent complete: 56.4%; Average loss: 2.9172
Iteration: 2255; Percent complete: 56.4%; Average loss: 2.9170
Iteration: 2256; Percent complete: 56.4%; Average loss: 3.0846
Iteration: 2257; Percent complete: 56.4%; Average loss: 3.0918
Iteration: 2258; Percent complete: 56.5%; Average loss: 2.7814
Iteration: 2259; Percent complete: 56.5%; Average loss: 3.1205
Iteration: 2260; Percent complete: 56.5%; Average loss: 2.8746
Iteration: 2261; Percent complete: 56.5%; Average loss: 3.0230
Iteration: 2262; Percent complete: 56.5%; Average loss: 3.0753
Iteration: 2263; Percent complete: 56.6%; Average loss: 3.3424
Iteration: 2264; Percent complete: 56.6%; Average loss:

Iteration: 2380; Percent complete: 59.5%; Average loss: 3.0428
Iteration: 2381; Percent complete: 59.5%; Average loss: 2.8286
Iteration: 2382; Percent complete: 59.6%; Average loss: 3.2355
Iteration: 2383; Percent complete: 59.6%; Average loss: 3.3846
Iteration: 2384; Percent complete: 59.6%; Average loss: 2.8527
Iteration: 2385; Percent complete: 59.6%; Average loss: 3.2904
Iteration: 2386; Percent complete: 59.7%; Average loss: 2.9695
Iteration: 2387; Percent complete: 59.7%; Average loss: 2.8500
Iteration: 2388; Percent complete: 59.7%; Average loss: 2.9584
Iteration: 2389; Percent complete: 59.7%; Average loss: 3.2639
Iteration: 2390; Percent complete: 59.8%; Average loss: 3.3981
Iteration: 2391; Percent complete: 59.8%; Average loss: 3.0275
Iteration: 2392; Percent complete: 59.8%; Average loss: 3.2056
Iteration: 2393; Percent complete: 59.8%; Average loss: 3.1981
Iteration: 2394; Percent complete: 59.9%; Average loss: 2.8115
Iteration: 2395; Percent complete: 59.9%; Average loss:

Iteration: 2511; Percent complete: 62.8%; Average loss: 3.2266
Iteration: 2512; Percent complete: 62.8%; Average loss: 3.2086
Iteration: 2513; Percent complete: 62.8%; Average loss: 2.6301
Iteration: 2514; Percent complete: 62.8%; Average loss: 2.9649
Iteration: 2515; Percent complete: 62.9%; Average loss: 3.0533
Iteration: 2516; Percent complete: 62.9%; Average loss: 2.7859
Iteration: 2517; Percent complete: 62.9%; Average loss: 2.9036
Iteration: 2518; Percent complete: 62.9%; Average loss: 3.2191
Iteration: 2519; Percent complete: 63.0%; Average loss: 2.9557
Iteration: 2520; Percent complete: 63.0%; Average loss: 2.9669
Iteration: 2521; Percent complete: 63.0%; Average loss: 3.0599
Iteration: 2522; Percent complete: 63.0%; Average loss: 2.9210
Iteration: 2523; Percent complete: 63.1%; Average loss: 2.7176
Iteration: 2524; Percent complete: 63.1%; Average loss: 3.0353
Iteration: 2525; Percent complete: 63.1%; Average loss: 3.0141
Iteration: 2526; Percent complete: 63.1%; Average loss:

Iteration: 2642; Percent complete: 66.0%; Average loss: 2.7933
Iteration: 2643; Percent complete: 66.1%; Average loss: 3.2014
Iteration: 2644; Percent complete: 66.1%; Average loss: 3.1226
Iteration: 2645; Percent complete: 66.1%; Average loss: 3.0639
Iteration: 2646; Percent complete: 66.1%; Average loss: 3.1319
Iteration: 2647; Percent complete: 66.2%; Average loss: 3.0172
Iteration: 2648; Percent complete: 66.2%; Average loss: 2.9176
Iteration: 2649; Percent complete: 66.2%; Average loss: 3.0982
Iteration: 2650; Percent complete: 66.2%; Average loss: 3.0388
Iteration: 2651; Percent complete: 66.3%; Average loss: 3.0875
Iteration: 2652; Percent complete: 66.3%; Average loss: 2.9796
Iteration: 2653; Percent complete: 66.3%; Average loss: 2.9420
Iteration: 2654; Percent complete: 66.3%; Average loss: 2.9028
Iteration: 2655; Percent complete: 66.4%; Average loss: 3.1046
Iteration: 2656; Percent complete: 66.4%; Average loss: 2.9090
Iteration: 2657; Percent complete: 66.4%; Average loss:

Iteration: 2773; Percent complete: 69.3%; Average loss: 2.6716
Iteration: 2774; Percent complete: 69.3%; Average loss: 2.8578
Iteration: 2775; Percent complete: 69.4%; Average loss: 3.0072
Iteration: 2776; Percent complete: 69.4%; Average loss: 2.9613
Iteration: 2777; Percent complete: 69.4%; Average loss: 2.8467
Iteration: 2778; Percent complete: 69.5%; Average loss: 2.8762
Iteration: 2779; Percent complete: 69.5%; Average loss: 2.6935
Iteration: 2780; Percent complete: 69.5%; Average loss: 3.0984
Iteration: 2781; Percent complete: 69.5%; Average loss: 3.2245
Iteration: 2782; Percent complete: 69.5%; Average loss: 3.4061
Iteration: 2783; Percent complete: 69.6%; Average loss: 3.0881
Iteration: 2784; Percent complete: 69.6%; Average loss: 2.8157
Iteration: 2785; Percent complete: 69.6%; Average loss: 2.9265
Iteration: 2786; Percent complete: 69.7%; Average loss: 3.1012
Iteration: 2787; Percent complete: 69.7%; Average loss: 2.9638
Iteration: 2788; Percent complete: 69.7%; Average loss:

Iteration: 2904; Percent complete: 72.6%; Average loss: 2.8717
Iteration: 2905; Percent complete: 72.6%; Average loss: 2.8171
Iteration: 2906; Percent complete: 72.7%; Average loss: 2.7477
Iteration: 2907; Percent complete: 72.7%; Average loss: 2.9503
Iteration: 2908; Percent complete: 72.7%; Average loss: 3.1021
Iteration: 2909; Percent complete: 72.7%; Average loss: 2.8746
Iteration: 2910; Percent complete: 72.8%; Average loss: 2.8965
Iteration: 2911; Percent complete: 72.8%; Average loss: 2.7799
Iteration: 2912; Percent complete: 72.8%; Average loss: 3.0810
Iteration: 2913; Percent complete: 72.8%; Average loss: 2.7951
Iteration: 2914; Percent complete: 72.9%; Average loss: 2.9198
Iteration: 2915; Percent complete: 72.9%; Average loss: 2.9214
Iteration: 2916; Percent complete: 72.9%; Average loss: 2.8732
Iteration: 2917; Percent complete: 72.9%; Average loss: 2.7352
Iteration: 2918; Percent complete: 73.0%; Average loss: 3.0069
Iteration: 2919; Percent complete: 73.0%; Average loss:

Iteration: 3035; Percent complete: 75.9%; Average loss: 2.8520
Iteration: 3036; Percent complete: 75.9%; Average loss: 2.8864
Iteration: 3037; Percent complete: 75.9%; Average loss: 2.7873
Iteration: 3038; Percent complete: 75.9%; Average loss: 2.8743
Iteration: 3039; Percent complete: 76.0%; Average loss: 2.9537
Iteration: 3040; Percent complete: 76.0%; Average loss: 2.8886
Iteration: 3041; Percent complete: 76.0%; Average loss: 2.7153
Iteration: 3042; Percent complete: 76.0%; Average loss: 2.9191
Iteration: 3043; Percent complete: 76.1%; Average loss: 2.7082
Iteration: 3044; Percent complete: 76.1%; Average loss: 2.9093
Iteration: 3045; Percent complete: 76.1%; Average loss: 3.0274
Iteration: 3046; Percent complete: 76.1%; Average loss: 3.0044
Iteration: 3047; Percent complete: 76.2%; Average loss: 2.6755
Iteration: 3048; Percent complete: 76.2%; Average loss: 2.7917
Iteration: 3049; Percent complete: 76.2%; Average loss: 2.7110
Iteration: 3050; Percent complete: 76.2%; Average loss:

Iteration: 3166; Percent complete: 79.1%; Average loss: 2.7810
Iteration: 3167; Percent complete: 79.2%; Average loss: 2.8667
Iteration: 3168; Percent complete: 79.2%; Average loss: 2.8160
Iteration: 3169; Percent complete: 79.2%; Average loss: 2.9178
Iteration: 3170; Percent complete: 79.2%; Average loss: 2.9338
Iteration: 3171; Percent complete: 79.3%; Average loss: 2.7450
Iteration: 3172; Percent complete: 79.3%; Average loss: 2.9124
Iteration: 3173; Percent complete: 79.3%; Average loss: 2.7214
Iteration: 3174; Percent complete: 79.3%; Average loss: 3.0226
Iteration: 3175; Percent complete: 79.4%; Average loss: 2.7108
Iteration: 3176; Percent complete: 79.4%; Average loss: 2.8203
Iteration: 3177; Percent complete: 79.4%; Average loss: 2.7733
Iteration: 3178; Percent complete: 79.5%; Average loss: 2.7898
Iteration: 3179; Percent complete: 79.5%; Average loss: 2.7160
Iteration: 3180; Percent complete: 79.5%; Average loss: 2.8156
Iteration: 3181; Percent complete: 79.5%; Average loss:

Iteration: 3297; Percent complete: 82.4%; Average loss: 2.9245
Iteration: 3298; Percent complete: 82.5%; Average loss: 2.7174
Iteration: 3299; Percent complete: 82.5%; Average loss: 2.8048
Iteration: 3300; Percent complete: 82.5%; Average loss: 2.8628
Iteration: 3301; Percent complete: 82.5%; Average loss: 2.8480
Iteration: 3302; Percent complete: 82.5%; Average loss: 2.6786
Iteration: 3303; Percent complete: 82.6%; Average loss: 3.0449
Iteration: 3304; Percent complete: 82.6%; Average loss: 2.6203
Iteration: 3305; Percent complete: 82.6%; Average loss: 2.6813
Iteration: 3306; Percent complete: 82.7%; Average loss: 2.7710
Iteration: 3307; Percent complete: 82.7%; Average loss: 2.8645
Iteration: 3308; Percent complete: 82.7%; Average loss: 2.8431
Iteration: 3309; Percent complete: 82.7%; Average loss: 2.7433
Iteration: 3310; Percent complete: 82.8%; Average loss: 2.4644
Iteration: 3311; Percent complete: 82.8%; Average loss: 2.8601
Iteration: 3312; Percent complete: 82.8%; Average loss:

Iteration: 3428; Percent complete: 85.7%; Average loss: 2.6561
Iteration: 3429; Percent complete: 85.7%; Average loss: 2.6832
Iteration: 3430; Percent complete: 85.8%; Average loss: 2.6571
Iteration: 3431; Percent complete: 85.8%; Average loss: 2.7456
Iteration: 3432; Percent complete: 85.8%; Average loss: 2.8709
Iteration: 3433; Percent complete: 85.8%; Average loss: 2.6253
Iteration: 3434; Percent complete: 85.9%; Average loss: 2.7101
Iteration: 3435; Percent complete: 85.9%; Average loss: 2.7229
Iteration: 3436; Percent complete: 85.9%; Average loss: 2.5742
Iteration: 3437; Percent complete: 85.9%; Average loss: 2.7504
Iteration: 3438; Percent complete: 86.0%; Average loss: 2.8957
Iteration: 3439; Percent complete: 86.0%; Average loss: 2.8797
Iteration: 3440; Percent complete: 86.0%; Average loss: 2.6783
Iteration: 3441; Percent complete: 86.0%; Average loss: 2.7421
Iteration: 3442; Percent complete: 86.1%; Average loss: 2.8530
Iteration: 3443; Percent complete: 86.1%; Average loss:

Iteration: 3559; Percent complete: 89.0%; Average loss: 2.7428
Iteration: 3560; Percent complete: 89.0%; Average loss: 2.6139
Iteration: 3561; Percent complete: 89.0%; Average loss: 2.7547
Iteration: 3562; Percent complete: 89.0%; Average loss: 2.5987
Iteration: 3563; Percent complete: 89.1%; Average loss: 2.6525
Iteration: 3564; Percent complete: 89.1%; Average loss: 2.7120
Iteration: 3565; Percent complete: 89.1%; Average loss: 2.6715
Iteration: 3566; Percent complete: 89.1%; Average loss: 2.7913
Iteration: 3567; Percent complete: 89.2%; Average loss: 2.6037
Iteration: 3568; Percent complete: 89.2%; Average loss: 2.6614
Iteration: 3569; Percent complete: 89.2%; Average loss: 2.6353
Iteration: 3570; Percent complete: 89.2%; Average loss: 2.7229
Iteration: 3571; Percent complete: 89.3%; Average loss: 2.8644
Iteration: 3572; Percent complete: 89.3%; Average loss: 2.6706
Iteration: 3573; Percent complete: 89.3%; Average loss: 2.5364
Iteration: 3574; Percent complete: 89.3%; Average loss:

Iteration: 3690; Percent complete: 92.2%; Average loss: 2.6444
Iteration: 3691; Percent complete: 92.3%; Average loss: 2.7666
Iteration: 3692; Percent complete: 92.3%; Average loss: 2.6945
Iteration: 3693; Percent complete: 92.3%; Average loss: 2.5283
Iteration: 3694; Percent complete: 92.3%; Average loss: 2.5202
Iteration: 3695; Percent complete: 92.4%; Average loss: 2.6611
Iteration: 3696; Percent complete: 92.4%; Average loss: 2.6893
Iteration: 3697; Percent complete: 92.4%; Average loss: 2.7212
Iteration: 3698; Percent complete: 92.5%; Average loss: 2.6727
Iteration: 3699; Percent complete: 92.5%; Average loss: 2.9559
Iteration: 3700; Percent complete: 92.5%; Average loss: 2.4996
Iteration: 3701; Percent complete: 92.5%; Average loss: 2.5677
Iteration: 3702; Percent complete: 92.5%; Average loss: 2.7213
Iteration: 3703; Percent complete: 92.6%; Average loss: 2.5900
Iteration: 3704; Percent complete: 92.6%; Average loss: 2.5570
Iteration: 3705; Percent complete: 92.6%; Average loss:

Iteration: 3821; Percent complete: 95.5%; Average loss: 2.4433
Iteration: 3822; Percent complete: 95.5%; Average loss: 2.7975
Iteration: 3823; Percent complete: 95.6%; Average loss: 2.5391
Iteration: 3824; Percent complete: 95.6%; Average loss: 2.6078
Iteration: 3825; Percent complete: 95.6%; Average loss: 2.5561
Iteration: 3826; Percent complete: 95.7%; Average loss: 2.7470
Iteration: 3827; Percent complete: 95.7%; Average loss: 2.5275
Iteration: 3828; Percent complete: 95.7%; Average loss: 2.8659
Iteration: 3829; Percent complete: 95.7%; Average loss: 2.5353
Iteration: 3830; Percent complete: 95.8%; Average loss: 2.5676
Iteration: 3831; Percent complete: 95.8%; Average loss: 2.5115
Iteration: 3832; Percent complete: 95.8%; Average loss: 2.6103
Iteration: 3833; Percent complete: 95.8%; Average loss: 2.6334
Iteration: 3834; Percent complete: 95.9%; Average loss: 2.5181
Iteration: 3835; Percent complete: 95.9%; Average loss: 2.5909
Iteration: 3836; Percent complete: 95.9%; Average loss:

Iteration: 3952; Percent complete: 98.8%; Average loss: 2.6482
Iteration: 3953; Percent complete: 98.8%; Average loss: 2.5892
Iteration: 3954; Percent complete: 98.9%; Average loss: 2.7300
Iteration: 3955; Percent complete: 98.9%; Average loss: 2.6444
Iteration: 3956; Percent complete: 98.9%; Average loss: 2.6642
Iteration: 3957; Percent complete: 98.9%; Average loss: 2.5366
Iteration: 3958; Percent complete: 99.0%; Average loss: 2.5876
Iteration: 3959; Percent complete: 99.0%; Average loss: 2.4824
Iteration: 3960; Percent complete: 99.0%; Average loss: 2.5843
Iteration: 3961; Percent complete: 99.0%; Average loss: 2.6340
Iteration: 3962; Percent complete: 99.1%; Average loss: 2.6694
Iteration: 3963; Percent complete: 99.1%; Average loss: 2.6230
Iteration: 3964; Percent complete: 99.1%; Average loss: 2.7474
Iteration: 3965; Percent complete: 99.1%; Average loss: 2.6347
Iteration: 3966; Percent complete: 99.2%; Average loss: 2.5243
Iteration: 3967; Percent complete: 99.2%; Average loss:

# Run evaluation

In [22]:
# Set dropout layers to eval mode
encoder.eval()
decoder.eval()

# Initialize search module
searcher = GreedySearchDecoder(encoder, decoder)

# Begin chatting (uncomment and run the following line to begin)
evaluateInput(encoder, decoder, searcher, voc)

> hi
Bot: hi . you ? you ?
> i'm fine
Bot: you re a good girl .
> do you know quantitative trade
Error: Encountered unknown word.
> finance
Error: Encountered unknown word.
> math
Bot: it s not your fault .
> what's your hobby
Bot: i m sorry . i .
> i very sad
Bot: i know . you re crazy .
> hong kong
Error: Encountered unknown word.
> Hong Kong
Error: Encountered unknown word.
> you know what
Bot: i know . i .
> what you can do
Bot: i don t know . i
> bye
Bot: bye . you my wife .
> see you
Bot: i m not . you ?
> ok
Bot: i ll see you later .
> time
Bot: what ? what ? what ?
> you are so cute
Bot: i m sorry . you ?
> color
Bot: i m sorry . you .
> hello
Bot: hello . you ? you ?
> nice to meet you
Bot: i m sorry . you ?
> who are you
Bot: i m dying . you ?
> who are you
Bot: i m dying . you ?
> who are you?
Bot: i m martha mama . you .
> what's your name
Bot: victor . you re my friend .
> job
Bot: you re a girl . you ?
> i'm
Bot: i m sorry . you ?
> i am
Bot: is that right ? you ?
> good
