In [1]:
%config IPCompleter.greedy=True

import torch
import os
from torch.jit import script, trace
import torch.nn as nn
from torch import optim
import torch.nn.functional as F
import csv
import random
import re
import itertools
import math
import pandas as pd
import unicodedata
import codecs
import itertools


from torch.utils.data import Dataset, DataLoader
from torch.autograd import Variable
import numpy as np


In [2]:
CUDA = torch.cuda.is_available()
device = torch.device("cuda" if CUDA else "cpu")

# DATA


CSV version

In [3]:
xy = pd.read_csv('./prep_student_input.csv')

pastconvo = xy['Past Convo']
pastdf = pd.DataFrame(data=pastconvo)

tutor = xy['Tutor Response']
tutordf = pd.DataFrame(data=tutor)

pastconvo = pastdf.values.flatten().tolist()
response = tutordf.values.flatten().tolist()

pairs = [[pastconvo[i], response[i]] for i in range(len(pastconvo))]

# VOCABUOLARY


In [4]:
# Default word tokens
PAD_token = 0  # Used for padding short sentences
SOS_token = 1  # Start-of-sentence token
EOS_token = 2  # End-of-sentence token

class Vocabulary:
    def __init__(self, name):
        self.name = name
        self.word2index = {}
        self.word2count = {}
        self.index2word = {PAD_token: "PAD", SOS_token: "SOS", EOS_token: "EOS"}
        self.num_words = 3  # Count SOS, EOS, PAD

    def addSentence(self, sentence):
        for word in sentence.split(' '):
            self.addWord(word)

    def addWord(self, word):
        if word not in self.word2index:
            self.word2index[word] = self.num_words
            self.word2count[word] = 1
            self.index2word[self.num_words] = word
            self.num_words += 1
        else:
            self.word2count[word] += 1

    # Remove words below a certain count threshold # CHANGE probably shouldn't do this
    def trim(self, min_count):
        keep_words = []
        for k, v in self.word2count.items():
            if v >= min_count:
                keep_words.append(k)        
        print('keep_words {} / {} = {:.4f}'.format(
            len(keep_words), len(self.word2index), len(keep_words) / len(self.word2index)
        ))
        # Reinitialize dictionaries
        self.word2index = {}
        self.word2count = {}
        self.index2word = {PAD_token: "PAD", SOS_token: "SOS", EOS_token: "EOS"}
        self.num_words = 3 # Count default tokens

        for word in keep_words:
            self.addWord(word)

In [5]:
# make data simple
def unicodeToAscii(s):
    return ''.join(
        c for c in unicodedata.normalize('NFD', s)
        if unicodedata.category(c) != 'Mn'
    )

In [6]:
# Lowercase, trim, and remove non-letter characters
def normalizeString(s):
    s = unicodeToAscii(s.lower().strip())
    s = re.sub(r"([.!?])", r" \1", s)
    s = re.sub(r"[^a-zA-Z.!?]+", r" ", s)
    s = re.sub(r"\s+", r" ", s).strip()
    return s

# preprocessing data

In [7]:
pairs = [[normalizeString(line) for line in p] for p in pairs]

voc = Vocabulary("dialog")

In [8]:
#getting  rid of rare words
for pair in pairs:
    voc.addSentence(pair[0])
    voc.addSentence(pair[1])
print("counted words:", voc.num_words)


counted words: 1014


# prepare data

In [9]:
def indexesFromSentence(voc, sentence):
    return [voc.word2index[word] for word in sentence.split(' ')] + [EOS_token]


In [10]:
inp = []
out = []
for pair in pairs[:10]: #assume batch size 10
    inp.append(pair[0])
    out.append(pair[1])
indexes = [indexesFromSentence(voc, sentence) for  sentence in inp]
indexes

