In [1]:
import torch
import torch.nn as nn 
from torch import optim
import torch.nn.functional as f
import torchvision

In [2]:
import csv 
import os
import random
import re
import unicodedata
import codecs
import itertools

In [3]:
cuda = torch.cuda.is_available()
device = torch.device('cuda' if cuda else 'cpu')

## Data Preprocessing 

In [4]:
line_path = os.path.join('cornell movie-dialogs corpus','movie_lines.txt')
conv_path = os.path.join('cornell movie-dialogs corpus','movie_conversations.txt')

In [5]:
with open(line_path, 'rb') as f:
    lines = f.readlines()
#   lines = lines.encode('utf-8').strip()
    
for line in lines[:10]:
    print(line.strip())

b'L1045 +++$+++ u0 +++$+++ m0 +++$+++ BIANCA +++$+++ They do not!'
b'L1044 +++$+++ u2 +++$+++ m0 +++$+++ CAMERON +++$+++ They do to!'
b'L985 +++$+++ u0 +++$+++ m0 +++$+++ BIANCA +++$+++ I hope so.'
b'L984 +++$+++ u2 +++$+++ m0 +++$+++ CAMERON +++$+++ She okay?'
b"L925 +++$+++ u0 +++$+++ m0 +++$+++ BIANCA +++$+++ Let's go."
b'L924 +++$+++ u2 +++$+++ m0 +++$+++ CAMERON +++$+++ Wow'
b"L872 +++$+++ u0 +++$+++ m0 +++$+++ BIANCA +++$+++ Okay -- you're gonna need to learn how to lie."
b'L871 +++$+++ u2 +++$+++ m0 +++$+++ CAMERON +++$+++ No'
b'L870 +++$+++ u0 +++$+++ m0 +++$+++ BIANCA +++$+++ I\'m kidding.  You know how sometimes you just become this "persona"?  And you don\'t know how to quit?'
b'L869 +++$+++ u0 +++$+++ m0 +++$+++ BIANCA +++$+++ Like my fear of wearing pastels?'


In [6]:
line_fields = ['lineID', 'charID', 'movieID', 'char', 'text']
lines = {}
with open(line_path, 'r', encoding='iso-8859-1') as f:
    for line in f:
        value = line.split(' +++$+++ ')
        lineobj = {}
        for i , field in enumerate(line_fields):
            lineobj[field] = value[i]
        lines[lineobj['lineID']] = lineobj

In [7]:
list(lines.values())[0]

{'lineID': 'L1045',
 'charID': 'u0',
 'movieID': 'm0',
 'char': 'BIANCA',
 'text': 'They do not!\n'}

In [8]:
with open(conv_path, 'rb') as f:
    ls = f.readlines()
#   lines = lines.encode('utf-8').strip()
    
for line in ls[:10]:
    print(line.strip())

b"u0 +++$+++ u2 +++$+++ m0 +++$+++ ['L194', 'L195', 'L196', 'L197']"
b"u0 +++$+++ u2 +++$+++ m0 +++$+++ ['L198', 'L199']"
b"u0 +++$+++ u2 +++$+++ m0 +++$+++ ['L200', 'L201', 'L202', 'L203']"
b"u0 +++$+++ u2 +++$+++ m0 +++$+++ ['L204', 'L205', 'L206']"
b"u0 +++$+++ u2 +++$+++ m0 +++$+++ ['L207', 'L208']"
b"u0 +++$+++ u2 +++$+++ m0 +++$+++ ['L271', 'L272', 'L273', 'L274', 'L275']"
b"u0 +++$+++ u2 +++$+++ m0 +++$+++ ['L276', 'L277']"
b"u0 +++$+++ u2 +++$+++ m0 +++$+++ ['L280', 'L281']"
b"u0 +++$+++ u2 +++$+++ m0 +++$+++ ['L363', 'L364']"
b"u0 +++$+++ u2 +++$+++ m0 +++$+++ ['L365', 'L366']"


In [9]:
conv_fields = ['char1ID', 'char2ID', 'movieID', 'utteranceID']
convs = []
with open(conv_path, 'r', encoding='iso-8859-1') as f:
    for line in f:
        value = line.split(' +++$+++ ')
        convobj = {}
        for i , field in enumerate(conv_fields):
            convobj[field] = value[i]
        lineIDs = eval(convobj['utteranceID'])
        convobj['lines'] = []
        for lineID in lineIDs:
            convobj['lines'].append(lines[lineID])
        convs.append(convobj)

In [10]:
len(convs)

83097

In [11]:
qa_pairs = []
for conversation in convs:
    for i in range(len(conversation['lines']) - 1):
        questline = conversation['lines'][i]['text'].strip()
        ansline = conversation['lines'][i+1]['text'].strip()
        if questline and ansline:
            qa_pairs.append([questline, ansline])

In [12]:
print(qa_pairs[0])

['Can we make this quick?  Roxanne Korrine and Andrew Barrett are having an incredibly horrendous public break- up on the quad.  Again.', "Well, I thought we'd start with pronunciation, if that's okay with you."]


In [13]:
len(qa_pairs)

221282

In [14]:
datafile = os.path.join('cornell movie-dialogs corpus', 'formatted_movie_lines.txt')
delimiter = '\t'

delimiter = str(codecs.decode(delimiter, "unicode_escape"))

with open(datafile, 'w', encoding='utf-8') as of:
    writer = csv.writer(of, delimiter=delimiter)
    for pair in qa_pairs:
        writer.writerow(pair)

In [15]:
with open(datafile, 'rb') as f:
    lines = f.readlines()
    
for line in lines[:5]:
    print(line)

b"Can we make this quick?  Roxanne Korrine and Andrew Barrett are having an incredibly horrendous public break- up on the quad.  Again.\tWell, I thought we'd start with pronunciation, if that's okay with you.\r\n"
b"Well, I thought we'd start with pronunciation, if that's okay with you.\tNot the hacking and gagging and spitting part.  Please.\r\n"
b"Not the hacking and gagging and spitting part.  Please.\tOkay... then how 'bout we try out some French cuisine.  Saturday?  Night?\r\n"
b"You're asking me out.  That's so cute. What's your name again?\tForget it.\r\n"
b"No, no, it's my fault -- we didn't have a proper introduction ---\tCameron.\r\n"


In [16]:
pad_token = 0
s_token = 1
e_token = 2

class vocabulary:
    def __init__(self, name):
        self.name = name
        self.word2idx = {}
        self.word2count = {}
        self.idx2word = {pad_token: 'pad', s_token: 'sos', e_token: 'eos'}
        self.num_words = 3
        
    def addWord(self, word):
        if word not in self.word2idx:
            self.word2idx[word] = self.num_words
            self.word2count[word] = 1
            self.idx2word[self.num_words] = word
            self.num_words += 1
        else:
            self.word2count[word] += 1
            
    def addSentence(self, sentence):
        for word in sentence.split():
            self.addWord(word)
            
    def trim(self, min_count):
        for k, v in self.word2count.items():
            if v <= min_count:
                del self.idx2word[self.word2idx[k]]
                del self.word2idx[k]
#               del self.word2count[k]   
                
        print('Number of words in vocab: {}'.format(len(self.word2idx)))

In [17]:
def normalise(s):
    s = s.lower().strip()
    s = re.sub(r"([.!?])", r" \1", s)
    s = re.sub(r"[^a-zA-Z.!?]+", r" ", s)
    s = re.sub(r"\s+", r" ", s).strip()
    return s

In [18]:
normalise("helloWorld 1234?!!")

'helloworld ? ! !'

In [19]:
datafile = os.path.join('cornell movie-dialogs corpus', 'formatted_movie_lines.txt')

lines = open(datafile, encoding='utf-8').read().strip().split('\n')

pairs = [[normalise(s) for s in pair.split('\t')] for pair in lines]

In [20]:
pairs[:5]

[['can we make this quick ? roxanne korrine and andrew barrett are having an incredibly horrendous public break up on the quad . again .',
  'well i thought we d start with pronunciation if that s okay with you .'],
 ['well i thought we d start with pronunciation if that s okay with you .',
  'not the hacking and gagging and spitting part . please .'],
 ['not the hacking and gagging and spitting part . please .',
  'okay . . . then how bout we try out some french cuisine . saturday ? night ?'],
 ['you re asking me out . that s so cute . what s your name again ?',
  'forget it .'],
 ['no no it s my fault we didn t have a proper introduction', 'cameron .']]

In [21]:
len(pairs)

221282

In [22]:
def filterSents(s, max_len=13, min_len=1):
    newPairs = []
    for pair in pairs:
        boo = len(pair[0].split())< max_len and len(pair[1].split())< max_len
        boo = (len(pair[0].split())> min_len and len(pair[1].split())> min_len) and boo
        if(boo):
            newPairs.append(pair)
            
    return newPairs

In [23]:
newPairs = filterSents(pairs)

In [24]:
len(newPairs)

93385

In [25]:
vocab = vocabulary('cornell movie-dialogs corpus')

In [26]:
for pair in newPairs:
    vocab.addSentence(pair[0])
    vocab.addSentence(pair[1])

print('Number of words', vocab.num_words)

Number of words 23621


In [27]:
def removeRare(vocab, pairs, thresh=3):
    vocab.trim(thresh)
    final_pairs = []
    
    for i, pair in enumerate(pairs):
        flag = 1
        for word in pair[0].split():
            if word not in vocab.word2idx:
#                 pairs.remove(pairs[i])
                flag = 0
        for word in pair[1].split():
            if word not in vocab.word2idx:
#                 pairs.remove(pairs[i])
                flag = 0
        if flag==1:
            final_pairs.append(pair)
                
    return final_pairs

In [28]:
l = ['l','x','y']
l.remove('l')
print(l)

['x', 'y']


In [29]:
final_pairs = removeRare(vocab, newPairs)

Number of words in vocab: 8997


In [30]:
print(final_pairs[:5])

[['gosh if only we could find kat a boyfriend . . .', 'let me see what i can do .'], ['that s because it s such a nice one .', 'forget french .'], ['there .', 'where ?'], ['you have my word . as a gentleman', 'you re sweet .'], ['hi .', 'looks like things worked out tonight huh ?']]


In [31]:
len(final_pairs)

75020

## Preparing our data

In [32]:
def sent_idxs(vocab, sent):
    return [vocab.word2idx[word] for word in sent.split()] + [e_token]