[[3,
  4,
  5,
  6,
  7,
  8,
  9,
  10,
  11,
  12,
  13,
  14,
  12,
  15,
  7,
  16,
  17,
  7,
  18,
  5,
  19,
  20,
  21,
  3,
  22,
  7,
  7,
  7,
  19,
  5,
  23,
  24,
  16,
  25,
  7,
  26,
  27,
  28,
  29,
  30,
  12,
  15,
  21,
  3,
  31,
  5,
  30,
  13,
  5,
  32,
  33,
  16,
  34,
  35,
  32,
  33,
  6,
  23,
  24,
  21,
  3,
  23,
  13,
  5,
  36,
  10,
  13,
  37,
  38,
  39,
  40,
  41,
  42,
  43,
  44,
  16,
  5,
  40,
  45,
  10,
  29,
  34,
  35,
  32,
  33,
  6,
  23,
  24,
  21,
  3,
  23,
  13,
  5,
  36,
  10,
  13,
  37,
  38,
  39,
  40,
  41,
  42,
  43,
  44,
  7,
  16,
  46,
  18,
  47,
  13,
  45,
  48,
  21,
  2],
 [3,
  8,
  9,
  10,
  11,
  12,
  13,
  14,
  12,
  15,
  7,
  16,
  26,
  27,
  63,
  29,
  5,
  64,
  10,
  13,
  21,
  3,
  65,
  5,
  64,
  10,
  13,
  5,
  32,
  66,
  67,
  16,
  32,
  66,
  67,
  68,
  6,
  5,
  69,
  70,
  71,
  7,
  3,
  17,
  28,
  72,
  73,
  63,
  43,
  74,
  7,
  75,
  5,
  76,
  16,
  77,
  17,
  78,
  71,
  2

In [11]:
#make column consistent and transpose
def zeroPadding(l, fillvalue= 0):
    return list(itertools.zip_longest(*l, fillvalue=fillvalue))

In [12]:
leng =  [len(ind) for ind in indexes]
max(leng)

114

In [13]:
test_result =  zeroPadding(indexes)


In [14]:
def binaryMatrix(l, value=0):
    m = []
    for i, seq in enumerate(l):
        m.append([])
        for token in seq:
            if token == PAD_token:
                m[i].append(0)
            else:
                m[i].append(1)
    return m

In [15]:
def inputVar(l, voc):
    indexes_batch = [indexesFromSentence(voc, sentence) for sentence in l]
    lengths = torch.tensor([len(indexes) for indexes in indexes_batch])
    padList = zeroPadding(indexes_batch)
    padVar = torch.LongTensor(padList)
    return padVar, lengths

# Returns padded target sequence tensor, padding mask, and max target length
def outputVar(l, voc):
    indexes_batch = [indexesFromSentence(voc, sentence) for sentence in l]
    max_target_len = max([len(indexes) for indexes in indexes_batch])
    padList = zeroPadding(indexes_batch)
    mask = binaryMatrix(padList)
    mask = torch.ByteTensor(mask)
    padVar = torch.LongTensor(padList)
    return padVar, mask, max_target_len

# Returns all items for a given batch of pairs
def batch2TrainData(voc, pair_batch):
    pair_batch.sort(key=lambda x: len(x[0].split(" ")), reverse=True)
    input_batch, output_batch = [], []
    for pair in pair_batch:
        input_batch.append(pair[0])
        
        output_batch.append(pair[1])
    print(input_batch)
    print(output_batch)
    inp, lengths = inputVar(input_batch, voc)
    output, mask, max_target_len = outputVar(output_batch, voc)
    return inp, lengths, output, mask, max_target_len, input_batch, output_batch

In [16]:
# Example for validation ??????????what is mask


# MODEL

In [17]:
class EncoderRNN(nn.Module):
    def __init__(self, hidden_size, embedding, n_layers=1, dropout=0):
        super(EncoderRNN, self).__init__()
        self.n_layers = n_layers
        self.hidden_size = hidden_size
        self.embedding = embedding

        # Initialize GRU; the input_size and hidden_size params are both set to 'hidden_size'
        #   because our input size is a word embedding with number of features == hidden_size
        self.gru = nn.GRU(hidden_size, hidden_size, n_layers,
                          dropout=(0 if n_layers == 1 else dropout), bidirectional=True)

    def forward(self, input_seq, input_lengths, hidden=None):
        # Convert word indexes to embeddings
        embedded = self.embedding(input_seq)
        # Pack padded batch of sequences for RNN module
        packed = nn.utils.rnn.pack_padded_sequence(embedded, input_lengths)
        # Forward pass through GRU
        outputs, hidden = self.gru(packed, hidden)
        # Unpack padding
        outputs, _ = nn.utils.rnn.pad_packed_sequence(outputs)
        # Sum bidirectional GRU outputs
        outputs = outputs[:, :, :self.hidden_size] + outputs[:, : ,self.hidden_size:]
        # Return output and final hidden state
        return outputs, hidden

Attention using dot

In [18]:
# Luong attention layer
class Attn(nn.Module):
    def __init__(self, method, hidden_size):
        super(Attn, self).__init__()
        self.method = method
        self.hidden_size = hidden_size

    def dot_score(self, hidden, encoder_output):
        return torch.sum(hidden * encoder_output, dim=2)

    def forward(self, hidden, encoder_outputs):
        # Calculate the attention weights (energies) based on the given method
        attn_energies = self.dot_score(hidden, encoder_outputs)
        # Transpose max_length and batch_size dimensions
        attn_energies = attn_energies.t()
        # Return the softmax normalized probability scores (with added dimension)
        return F.softmax(attn_energies, dim=1).unsqueeze(1)

In [19]:
class LuongAttnDecoderRNN(nn.Module):
    def __init__(self, attn_model, embedding, hidden_size, output_size, n_layers=1, dropout=0.1):
        super(LuongAttnDecoderRNN, self).__init__()

        # Keep for reference
        self.attn_model = attn_model
        self.hidden_size = hidden_size
        self.output_size = output_size
        self.n_layers = n_layers
        self.dropout = dropout

        # Define layers
        self.embedding = embedding
        self.embedding_dropout = nn.Dropout(dropout)
        self.gru = nn.GRU(hidden_size, hidden_size, n_layers, dropout=(0 if n_layers == 1 else dropout))
        self.concat = nn.Linear(hidden_size * 2, hidden_size)
        self.out = nn.Linear(hidden_size, output_size)

        self.attn = Attn(attn_model, hidden_size)

    def forward(self, input_step, last_hidden, encoder_outputs):
        # Note: we run this one step (word) at a time
        # Get embedding of current input word
        embedded = self.embedding(input_step)
        embedded = self.embedding_dropout(embedded)
        # Forward through unidirectional GRU
        rnn_output, hidden = self.gru(embedded, last_hidden)
        # Calculate attention weights from the current GRU output
        attn_weights = self.attn(rnn_output, encoder_outputs)
        # Multiply attention weights to encoder outputs to get new "weighted sum" context vector
        context = attn_weights.bmm(encoder_outputs.transpose(0, 1))
        # Concatenate weighted context vector and GRU output using Luong eq. 5
        rnn_output = rnn_output.squeeze(0)
        context = context.squeeze(1)
        concat_input = torch.cat((rnn_output, context), 1)
        concat_output = torch.tanh(self.concat(concat_input))
        # Predict next word using Luong eq. 6
        output = self.out(concat_output)
        output = F.softmax(output, dim=1)
        # Return output and final hidden state
        return output, hidden

# Training

In [20]:
# negative log lokelihood loss
def maskNLLLoss(inp, target, mask): 
    nTotal = mask.sum()
    crossEntropy = -torch.log(torch.gather(inp, 1, target.view(-1, 1)).squeeze(1))
    loss = crossEntropy.masked_select(mask).mean()
    loss = loss.to(device)
    return loss, nTotal.item()

In [49]:
num_iters = 3000

small_batch_size = 3
batches = batch2TrainData(voc, [random.choice(pairs) for _ in range(small_batch_size)])
input_variable, lengths, target_variable, mask, max_target_len, input_batch, output_batch = batches

print("input_variable:", input_variable)
print("lengths:", lengths)
print("target_variable:", target_variable)
print("mask:", mask)
print("max_target_len:", max_target_len)

attn_model = 'dot'
hidden_size = 500
encoder_n_layers = 2
decoder_n_layers = 2
dropout = 0.1
embedding = nn.Embedding(voc.num_words, hidden_size)

encoder = EncoderRNN(hidden_size, embedding, encoder_n_layers, dropout)
decoder = LuongAttnDecoderRNN(attn_model,  embedding, hidden_size, voc.num_words, decoder_n_layers, dropout)
#pass models to device 
encoder = encoder.to(device)
decoder = decoder.to(device)

encoder.train()
decoder.train()

#optimizers
encoder_optimizer = optim.Adam(encoder.parameters(), lr=0.0001)
decoder_optimizer = optim.Adam(decoder.parameters(), lr=0.0001)
encoder_optimizer.zero_grad()
decoder_optimizer.zero_grad()


# start training
print("START TRAINING")
print("\n")

for iteration in range(num_iters):
    
    print(" Iteration:", iteration)
    
#     batches = batch2TrainData(voc, [random.choice(pairs) for _ in range(small_batch_size)])
#     input_variable, lengths, target_variable, mask, max_target_len, input_batch, output_batch = batches
    
    #train one iter
    
    encoder_optimizer.zero_grad()
    decoder_optimizer.zero_grad()

    input_variable = input_variable.to(device)
    lengths = lengths.to(device)
    target_variable = target_variable.to(device)
    mast = mask.to(device)

    loss = 0
    print_losses = []
    n_totals = 0

    encoder_outputs, encoder_hidden = encoder(input_variable, lengths)
#     print("encoder output shape", encoder_outputs.shape)
#     print("last encoder hidden shape", encoder_hidden.shape)

    decoder_input = torch.LongTensor([[SOS_token for _ in range (small_batch_size)]])
    decoder_input = decoder_input.to(device)
#     print("initial decoder input shape", decoder_input.shape)
#     print(decoder_input)

    decoder_hidden =  encoder_hidden[:decoder.n_layers]
#     print("initial decoder hidden state shape", decoder_hidden.shape)

    for t in range(max_target_len):
        decoder_output, decoder_hidden = decoder(decoder_input, decoder_hidden, encoder_outputs)

        decoder_input - target_variable[t].view(1,-1)

        mask_loss, nTotal = maskNLLLoss(decoder_output, target_variable[t], mask[t])
        loss += mask_loss
        print_losses.append(mask_loss.item() * nTotal)

        n_totals+= nTotal

    loss.backward(retain_graph=True)

    encoder_optimizer.step()
    decoder_optimizer.step()

    returned_loss = sum(print_losses) / n_totals
    print("returned loss", returned_loss)
    print("\n")
    



['tutor is behind the is e dietro il . please try to fill in the blank in italian . student what is blue in italian ? tutor can you give me your best guess ? student blueo tutor remember that is behind the is e dietro il student e dietro il blueo cato tutor hmm . . . is behind the is e dietro il student e dietro il tutor hmm . . . cat is gatto student e dietro il gatto', 'tutor please try to fill in the blank in italian . student il cane . . . tutor hmm . . . is inside of the is e dentro la student il cane e dentro la', 'tutor please try to fill in the blank in italian . student how do i say tree in italian ?']
['all right here s a hint the word for blue is blu .', 'what is the word for bag ?', 'tree is all albero .']
input_variable: tensor([[  3,   3,   3],
        [  5,   8,   8],
        [ 30,   9,   9],
        [ 13,  10,  10],
        [  5,  11,  11],
        [ 32,  12,  12],
        [ 33,  13,  13],
        [ 34,  14,  14],
        [  7,  12,  12],
        [  8,  15,  15],
      

returned loss 1.6195268694843565


 Iteration: 95
returned loss 1.5972818295870508


 Iteration: 96
returned loss 1.5810175378407751


 Iteration: 97
returned loss 1.5665570124983788


 Iteration: 98
returned loss 1.5545053662998336


 Iteration: 99
returned loss 1.536722019314766


 Iteration: 100
returned loss 1.5197557349290167


 Iteration: 101
returned loss 1.5159795773880822


 Iteration: 102
returned loss 1.4978774477328574


 Iteration: 103
returned loss 1.531127427305494


 Iteration: 104
returned loss 1.516079374722072


 Iteration: 105
returned loss 1.504314046885286


 Iteration: 106
returned loss 1.5140800209982055


 Iteration: 107
returned loss 1.4968442523053713


 Iteration: 108
returned loss 1.5056497209838458


 Iteration: 109
returned loss 1.4973555047597205


 Iteration: 110
returned loss 1.5948621502944402


 Iteration: 111
returned loss 1.5851133380617415


 Iteration: 112
returned loss 1.567395303930555


 Iteration: 113
returned loss 1.551497561591012


 Iterat

returned loss 0.5249376044209514


 Iteration: 257
returned loss 0.5207515132746526


 Iteration: 258
returned loss 0.5224532806979758


 Iteration: 259
returned loss 0.5153019018471241


 Iteration: 260
returned loss 0.5167729298451117


 Iteration: 261
returned loss 0.5052760655858687


 Iteration: 262
returned loss 0.5100371060626847


 Iteration: 263
returned loss 0.5071013827941248


 Iteration: 264
returned loss 0.5029951395200831


 Iteration: 265
returned loss 0.4957552400550672


 Iteration: 266
returned loss 0.49515996421022074


 Iteration: 267
returned loss 0.4882732047034161


 Iteration: 268
returned loss 0.4885840700673206


 Iteration: 269
returned loss 0.48279196582734585


 Iteration: 270
returned loss 0.487875835703952


 Iteration: 271
returned loss 0.48375758874629227


 Iteration: 272
returned loss 0.46783155389130116


 Iteration: 273
returned loss 0.4745282283318894


 Iteration: 274
returned loss 0.4646844414195844


 Iteration: 275
returned loss 0.464670896796

returned loss 0.3346051994178976


 Iteration: 417
returned loss 0.3249136454292706


 Iteration: 418
returned loss 0.3197374814855201


 Iteration: 419
returned loss 0.3093463701329061


 Iteration: 420
returned loss 0.31290607314024654


 Iteration: 421
returned loss 0.31019849942198824


 Iteration: 422
returned loss 0.3113698727850403


 Iteration: 423
returned loss 0.32399158499070574


 Iteration: 424
returned loss 0.2948451087410961


 Iteration: 425
returned loss 0.2932581356061356


 Iteration: 426
returned loss 0.2842305810855968


 Iteration: 427
returned loss 0.2796121859656913


 Iteration: 428
returned loss 0.2801796014287642


 Iteration: 429
returned loss 0.2818301687283175


 Iteration: 430
returned loss 0.2730109207332134


 Iteration: 431
returned loss 0.2688636441848108


 Iteration: 432
returned loss 0.2902828944580896


 Iteration: 433
returned loss 0.26862613963229315


 Iteration: 434
returned loss 0.26323569166873184


 Iteration: 435
returned loss 0.2619323243

returned loss 0.13686112726905517


 Iteration: 576
returned loss 0.1323820206203631


 Iteration: 577
returned loss 0.13047304137476853


 Iteration: 578
returned loss 0.13846617923783405


 Iteration: 579
returned loss 0.1326154371989625


 Iteration: 580
returned loss 0.13789354437696083


 Iteration: 581
returned loss 0.13221666043890373


 Iteration: 582
returned loss 0.12773527138467347


 Iteration: 583
returned loss 0.13759320748171636


 Iteration: 584
returned loss 0.1435551926759737


 Iteration: 585
returned loss 0.12584652392459766


 Iteration: 586
returned loss 0.12589994671621493


 Iteration: 587
returned loss 0.12693818685199534


 Iteration: 588
returned loss 0.12659915629774332


 Iteration: 589
returned loss 0.1312118543844138


 Iteration: 590
returned loss 0.13638241150017297


 Iteration: 591
returned loss 0.12374145683965512


 Iteration: 592
returned loss 0.12327570893934794


 Iteration: 593
returned loss 0.12623868164207255


 Iteration: 594
returned loss 0.

returned loss 0.07402870205364057


 Iteration: 734
returned loss 0.07409365574962326


 Iteration: 735
returned loss 0.07340899987944535


 Iteration: 736
returned loss 0.07250182709789701


 Iteration: 737
returned loss 0.072337880797152


 Iteration: 738
returned loss 0.07197891628103596


 Iteration: 739
returned loss 0.07234801358676382


 Iteration: 740
returned loss 0.07258075269471322


 Iteration: 741
returned loss 0.07127913732879929


 Iteration: 742
returned loss 0.0724958214642746


 Iteration: 743
returned loss 0.0714536937219756


 Iteration: 744
returned loss 0.07204309564882092


 Iteration: 745
returned loss 0.07022767666993397


 Iteration: 746
returned loss 0.07018023190487709


 Iteration: 747
returned loss 0.06955648433150989


 Iteration: 748
returned loss 0.06992646100531731


 Iteration: 749
returned loss 0.06952932417126638


 Iteration: 750
returned loss 0.06938495879460659


 Iteration: 751
returned loss 0.06917004673076528


 Iteration: 752
returned loss 0.

returned loss 0.04551794150445078


 Iteration: 892
returned loss 0.04564271218675588


 Iteration: 893
returned loss 0.04506664121124361


 Iteration: 894
returned loss 0.044808863229783516


 Iteration: 895
returned loss 0.04476321398812745


 Iteration: 896
returned loss 0.04469572912369456


 Iteration: 897
returned loss 0.04565739771351218


 Iteration: 898
returned loss 0.044689876991989355


 Iteration: 899
returned loss 0.0450357556941786


 Iteration: 900
returned loss 0.044484470871144106


 Iteration: 901
returned loss 0.04429524069252823


 Iteration: 902
returned loss 0.04610555582413716


 Iteration: 903
returned loss 0.044035052614552636


 Iteration: 904
returned loss 0.04421484952659479


 Iteration: 905
returned loss 0.043591578324724524


 Iteration: 906
returned loss 0.043941570878294964


 Iteration: 907
returned loss 0.043407089509336014


 Iteration: 908
returned loss 0.0434198816240366


 Iteration: 909
returned loss 0.043537090898358395


 Iteration: 910
return

returned loss 0.030658848217821548


 Iteration: 1048
returned loss 0.03054576420358249


 Iteration: 1049
returned loss 0.030417265703103373


 Iteration: 1050
returned loss 0.030385438718699982


 Iteration: 1051
returned loss 0.030458078919244663


 Iteration: 1052
returned loss 0.030414201019863998


 Iteration: 1053
returned loss 0.03014070983044803


 Iteration: 1054
returned loss 0.03028420588400747


 Iteration: 1055
returned loss 0.030350001229505454


 Iteration: 1056
returned loss 0.03074621462396213


 Iteration: 1057
returned loss 0.029610614691461836


 Iteration: 1058
returned loss 0.030209082178771496


 Iteration: 1059
returned loss 0.03089546149463526


 Iteration: 1060
returned loss 0.029600356040256365


 Iteration: 1061
returned loss 0.029545220247070705


 Iteration: 1062
returned loss 0.02948858535715512


 Iteration: 1063
returned loss 0.02946111692913941


 Iteration: 1064
returned loss 0.0296706572241549


 Iteration: 1065
returned loss 0.031184869319466606




returned loss 0.021791284571268728


 Iteration: 1201
returned loss 0.022465992186750685


 Iteration: 1202
returned loss 0.022522173256480267


 Iteration: 1203
returned loss 0.021693664469889233


 Iteration: 1204
returned loss 0.021634910754593357


 Iteration: 1205
returned loss 0.0213586738266583


 Iteration: 1206
returned loss 0.02200906569071646


 Iteration: 1207
returned loss 0.02181848614210529


 Iteration: 1208
returned loss 0.022042898111976683


 Iteration: 1209
returned loss 0.02142805468091475


 Iteration: 1210
returned loss 0.021219184372707138


 Iteration: 1211
returned loss 0.02168439225559788


 Iteration: 1212
returned loss 0.02140722856191652


 Iteration: 1213
returned loss 0.021669104445858726


 Iteration: 1214
returned loss 0.021125863705362593


 Iteration: 1215
returned loss 0.021662745724565217


 Iteration: 1216
returned loss 0.021352617319540253


 Iteration: 1217
returned loss 0.021103080112620125


 Iteration: 1218
returned loss 0.02142924823731716



returned loss 0.016433972898604616


 Iteration: 1354
returned loss 0.016479399071873298


 Iteration: 1355
returned loss 0.016491333371959627


 Iteration: 1356
returned loss 0.016214061744644175


 Iteration: 1357
returned loss 0.016348236157292768


 Iteration: 1358
returned loss 0.01666535180993378


 Iteration: 1359
returned loss 0.016224872703397914


 Iteration: 1360
returned loss 0.016437122249044478


 Iteration: 1361
returned loss 0.016305144992657006


 Iteration: 1362
returned loss 0.016170592895442887


 Iteration: 1363
returned loss 0.016169013671710024


 Iteration: 1364
returned loss 0.016091259395969764


 Iteration: 1365
returned loss 0.016327199892007878


 Iteration: 1366
returned loss 0.016072535338545486


 Iteration: 1367
returned loss 0.015975877127077962


 Iteration: 1368
returned loss 0.01593295639447336


 Iteration: 1369
returned loss 0.016007843510513858


 Iteration: 1370
returned loss 0.015990949269118055


 Iteration: 1371
returned loss 0.01588244516668

returned loss 0.012844840247583176


 Iteration: 1507
returned loss 0.012728075596636959


 Iteration: 1508
returned loss 0.012768163744892393


 Iteration: 1509
returned loss 0.012663905858062208


 Iteration: 1510
returned loss 0.01273918422936861


 Iteration: 1511
returned loss 0.012768302844571216


 Iteration: 1512
returned loss 0.012675695603580348


 Iteration: 1513
returned loss 0.012661281607246824


 Iteration: 1514
returned loss 0.012672834946507854


 Iteration: 1515
returned loss 0.012573127834392446


 Iteration: 1516
returned loss 0.012707337071853024


 Iteration: 1517
returned loss 0.012671962596609123


 Iteration: 1518
returned loss 0.012543980298297746


 Iteration: 1519
returned loss 0.01264493927426104


 Iteration: 1520
returned loss 0.012596285785548389


 Iteration: 1521
returned loss 0.012552415602840483


 Iteration: 1522
returned loss 0.012465746241754718


 Iteration: 1523
returned loss 0.012704738565454525


 Iteration: 1524
returned loss 0.01249752080600

returned loss 0.010345345363020897


 Iteration: 1660
returned loss 0.010424653791622924


 Iteration: 1661
returned loss 0.01022693227943299


 Iteration: 1662
returned loss 0.01024033680524943


 Iteration: 1663
returned loss 0.010233663621225528


 Iteration: 1664
returned loss 0.010700868675485253


 Iteration: 1665
returned loss 0.010188524328571345


 Iteration: 1666
returned loss 0.010151414608117193


 Iteration: 1667
returned loss 0.010091891229551817


 Iteration: 1668
returned loss 0.010222497387855713


 Iteration: 1669
returned loss 0.01030522259783798


 Iteration: 1670
returned loss 0.010150010306720756


 Iteration: 1671
returned loss 0.010111175261304848


 Iteration: 1672
returned loss 0.010230969853832253


 Iteration: 1673
returned loss 0.010122699779458344


 Iteration: 1674
returned loss 0.010148025001399219


 Iteration: 1675
returned loss 0.01007654656755871


 Iteration: 1676
returned loss 0.010077760164027236


 Iteration: 1677
returned loss 0.0100495068389656

returned loss 0.008361074749180781


 Iteration: 1813
returned loss 0.008465682979606624


 Iteration: 1814
returned loss 0.008405639291075724


 Iteration: 1815
returned loss 0.008445538971240498


 Iteration: 1816
returned loss 0.008459059662917363


 Iteration: 1817
returned loss 0.008290014038461127


 Iteration: 1818
returned loss 0.008277146054232227


 Iteration: 1819
returned loss 0.008288164631397064


 Iteration: 1820
returned loss 0.00825683111491214


 Iteration: 1821
returned loss 0.008364400543671633


 Iteration: 1822
returned loss 0.008254476465351348


 Iteration: 1823
returned loss 0.008264043118937739


 Iteration: 1824
returned loss 0.008239854967022049


 Iteration: 1825
returned loss 0.008221105664103692


 Iteration: 1826
returned loss 0.008297887381299265


 Iteration: 1827
returned loss 0.008253990572744183


 Iteration: 1828
returned loss 0.008167928758276892


 Iteration: 1829
returned loss 0.008232484604897243


 Iteration: 1830
returned loss 0.0082122418929

returned loss 0.007124924308819962


 Iteration: 1965
returned loss 0.006981030114859875


 Iteration: 1966
returned loss 0.007033139142939555


 Iteration: 1967
returned loss 0.006965608693592783


 Iteration: 1968
returned loss 0.006888679519761354


 Iteration: 1969
returned loss 0.00695281963895208


 Iteration: 1970
returned loss 0.006930535875393876


 Iteration: 1971
returned loss 0.0069076714072642586


 Iteration: 1972
returned loss 0.006983189885171929


 Iteration: 1973
returned loss 0.006861045424427305


 Iteration: 1974
returned loss 0.006908695438011948


 Iteration: 1975
returned loss 0.007083907008304128


 Iteration: 1976
returned loss 0.006842811433931014


 Iteration: 1977
returned loss 0.006848689193637776


 Iteration: 1978
returned loss 0.006834879767016641


 Iteration: 1979
returned loss 0.006825266065009471


 Iteration: 1980
returned loss 0.006865343270224652


 Iteration: 1981
returned loss 0.006822580662888608


 Iteration: 1982
returned loss 0.006916621242

returned loss 0.0058869223950231185


 Iteration: 2117
returned loss 0.005886004684725776


 Iteration: 2118
returned loss 0.005827956961833739


 Iteration: 2119
returned loss 0.0058655865390651995


 Iteration: 2120
returned loss 0.005847432122599068


 Iteration: 2121
returned loss 0.005870367034471461


 Iteration: 2122
returned loss 0.005862603205189641


 Iteration: 2123
returned loss 0.005844004019828779


 Iteration: 2124
returned loss 0.005817146240068334


 Iteration: 2125
returned loss 0.005794739999276187


 Iteration: 2126
returned loss 0.005838124193750056


 Iteration: 2127
returned loss 0.00584984542053592


 Iteration: 2128
returned loss 0.005762782879173756


 Iteration: 2129
returned loss 0.005794857250293717


 Iteration: 2130
returned loss 0.005729369894002697


 Iteration: 2131
returned loss 0.005775638643951554


 Iteration: 2132
returned loss 0.0057130040617526644


 Iteration: 2133
returned loss 0.005751983270913895


 Iteration: 2134
returned loss 0.0057572913

returned loss 0.005079463095171377


 Iteration: 2269
returned loss 0.004990397116801303


 Iteration: 2270
returned loss 0.004961527584652815


 Iteration: 2271
returned loss 0.004978425243669855


 Iteration: 2272
returned loss 0.005013727034176034


 Iteration: 2273
returned loss 0.005058085429482162


 Iteration: 2274
returned loss 0.00507167519702177


 Iteration: 2275
returned loss 0.004961888808923375


 Iteration: 2276
returned loss 0.005073749269026199


 Iteration: 2277
returned loss 0.004994677375569674


 Iteration: 2278
returned loss 0.004935145872877911


 Iteration: 2279
returned loss 0.004953867759987978


 Iteration: 2280
returned loss 0.005003145816902231


 Iteration: 2281
returned loss 0.004976621849761743


 Iteration: 2282
returned loss 0.004930144026210266


 Iteration: 2283
returned loss 0.004942088575002604


 Iteration: 2284
returned loss 0.00495800789212808


 Iteration: 2285
returned loss 0.0049527056606686005


 Iteration: 2286
returned loss 0.0049021566519

returned loss 0.004324234074114689


 Iteration: 2421
returned loss 0.004312677915939795


 Iteration: 2422
returned loss 0.00427310068542803


 Iteration: 2423
returned loss 0.004344028158811852


 Iteration: 2424
returned loss 0.004275468510708639


 Iteration: 2425
returned loss 0.0042865142750088125


 Iteration: 2426
returned loss 0.004278137847515089


 Iteration: 2427
returned loss 0.004297688092005306


 Iteration: 2428
returned loss 0.0042583651125564104


 Iteration: 2429
returned loss 0.004236977044326652


 Iteration: 2430
returned loss 0.004241141049923109


 Iteration: 2431
returned loss 0.004290979908546433


 Iteration: 2432
returned loss 0.004262281835378546


 Iteration: 2433
returned loss 0.004268692153605765


 Iteration: 2434
returned loss 0.004182472554800499


 Iteration: 2435
returned loss 0.004236543127002993


 Iteration: 2436
returned loss 0.004266319226839447


 Iteration: 2437
returned loss 0.004216083487595564


 Iteration: 2438
returned loss 0.00425377819

returned loss 0.0037319869568039265


 Iteration: 2573
returned loss 0.0037090696485912694


 Iteration: 2574
returned loss 0.0037995093485473524


 Iteration: 2575
returned loss 0.0037268256377761383


 Iteration: 2576
returned loss 0.0036748538217839916


 Iteration: 2577
returned loss 0.003668175922939554


 Iteration: 2578
returned loss 0.0037252704579649226


 Iteration: 2579
returned loss 0.0036600107352049755


 Iteration: 2580
returned loss 0.003668421787941562


 Iteration: 2581
returned loss 0.003681870156599741


 Iteration: 2582
returned loss 0.003727870947581583


 Iteration: 2583
returned loss 0.0036601730140059124


 Iteration: 2584
returned loss 0.0036634369004916933


 Iteration: 2585
returned loss 0.003692861047706434


 Iteration: 2586
returned loss 0.0036666637003820923


 Iteration: 2587
returned loss 0.003639376144357292


 Iteration: 2588
returned loss 0.003694520302815363


 Iteration: 2589
returned loss 0.0036364453221072574


 Iteration: 2590
returned loss 0.0

returned loss 0.003224350865431396


 Iteration: 2724
returned loss 0.003233600890130869


 Iteration: 2725
returned loss 0.0032559009518341292


 Iteration: 2726
returned loss 0.003225240873039833


 Iteration: 2727
returned loss 0.0032662045559845865


 Iteration: 2728
returned loss 0.0032263033042129663


 Iteration: 2729
returned loss 0.0032136312677591506


 Iteration: 2730
returned loss 0.003248293522379494


 Iteration: 2731
returned loss 0.003202891459555498


 Iteration: 2732
returned loss 0.0032250995489968254


 Iteration: 2733
returned loss 0.0032010900280770977


 Iteration: 2734
returned loss 0.0032364632601716687


 Iteration: 2735
returned loss 0.003208416177325749


 Iteration: 2736
returned loss 0.0032124822235865785


 Iteration: 2737
returned loss 0.0032118176708796193


 Iteration: 2738
returned loss 0.003194535812196721


 Iteration: 2739
returned loss 0.003177591674362442


 Iteration: 2740
returned loss 0.003196191753626668


 Iteration: 2741
returned loss 0.003

returned loss 0.002883710206203562


 Iteration: 2875
returned loss 0.0028473180926604463


 Iteration: 2876
returned loss 0.0028218700546338888


 Iteration: 2877
returned loss 0.002847025141818449


 Iteration: 2878
returned loss 0.002835635989738096


 Iteration: 2879
returned loss 0.002827947524825244


 Iteration: 2880
returned loss 0.0028187097881787588


 Iteration: 2881
returned loss 0.0028278161626076326


 Iteration: 2882
returned loss 0.002810727630276233


 Iteration: 2883
returned loss 0.0028039233522057267


 Iteration: 2884
returned loss 0.0028206500781899585


 Iteration: 2885
returned loss 0.0028729802267792237


 Iteration: 2886
returned loss 0.002817282760848424


 Iteration: 2887
returned loss 0.0028088151136346695


 Iteration: 2888
returned loss 0.0027954816323472187


 Iteration: 2889
returned loss 0.0027964480416682947


 Iteration: 2890
returned loss 0.002811937784177384


 Iteration: 2891
returned loss 0.0027934731146420483


 Iteration: 2892
returned loss 0.0

In [50]:
class GreedySearchDecoder(nn.Module):
    def __init__(self, encoder, decoder):
        super(GreedySearchDecoder, self).__init__()
        self.encoder = encoder
        self.decoder = decoder

    def forward(self, input_seq, input_length, max_length):
        # Forward input through encoder model
        encoder_outputs, encoder_hidden = self.encoder(input_seq, input_length)
        # Prepare encoder's final hidden layer to be first hidden input to the decoder
        decoder_hidden = encoder_hidden[:decoder.n_layers]
        # Initialize decoder input with SOS_token
        decoder_input = torch.ones(1, 1, device=device, dtype=torch.long) * SOS_token
        # Initialize tensors to append decoded words to
        all_tokens = torch.zeros([0], device=device, dtype=torch.long)
        all_scores = torch.zeros([0], device=device)
        # Iteratively decode one word token at a time
        for _ in range(max_length):
            # Forward pass through decoder
            decoder_output, decoder_hidden = self.decoder(decoder_input, decoder_hidden, encoder_outputs)
            # Obtain most likely word token and its softmax score
            decoder_scores, decoder_input = torch.max(decoder_output, dim=1)
            # Record token and score
            all_tokens = torch.cat((all_tokens, decoder_input), dim=0)
            all_scores = torch.cat((all_scores, decoder_scores), dim=0)
            # Prepare current token to be next decoder input (add a dimension)
            decoder_input = torch.unsqueeze(decoder_input, 0)
        # Return collections of word tokens and scores
        return all_tokens, all_scores
    

In [51]:
def evaluate(encoder, decoder, searcher, voc, sentence, max_length):
    ### Format input sentence as a batch
    # words -> indexes
    indexes_batch = [indexesFromSentence(voc, sentence)]
    # Create lengths tensor
    lengths = torch.tensor([len(indexes) for indexes in indexes_batch])
    # Transpose dimensions of batch to match models' expectations
    input_batch = torch.LongTensor(indexes_batch).transpose(0, 1)
    # Use appropriate device
    input_batch = input_batch.to(device)
    lengths = lengths.to(device)
    # Decode sentence with searcher
    tokens, scores = searcher(input_batch, lengths, max_length)
    # indexes -> words
    decoded_words = [voc.index2word[token.item()] for token in tokens]
    return decoded_words


def evaluateInput(encoder, decoder, searcher, voc, input_sentence, max_length):
    input_sentence = normalizeString(input_sentence)
    # Evaluate sentence
    output_words = evaluate(encoder, decoder, searcher, voc, input_sentence, max_length)
    # Format and print response sentence
    output_words[:] = [x for x in output_words if not (x == 'EOS' or x == 'PAD')]
    print('Bot:', ' '.join(output_words))


In [54]:
#evaluate model

encoder.eval()
decoder.eval()

searcher = GreedySearchDecoder(encoder, decoder)

inputs = input_batch
outputs = output_batch

for i in range(3):
    input_sentence = inputs[i]
    max_length = len(normalizeString(outputs[i]).split()) 
    print("past convo:", input_sentence)
    print("\n")
    evaluateInput(encoder, decoder, searcher, voc, input_sentence,40)
    print("\n")
    print("Correct: ", outputs[i])
    print("\n")

past convo: tutor is behind the is e dietro il . please try to fill in the blank in italian . student what is blue in italian ? tutor can you give me your best guess ? student blueo tutor remember that is behind the is e dietro il student e dietro il blueo cato tutor hmm . . . is behind the is e dietro il student e dietro il tutor hmm . . . cat is gatto student e dietro il gatto


Bot: all all right right right right here here here here here here here here here here here here here here here here here here here here here here here here here here here here here here here here here here


Correct:  all right here s a hint the word for blue is blu .


past convo: tutor please try to fill in the blank in italian . student il cane . . . tutor hmm . . . is inside of the is e dentro la student il cane e dentro la


Bot: what what is is is the the the the the the the the the what what what what what what what what what what what what what what what what what what what what what what what what w