In [33]:
print(final_pairs[1][1], sent_idxs(vocab, final_pairs[1][1]))

forget french . [43, 44, 12, 2]


In [34]:
inp = []
out = []

for pair in final_pairs[:10]:
    inp.append(pair[0])
    out.append(pair[1])
    
indxs = [sent_idxs(vocab, sentence) for sentence in inp]

In [35]:
print(inp[:5],indxs[:5])

['gosh if only we could find kat a boyfriend . . .', 'that s because it s such a nice one .', 'there .', 'you have my word . as a gentleman', 'hi .'] [[3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 12, 12, 2], [36, 37, 38, 39, 37, 40, 10, 41, 42, 12, 2], [53, 12, 2], [30, 59, 26, 60, 12, 61, 10, 62, 2], [64, 12, 2]]


In [36]:
def zeroPadding(l, val=0):
    return list(itertools.zip_longest(*l, fillvalue=val))

In [37]:
test_res = zeroPadding(indxs)

In [38]:
test_res

[(3, 36, 53, 30, 64, 17, 52, 19, 17, 16),
 (4, 37, 12, 59, 12, 79, 86, 30, 93, 96),
 (5, 38, 2, 26, 2, 29, 12, 91, 30, 97),
 (6, 39, 0, 60, 0, 2, 12, 82, 94, 29),
 (7, 37, 0, 12, 0, 0, 12, 24, 95, 2),
 (8, 40, 0, 61, 0, 0, 2, 92, 82, 0),
 (9, 10, 0, 10, 0, 0, 0, 29, 34, 0),
 (10, 41, 0, 62, 0, 0, 0, 2, 96, 0),
 (11, 42, 0, 2, 0, 0, 0, 0, 97, 0),
 (12, 12, 0, 0, 0, 0, 0, 0, 98, 0),
 (12, 2, 0, 0, 0, 0, 0, 0, 12, 0),
 (12, 0, 0, 0, 0, 0, 0, 0, 2, 0),
 (2, 0, 0, 0, 0, 0, 0, 0, 0, 0)]

In [39]:
def binaryMatrix(l):
    m = []
    for i, seq in enumerate(l):
        m.append([])
        for token in seq:
            if token==pad_token:
                m[i].append(0)
            else:
                m[i].append(1)
                
    return m

In [40]:
test_b = binaryMatrix(test_res)

In [41]:
print(test_b)

[[1, 1, 1, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1], [1, 1, 0, 1, 0, 1, 1, 1, 1, 1], [1, 1, 0, 1, 0, 0, 1, 1, 1, 1], [1, 1, 0, 1, 0, 0, 1, 1, 1, 0], [1, 1, 0, 1, 0, 0, 0, 1, 1, 0], [1, 1, 0, 1, 0, 0, 0, 1, 1, 0], [1, 1, 0, 1, 0, 0, 0, 0, 1, 0], [1, 1, 0, 0, 0, 0, 0, 0, 1, 0], [1, 1, 0, 0, 0, 0, 0, 0, 1, 0], [1, 0, 0, 0, 0, 0, 0, 0, 1, 0], [1, 0, 0, 0, 0, 0, 0, 0, 0, 0]]


In [42]:
def inputs(l, voc):
    indxs_batch = [sent_idxs(vocab, sent) for sent in l]
    lengths = torch.tensor([len(indexes) for indexes in indxs_batch])
    padded = zeroPadding(indxs_batch)
    pad_tens = torch.LongTensor(padded)
    
    return pad_tens, lengths

In [43]:
def outputs(l, voc):
    indxs_batch = [sent_idxs(vocab, sent) for sent in l]
    max_len = max([len(indexes) for indexes in indxs_batch])
    padded = zeroPadding(indxs_batch)
    pad_tens = torch.LongTensor(padded)
    mask = binaryMatrix(padded)
    mask = torch.ByteTensor(mask)
    
    return pad_tens, mask, max_len

In [44]:
def batch2data(vocab, pair_batch):
    pair_batch.sort(key=lambda x: len(x[0].split()), reverse=True)
    input_batch, output_batch = [], []
    for pair in pair_batch:
        input_batch.append(pair[0])
        output_batch.append(pair[1])
    inp, lengths = inputs(input_batch, vocab)
    out, mask, max_target_len = outputs(output_batch, vocab)
    
    return inp, lengths, out, mask, max_target_len

In [45]:
small_batch = 5
batches = batch2data(vocab, [random.choice(final_pairs) for _ in range(small_batch)])
inp, lengths, out, mask, max_target_len = batches

print('Input: ', inp)
print('Lengths: ', lengths)
print('Output: ', out)
print('Mask: ', mask)
print('Max Len: ', max_target_len)

Input:  tensor([[   17,   571,    45,    16, 10439],
        [  166,    30,    37,    29,  3323],
        [  212,   881,   112,   109,    12],
        [   82,    12,  2765,     2,     2],
        [  854,    96,  7935,     0,     0],
        [   30,   191,    29,     0,     0],
        [  199,    12,     2,     0,     0],
        [  108,     2,     0,     0,     0],
        [  360,     0,     0,     0,     0],
        [  126,     0,     0,     0,     0],
        [   12,     0,     0,     0,     0],
        [    2,     0,     0,     0,     0]])
Lengths:  tensor([12,  8,  7,  4,  4])
Output:  tensor([[ 146,   18,   80,  116, 3323],
        [  82,   17, 1779,   37,   12],
        [  95,   15,  344,  818,   16],
        [  96,   30,   12, 2290,   37],
        [ 510,  288,  139,  509,  116],
        [ 711,   29,   34,   12,  319],
        [ 706,    2, 7999,   13,   82],
        [  12,    0,   12,   37,   90],
        [   2,    0,   66,   19,   33],
        [   0,    0,  179,  267,  904],
   

## Encoder

In [46]:
class EncoderRNN(nn.Module): # nn.Module contain PyTorch's neural network objects
    def __init__(self, hidden_size, embedding, n_layers=1, dropout=0):
        super(EncoderRNN, self).__init__()
        self.n_layers = n_layers
        self.hidden_size = hidden_size # hidden_size is the number of neurons in the hidden layer
        self.embedding = embedding
        
        self.gru = nn.GRU(hidden_size, hidden_size, n_layers, dropout=(0 if n_layers == 1 else dropout), bidirectional = True)
        
        # Input Size and Hidden Size are both set as 'hidden_size' because our input size is a word embedding 
        # with number of features == hidden_size
        
    def forward(self, input_seq, input_lengths, hidden=None): # Overwriting the default forward prop function
        # input_seq: (max_len, batch_size)
        # input_length: the list containing the length of sentences in the batch
        # output_seq: (seq_len, batch_size, hidden_size*n_directions)
        # hidden_state: (n_layers*n_directions, batch_size, hidden_size)
        
        embedded = self.embedding(input_seq)
        
        packed = torch.nn.utils.rnn.pack_padded_sequence(embedded, input_lengths)
        
        outputs, hidden = self.gru(packed, hidden)
        
        outputs, _ = torch.nn.utils.rnn.pad_packed_sequence(outputs)
        # Sum of bidirectional GRU Output:
        outputs = outputs[:, :, :self.hidden_size] + outputs[:, :, self.hidden_size:]
        
        return outputs, hidden
    
    # Remember that the shape of a torch tensor is (no. of channels, rows, columns)

## Decoder and Attention

In [47]:
class Attention(nn.Module):
    def __init__(self, method, hidden_size):
        super(Attention, self).__init__()
        self.method = method
        self.hidden_size = hidden_size
        
    def ele_prod(self, emb_out, dec_hidd):
        return torch.sum(emb_out*dec_hidd, dim = 2) # Taking element-wise product and then taking sum across the columns
    
    def forward(self, emb_out, dec_hidd):
        attn_energies = self.ele_prod(emb_out, dec_hidd) # (max_len, batch_size, hidden_size)*(1, batch_size, hidden_size) = (max_len, batch_size, hidden_size)
        # attn_energies = (max_len, batch_size)
        attn_energies = attn_energies.t() # (batch_size, max_len)
        
        return f.softmax(attn_energies, dim=1).unsqueeze(1) # (batch_size, 1, max_len)

In [48]:
class DecoderRNN(nn.Module):
    def __init__(self, attn_model, embedding, hidden_size, output_size, n_layers = 1, dropout = 0.1):
        super(DecoderRNN, self).__init__()
        self.attn_model = attn_model
        self.embedding = embedding
        self.hidden_size = hidden_size
        self.output_size = output_size
        self.n_layers = n_layers
        self.dropout = dropout
        
        self.embedding_dropout = nn.Dropout(dropout)
        self.gru = nn.GRU(hidden_size, hidden_size, n_layers, dropout=(0 if n_layers==0 else dropout))
        self.concat = nn.Linear(hidden_size*2, hidden_size)
        self.out = nn.Linear(hidden_size, output_size)
        
        self.attn = Attention(attn_model, hidden_size)
                          

    def forward(self, input_step, last_hidden, enc_out):
        embedded = self.embedding(input_step)
        embedded = self.embedding_dropout(embedded)
        
        rnn_output, hidden = self.gru(embedded, last_hidden)
        
        attn_weights = self.attn(rnn_output, enc_out)
        context = attn_weights.bmm(enc_out.transpose(0,1))
        
        rnn_output = rnn_output.squeeze(0)
        
        context = context.squeeze(1)
        
        concat_inp = torch.cat((rnn_output, context), 1)
        concat_out = torch.tanh(self.concat(concat_inp))
        
        output = self.out(concat_out)
        output = f.softmax(output, dim=1)
        
        return output, hidden
                 

In [49]:
def maskedNLL(dec_out, target, mask):
    nTotal = mask.sum()
    target = target.view(-1,1)
    
    gathered_loss = torch.gather(dec_out, 1, target)
    crossEnt = -torch.log(gathered_loss)
    
    loss = crossEnt.masked_select(mask)
    
    loss = loss.mean().to(device)
    return loss, nTotal.item()

In [50]:
import torch.nn.functional as f

In [51]:
small_batch = 5
batches = batch2data(vocab, [random.choice(final_pairs) for _ in range(small_batch)])
inp, lengths, out, mask, max_target_len = batches

print('Input: ', inp)
print('Lengths: ', lengths)
print('Output: ', out)
print('Mask: ', mask)
print('Max Len: ', max_target_len)

# Defining the parameters:
hidden_size = 500
n_enc_layers = 2
n_dec_layers = 2
dropout = 0.1
attn_model = 'dot'
embedding = nn.Embedding(vocab.num_words, hidden_size)

# Defining Encoder and Decoder
encoder = EncoderRNN(hidden_size, embedding, n_enc_layers, dropout)
decoder = DecoderRNN(attn_model, embedding, hidden_size, vocab.num_words, n_dec_layers, dropout)

encoder.to(device)
decoder.to(device)

encoder.train()
decoder.train()

# Zero gradients
encoder_opt = optim.Adam(encoder.parameters(), lr=0.0001)
decoder_opt = optim.Adam(decoder.parameters(), lr=0.0001)
encoder_opt.zero_grad()
decoder_opt.zero_grad()

# Set device options
input_variable = inp.to(device)
lengths = lengths.to(device)
target_variable = out.to(device)
mask = mask.to(device)

# Initialize variables
loss = 0
print_losses = []
n_totals = 0

# Forward pass through encoder
encoder_outputs, encoder_hidden = encoder(input_variable, lengths)

# Create initial decoder input (start with SOS tokens for each sentence)
decoder_input = torch.LongTensor([[s_token for _ in range(small_batch)]])
decoder_input = decoder_input.to(device)

# Set initial decoder hidden state to the encoder's final hidden state
decoder_hidden = encoder_hidden[:decoder.n_layers]
print("Initial Decoder Hidden Shape: ", decoder_hidden.shape)
print('\n')

for t in range(max_target_len):
    print('---------------------------------------- \n')
    decoder_output, decoder_hidden = decoder(decoder_input, decoder_hidden, encoder_outputs)
    # Teacher forcing: next input is current target
    print("Decoder Output Shape:", decoder_input.shape)
    print("Decoder Hidden Shape:", decoder_hidden.shape)
    decoder_input = target_variable[t].view(1, -1)
    print("Target Variable 1: ", target_variable[t])
    print("Target Variable 1 Shape: ", target_variable[t].shape)
    print("Decoder Input Shape: ", decoder_input.shape)
    # Calculate and accumulate loss
    print("Mask at the present time step: ", mask[t])
    mask_loss, nTotal = maskedNLL(decoder_output, target_variable[t], mask[t])
    loss += mask_loss
    print(loss,'\n')
    print_losses.append(mask_loss.item() * nTotal)
    n_totals += nTotal
    print(print_losses)
    print('---------------------------------------- \n')


Input:  tensor([[ 729,  162, 1158,   17,  185],
        [4478,   30, 5647,   19,  107],
        [ 109,   34,   39,  225,   12],
        [  45, 7249,   37, 1004,    2],
        [  34,  101, 1421,  570,    0],
        [ 729, 9432, 1421,   12,    0],
        [ 162, 2580, 2199,    2,    0],
        [ 303,   29,   12,    0,    0],
        [  29,    2,    2,    0,    0],
        [   2,    0,    0,    0,    0]])
Lengths:  tensor([10,  9,  9,  7,  4])
Output:  tensor([[  394,     4,   202,    30,  1215],
        [  996,    53,  1158,   219,    12],
        [   12,   229,    58,    14,   175],
        [  394,    10,    37,   106,   208],
        [   12,  7249, 13126,  2667,    30],
        [   45,    17,    29,    12,    72],
        [  162,   307,    45,     2,    39],
        [   30,    57,   162,     0,    12],
        [   29,    34,    30,     0,     2],
        [    2,    42,    29,     0,     0],
        [    0,    12,     2,     0,     0],
        [    0,     2,     0,     0,     0]])
Ma



In [52]:
def train(input_variable, lengths, target_variable, mask, max_target_len, encoder, decoder, embedding, encoder_optimizer, decoder_optimizer, batch_size, clip, max_length=13):
    
    encoder_optimizer.zero_grad()
    decoder_optimizer.zero_grad()

    # Set device options
    input_variable = input_variable.to(device)
    lengths = lengths.to(device)
    target_variable = target_variable.to(device)
    mask = mask.to(device)

    # Initialize variables
    loss = 0
    print_losses = []
    n_totals = 0

    # Forward pass through encoder
    encoder_outputs, encoder_hidden = encoder(input_variable, lengths)

    # Create initial decoder input (start with SOS tokens for each sentence)
    decoder_input = torch.LongTensor([[s_token for _ in range(batch_size)]])
    decoder_input = decoder_input.to(device)

    # Set initial decoder hidden state to the encoder's final hidden state
    decoder_hidden = encoder_hidden[:decoder.n_layers]

    # Determine if we are using teacher forcing this iteration
    use_teacher_forcing = True if random.random() < teacher_forcing_ratio else False

    # Forward batch of sequences through decoder one time step at a time
    if use_teacher_forcing:
        for t in range(max_target_len):
            decoder_output, decoder_hidden = decoder(
                decoder_input, decoder_hidden, encoder_outputs
            )
            # Teacher forcing: next input is current target
            decoder_input = target_variable[t].view(1, -1)
            # Calculate and accumulate loss
            mask_loss, nTotal = maskedNLL(decoder_output, target_variable[t], mask[t])
            loss += mask_loss
            print_losses.append(mask_loss.item() * nTotal)
            n_totals += nTotal
    else:
        for t in range(max_target_len):
            decoder_output, decoder_hidden = decoder(
                decoder_input, decoder_hidden, encoder_outputs
            )
            # No teacher forcing: next input is decoder's own current output
            _, topi = decoder_output.topk(1)
            decoder_input = torch.LongTensor([[topi[i][0] for i in range(batch_size)]])
            decoder_input = decoder_input.to(device)
            # Calculate and accumulate loss
            mask_loss, nTotal = maskedNLL(decoder_output, target_variable[t], mask[t])
            loss += mask_loss
            print_losses.append(mask_loss.item() * nTotal)
            n_totals += nTotal

    # Perform backpropatation
    loss.backward()

    # Clip gradients: gradients are modified in place
    _ = nn.utils.clip_grad_norm_(encoder.parameters(), clip)
    _ = nn.utils.clip_grad_norm_(decoder.parameters(), clip)

    # Adjust model weights
    encoder_optimizer.step()
    decoder_optimizer.step()

    return sum(print_losses) / n_totals

In [53]:
def trainIters(model_name, voc, pairs, encoder, decoder, encoder_optimizer, decoder_optimizer, embedding, encoder_n_layers, decoder_n_layers, save_dir, n_iteration, batch_size, print_every, save_every, clip, corpus_name, loadFilename):

    # Load batches for each iteration
    training_batches = [batch2data(voc, [random.choice(pairs) for _ in range(batch_size)]) for _ in range(n_iteration)]

    # Initializations
    print('Initializing ...')
    start_iteration = 1
    print_loss = 0
    if loadFilename:
        start_iteration = checkpoint['iteration'] + 1

    # Training loop
    print("Training...")
    for iteration in range(start_iteration, n_iteration + 1):
        training_batch = training_batches[iteration - 1]
        # Extract fields from batch
        input_variable, lengths, target_variable, mask, max_target_len = training_batch

        # Run a training iteration with batch
        loss = train(input_variable, lengths, target_variable, mask, max_target_len, encoder,
                     decoder, embedding, encoder_optimizer, decoder_optimizer, batch_size, clip)
        print_loss += loss

        # Print progress
        if iteration % print_every == 0:
            print_loss_avg = print_loss / print_every
            print("Iteration: {}; Percent complete: {:.1f}%; Average loss: {:.4f}".format(iteration, iteration / n_iteration * 100, print_loss_avg))
            print_loss = 0

        # Save checkpoint
        if (iteration % save_every == 0):
            directory = os.path.join(save_dir, model_name, corpus_name, '{}-{}_{}'.format(encoder_n_layers, decoder_n_layers, hidden_size))
            if not os.path.exists(directory):
                os.makedirs(directory)
            torch.save({
                'iteration': iteration,
                'en': encoder.state_dict(),
                'de': decoder.state_dict(),
                'en_opt': encoder_optimizer.state_dict(),
                'de_opt': decoder_optimizer.state_dict(),
                'loss': loss,
                'voc_dict': voc.__dict__,
                'embedding': embedding.state_dict()
            }, os.path.join(directory, '{}_{}.tar'.format(iteration, 'checkpoint')))

In [54]:
class GreedySearchDecoder(nn.Module):
    def __init__(self, encoder, decoder):
        super(GreedySearchDecoder, self).__init__()
        self.encoder = encoder
        self.decoder = decoder

    def forward(self, input_seq, input_length, max_length):
        # Forward input through encoder model
        encoder_outputs, encoder_hidden = self.encoder(input_seq, input_length)
        # Prepare encoder's final hidden layer to be first hidden input to the decoder
        decoder_hidden = encoder_hidden[:decoder.n_layers]
        # Initialize decoder input with SOS_token
        decoder_input = torch.ones(1, 1, device=device, dtype=torch.long) * s_token
        # Initialize tensors to append decoded words to
        all_tokens = torch.zeros([0], device=device, dtype=torch.long)
        all_scores = torch.zeros([0], device=device)
        # Iteratively decode one word token at a time
        for _ in range(max_length):
            # Forward pass through decoder
            decoder_output, decoder_hidden = self.decoder(decoder_input, decoder_hidden, encoder_outputs)
            # Obtain most likely word token and its softmax score
            decoder_scores, decoder_input = torch.max(decoder_output, dim=1)
            # Record token and score
            all_tokens = torch.cat((all_tokens, decoder_input), dim=0)
            all_scores = torch.cat((all_scores, decoder_scores), dim=0)
            # Prepare current token to be next decoder input (add a dimension)
            decoder_input = torch.unsqueeze(decoder_input, 0)
        # Return collections of word tokens and scores
        return all_tokens, all_scores

In [55]:
def evaluate(encoder, decoder, searcher, voc, sentence, max_length=13):
    ### Format input sentence as a batch
    # words -> indexes
    indexes_batch = [sent_idxs(voc, sentence)]
    # Create lengths tensor
    lengths = torch.tensor([len(indexes) for indexes in indexes_batch])
    # Transpose dimensions of batch to match models' expectations
    input_batch = torch.LongTensor(indexes_batch).transpose(0, 1)
    # Use appropriate device
    input_batch = input_batch.to(device)
    lengths = lengths.to(device)
    # Decode sentence with searcher
    tokens, scores = searcher(input_batch, lengths, max_length)
    # indexes -> words
    decoded_words = [voc.idx2word[token.item()] for token in tokens]
    return decoded_words


def evaluateInput(encoder, decoder, searcher, voc):
    input_sentence = ''
    while(1):
        try:
            # Get input sentence
            input_sentence = input('> ')
            # Check if it is quit case
            if input_sentence == 'q' or input_sentence == 'quit': 
                break
            # Normalize sentence
            input_sentence = normalise(input_sentence)
            # Evaluate sentence
            output_words = evaluate(encoder, decoder, searcher, voc, input_sentence)
            # Format and print response sentence
            output_words[:] = [x for x in output_words if (x != 'eos' and x != 'pad')]
            print('Bot: ', ' '.join(output_words))

        except KeyError:
            print("Error: Encountered unknown word.")


In [56]:
# Configure models
model_name = 'cb_model'
attn_model = 'dot'
#attn_model = 'general'
#attn_model = 'concat'
hidden_size = 500
encoder_n_layers = 2
decoder_n_layers = 2
dropout = 0.1
batch_size = 64

save_dir = 'model_weights'
corpus_name = 'cornell_movie'

# Set checkpoint to load from; set to None if starting from scratch
loadFilename = None
checkpoint_iter = 4000
#loadFilename = os.path.join(save_dir, model_name, corpus_name,
#                            '{}-{}_{}'.format(encoder_n_layers, decoder_n_layers, hidden_size),
#                            '{}_checkpoint.tar'.format(checkpoint_iter))


# Load model if a loadFilename is provided
if loadFilename:
    # If loading on same machine the model was trained on
    checkpoint = torch.load(loadFilename)
    # If loading a model trained on GPU to CPU
    #checkpoint = torch.load(loadFilename, map_location=torch.device('cpu'))
    encoder_sd = checkpoint['en']
    decoder_sd = checkpoint['de']
    encoder_optimizer_sd = checkpoint['en_opt']
    decoder_optimizer_sd = checkpoint['de_opt']
    embedding_sd = checkpoint['embedding']
    voc.__dict__ = checkpoint['voc_dict']


print('Building encoder and decoder ...')
# Initialize word embeddings
embedding = nn.Embedding(vocab.num_words, hidden_size)
if loadFilename:
    embedding.load_state_dict(embedding_sd)
# Initialize encoder & decoder models
encoder = EncoderRNN(hidden_size, embedding, encoder_n_layers, dropout)
decoder = DecoderRNN(attn_model, embedding, hidden_size, vocab.num_words, decoder_n_layers, dropout)
if loadFilename:
    encoder.load_state_dict(encoder_sd)
    decoder.load_state_dict(decoder_sd)
# Use appropriate device
encoder = encoder.to(device)
decoder = decoder.to(device)
print('Models built and ready to go!')

Building encoder and decoder ...
Models built and ready to go!


In [57]:
import warnings
warnings.filterwarnings('ignore')

In [61]:
# Configure training/optimization
clip = 50.0
teacher_forcing_ratio = 1.0
learning_rate = 0.0001
decoder_learning_ratio = 5.0
n_iteration = 500
print_every = 1
save_every = 500

# Ensure dropout layers are in train mode
encoder.train()
decoder.train()

# Initialize optimizers
print('Building optimizers ...')
encoder_optimizer = optim.Adam(encoder.parameters(), lr=learning_rate)
decoder_optimizer = optim.Adam(decoder.parameters(), lr=learning_rate * decoder_learning_ratio)
if loadFilename:
    encoder_optimizer.load_state_dict(encoder_optimizer_sd)
    decoder_optimizer.load_state_dict(decoder_optimizer_sd)

# If you have cuda, configure cuda to call
for state in encoder_optimizer.state.values():
    for k, v in state.items():
        if isinstance(v, torch.Tensor):
            state[k] = v.cuda()

for state in decoder_optimizer.state.values():
    for k, v in state.items():
        if isinstance(v, torch.Tensor):
            state[k] = v.cuda()

# Run training iterations
print("Starting Training!")
trainIters(model_name, vocab, final_pairs, encoder, decoder, encoder_optimizer, decoder_optimizer,
           embedding, encoder_n_layers, decoder_n_layers, save_dir, n_iteration, batch_size,
           print_every, save_every, clip, corpus_name, loadFilename)

Building optimizers ...
Starting Training!
Initializing ...
Training...
Iteration: 1; Percent complete: 0.2%; Average loss: 3.0061
Iteration: 2; Percent complete: 0.4%; Average loss: 3.0825
Iteration: 3; Percent complete: 0.6%; Average loss: 3.2505
Iteration: 4; Percent complete: 0.8%; Average loss: 3.1688
Iteration: 5; Percent complete: 1.0%; Average loss: 3.1822
Iteration: 6; Percent complete: 1.2%; Average loss: 3.2671
Iteration: 7; Percent complete: 1.4%; Average loss: 3.2452
Iteration: 8; Percent complete: 1.6%; Average loss: 3.3508
Iteration: 9; Percent complete: 1.8%; Average loss: 3.2280
Iteration: 10; Percent complete: 2.0%; Average loss: 3.0722
Iteration: 11; Percent complete: 2.2%; Average loss: 3.0252
Iteration: 12; Percent complete: 2.4%; Average loss: 3.1291
Iteration: 13; Percent complete: 2.6%; Average loss: 3.2354
Iteration: 14; Percent complete: 2.8%; Average loss: 3.3503
Iteration: 15; Percent complete: 3.0%; Average loss: 3.2176
Iteration: 16; Percent complete: 3.2%

Iteration: 135; Percent complete: 27.0%; Average loss: 3.0716
Iteration: 136; Percent complete: 27.2%; Average loss: 3.1019
Iteration: 137; Percent complete: 27.4%; Average loss: 3.1014
Iteration: 138; Percent complete: 27.6%; Average loss: 3.2437
Iteration: 139; Percent complete: 27.8%; Average loss: 3.0540
Iteration: 140; Percent complete: 28.0%; Average loss: 3.0099
Iteration: 141; Percent complete: 28.2%; Average loss: 2.9319
Iteration: 142; Percent complete: 28.4%; Average loss: 3.0535
Iteration: 143; Percent complete: 28.6%; Average loss: 2.9843
Iteration: 144; Percent complete: 28.8%; Average loss: 3.3624
Iteration: 145; Percent complete: 29.0%; Average loss: 2.9518
Iteration: 146; Percent complete: 29.2%; Average loss: 3.4636
Iteration: 147; Percent complete: 29.4%; Average loss: 2.9907
Iteration: 148; Percent complete: 29.6%; Average loss: 3.4317
Iteration: 149; Percent complete: 29.8%; Average loss: 2.8928
Iteration: 150; Percent complete: 30.0%; Average loss: 3.2364
Iteratio

Iteration: 268; Percent complete: 53.6%; Average loss: 3.1332
Iteration: 269; Percent complete: 53.8%; Average loss: 3.1626
Iteration: 270; Percent complete: 54.0%; Average loss: 2.9570
Iteration: 271; Percent complete: 54.2%; Average loss: 2.9705
Iteration: 272; Percent complete: 54.4%; Average loss: 2.9556
Iteration: 273; Percent complete: 54.6%; Average loss: 3.2195
Iteration: 274; Percent complete: 54.8%; Average loss: 2.9680
Iteration: 275; Percent complete: 55.0%; Average loss: 3.1331
Iteration: 276; Percent complete: 55.2%; Average loss: 2.9338
Iteration: 277; Percent complete: 55.4%; Average loss: 3.0770
Iteration: 278; Percent complete: 55.6%; Average loss: 3.0924
Iteration: 279; Percent complete: 55.8%; Average loss: 3.2919
Iteration: 280; Percent complete: 56.0%; Average loss: 2.8221
Iteration: 281; Percent complete: 56.2%; Average loss: 3.0760
Iteration: 282; Percent complete: 56.4%; Average loss: 3.3516
Iteration: 283; Percent complete: 56.6%; Average loss: 2.8948
Iteratio

Iteration: 401; Percent complete: 80.2%; Average loss: 3.0397
Iteration: 402; Percent complete: 80.4%; Average loss: 2.8966
Iteration: 403; Percent complete: 80.6%; Average loss: 3.1734
Iteration: 404; Percent complete: 80.8%; Average loss: 2.9103
Iteration: 405; Percent complete: 81.0%; Average loss: 3.0778
Iteration: 406; Percent complete: 81.2%; Average loss: 3.0499
Iteration: 407; Percent complete: 81.4%; Average loss: 2.7739
Iteration: 408; Percent complete: 81.6%; Average loss: 2.7251
Iteration: 409; Percent complete: 81.8%; Average loss: 2.7543
Iteration: 410; Percent complete: 82.0%; Average loss: 2.8402
Iteration: 411; Percent complete: 82.2%; Average loss: 3.2567
Iteration: 412; Percent complete: 82.4%; Average loss: 2.8393
Iteration: 413; Percent complete: 82.6%; Average loss: 3.0754
Iteration: 414; Percent complete: 82.8%; Average loss: 3.2501
Iteration: 415; Percent complete: 83.0%; Average loss: 3.0644
Iteration: 416; Percent complete: 83.2%; Average loss: 2.7800
Iteratio

In [62]:
torch.save(encoder, 'model_weights/enc_weights500')
torch.save(decoder, 'model_weights/dec_weights500')

In [65]:
# Set dropout layers to eval mode
encoder.eval()
decoder.eval()

# Initialize search module
searcher = GreedySearchDecoder(encoder, decoder)

evaluateInput(encoder, decoder, searcher, vocab)

> hi
Bot:  hi .
> how are you?
Bot:  i m not going to go .
> go where?
Bot:  i m not going to go .
> what?
Bot:  you re not going to go .
> i am not
Bot:  you re not going to go .
> okay
Bot:  i m not going to go .
> okay
Bot:  i m not going to go .
> do not go
Bot:  i m not going to go .
> that is good
Bot:  i m not going to go .
> say something good
Bot:  i m not going to go .
> else
Bot:  i m not going to go .
> bye
Bot:  i m not going to go .
> byeee
Error: Encountered unknown word.
> byeeeee
Error: Encountered unknown word.
> goodbye
Bot:  i m not going to go .
> quit
