In [46]:
from Vocabulary import normalizeString, vocabulary, unicodetoascii
import itertools
import torch
import random

In [24]:
PAD = 0
SOS = 1
EOS = 2

In [12]:
datafile = "data/formatted_movie_lines.txt"

In [13]:
lines = open(datafile, encoding="utf-8").read().strip().split("\n\n")
pairs = [[normalizeString(s) for s in pair.split("\t")] for pair in lines]

In [14]:
len(pairs)

221282

In [15]:
corpus = vocabulary("Cornell Movie Dialogues")

In [16]:
#a bit more cleaning, so well remove any sentances that are too long
def filterpair(p, max_length=10):
    return len(p[0].split()) <= max_length and len(p[1].split()) <= max_length

pairs = [pair for pair in pairs if filterpair(pair)]

In [17]:
len(pairs)

75026

In [18]:
pairs[:10]

[['that s because it s such a nice one .', 'forget french .'],
 ['there .', 'where ?'],
 ['you have my word . as a gentleman', 'you re sweet .'],
 ['hi .', 'looks like things worked out tonight huh ?'],
 ['you know chastity ?', 'i believe we share an art instructor'],
 ['have fun tonight ?', 'tons'],
 ['well no . . .', 'then that s all you had to say .'],
 ['then that s all you had to say .', 'but'],
 ['but', 'you always been this selfish ?'],
 ['do you listen to this crap ?', 'what crap ?']]

In [19]:
def trimRareWords(vocab, pairs, min_count = 3):
    
    vocab.trim(min_count=min_count)
    keep_pairs = []
    for pair in pairs:
        input_ = pair[0]
        reply_ = pair[1]
        keepinput, keepreply = True, True
        for word in input_.split(" "):
            if word not in vocab.word2index:
                keepinput = False
                break
        for word in reply_.split(" "):
            if word not in vocab.word2index:
                keepreply = False
                break
        if keepinput and keepreply:
            keep_pairs.append(pair)

    print(f"After trimming kept {len(keep_pairs)} out of {len(pairs)}")
    
    return keep_pairs

In [20]:
for pair in pairs:
    corpus.addSentance(pair[0])
    corpus.addSentance(pair[1])

print(corpus.num_words)

20093


In [21]:
cleaned_pairs = trimRareWords(corpus, pairs)

After trimming kept 62810 out of 75026


In [22]:
cleaned_pairs[:10]

[['that s because it s such a nice one .', 'forget french .'],
 ['there .', 'where ?'],
 ['you have my word . as a gentleman', 'you re sweet .'],
 ['hi .', 'looks like things worked out tonight huh ?'],
 ['have fun tonight ?', 'tons'],
 ['well no . . .', 'then that s all you had to say .'],
 ['then that s all you had to say .', 'but'],
 ['but', 'you always been this selfish ?'],
 ['do you listen to this crap ?', 'what crap ?'],
 ['what good stuff ?', 'the real you .']]

In [27]:
def indexfromSentance(vocab:vocabulary, sentance:str):
    return [vocab.word2index[word] for word in sentance.split(" ")] + [EOS]

In [29]:
indexfromSentance(corpus, cleaned_pairs[1][0])

[14, 11, 2]

In [31]:
inputs = []
for pair in cleaned_pairs[:10]:
    inputs.append(indexfromSentance(corpus, pair[0]))

In [33]:
def zeropading(l, fillvalue = 0):
    return list(itertools.zip_longest(*l, fillvalue=fillvalue))

In [34]:
def binarymatrix(l, value=0):
    m = []
    for i, seq in enumerate(l):
        m.append([])
        for token in seq:
            if token == PAD:
                m[i].append(0)
            else:
                m[i].append(1)

    return m

In [35]:
l = zeropading(inputs)

In [37]:
inputs

[[3, 4, 5, 6, 4, 7, 8, 9, 10, 11, 2],
 [14, 11, 2],
 [17, 18, 19, 20, 11, 21, 8, 22, 2],
 [25, 11, 2],
 [18, 40, 31, 16, 2],
 [42, 43, 11, 11, 11, 2],
 [44, 3, 4, 45, 17, 46, 47, 48, 11, 2],
 [49, 2],
 [54, 17, 55, 47, 52, 56, 16, 2],
 [57, 58, 59, 16, 2]]

In [36]:
l

[(3, 14, 17, 25, 18, 42, 44, 49, 54, 57),
 (4, 11, 18, 11, 40, 43, 3, 2, 17, 58),
 (5, 2, 19, 2, 31, 11, 4, 0, 55, 59),
 (6, 0, 20, 0, 16, 11, 45, 0, 47, 16),
 (4, 0, 11, 0, 2, 11, 17, 0, 52, 2),
 (7, 0, 21, 0, 0, 2, 46, 0, 56, 0),
 (8, 0, 8, 0, 0, 0, 47, 0, 16, 0),
 (9, 0, 22, 0, 0, 0, 48, 0, 2, 0),
 (10, 0, 2, 0, 0, 0, 11, 0, 0, 0),
 (11, 0, 0, 0, 0, 0, 2, 0, 0, 0),
 (2, 0, 0, 0, 0, 0, 0, 0, 0, 0)]

In [38]:
binary = binarymatrix(l)

In [39]:
binary

[[1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
 [1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
 [1, 1, 1, 1, 1, 1, 1, 0, 1, 1],
 [1, 0, 1, 0, 1, 1, 1, 0, 1, 1],
 [1, 0, 1, 0, 1, 1, 1, 0, 1, 1],
 [1, 0, 1, 0, 0, 1, 1, 0, 1, 0],
 [1, 0, 1, 0, 0, 0, 1, 0, 1, 0],
 [1, 0, 1, 0, 0, 0, 1, 0, 1, 0],
 [1, 0, 1, 0, 0, 0, 1, 0, 0, 0],
 [1, 0, 0, 0, 0, 0, 1, 0, 0, 0],
 [1, 0, 0, 0, 0, 0, 0, 0, 0, 0]]

In [42]:
#Returns padded input sequence tensor as well as tensor of lengths for each of the padded seq in the batc
def inputVar(l:list, vocab:vocabulary):
    indexes_batch = [indexfromSentance(vocab, sentance) for sentance in l]
    lengths = torch.tensor([len(index_array) for index_array in indexes_batch])
    padlist = zeropading(indexes_batch)
    padvar = torch.LongTensor(padlist)
    return padvar, lengths

In [44]:
# Returns padded target sequence tensor, padding mask and maax target length
def outputVar(l:list, vocab:vocabulary):
    indexes_batch = [indexfromSentance(vocab, sentance) for sentance in l]
    max_target_len = max([len(index_array) for index_array in indexes_batch])
    padlist = zeropading(indexes_batch)
    mask = binarymatrix(padlist)
    mask = torch.ByteTensor(mask)
    padvar = torch.LongTensor(padlist)
    return padvar, mask, max_target_len

In [45]:
#Prepares the data for training for a given batch of pairs
def batch2traindata(vocab, pair_batch):
    #Sort the question answers pairs in descending order
    pair_batch.sort(key=lambda x:len(x[0].split(" ")), reverse=True)
    input_batch, output_batch = [], []
    for pair in pair_batch:
        input_batch.append(pair[0])
        output_batch.append(pair[1])
    inp, lengths = inputVar(input_batch, vocab)
    output, mask, max_target_len = outputVar(output_batch, vocab)
    return inp, lengths, output, mask, max_target_len

In [47]:
#Validation of preprocessing steps
batch_size = 5
input_seq, lengths, target_seq, target_mask, max_target_length = batch2traindata(corpus, [random.choice(cleaned_pairs) for _ in range(batch_size)])

In [48]:
print(input_seq)

tensor([[ 153,  113,   34,  550, 1483],
        [  34,   34,  108, 6394,   16],
        [ 101,   67,  285,   16,    2],
        [ 102,  882,    6,    2,    0],
        [ 307, 1114,  158,    0,    0],
        [  82,  225,   11,    0,    0],
        [  60,   16,    2,    0,    0],
        [ 246,    2,    0,    0,    0],
        [  11,    0,    0,    0,    0],
        [   2,    0,    0,    0,    0]])


In [49]:
print(lengths)

tensor([10,  8,  7,  4,  3])


In [50]:
print(target_seq)

tensor([[ 860,  344,  327,   66,  266],
        [8058, 1184, 1544, 3183,   32],
        [ 640,   11, 2398,   67,   11],
        [  11,   17,   27, 6263,  143],
        [   2,  543,   93,   73,   10],
        [   0,  183,   11,    2,   16],
        [   0,  522,    2,    0,    2],
        [   0,   47,    0,    0,    0],
        [   0, 1704,    0,    0,    0],
        [   0,   11,    0,    0,    0],
        [   0,    2,    0,    0,    0]])


In [51]:
print(target_mask)

tensor([[1, 1, 1, 1, 1],
        [1, 1, 1, 1, 1],
        [1, 1, 1, 1, 1],
        [1, 1, 1, 1, 1],
        [1, 1, 1, 1, 1],
        [0, 1, 1, 1, 1],
        [0, 1, 1, 0, 1],
        [0, 1, 0, 0, 0],
        [0, 1, 0, 0, 0],
        [0, 1, 0, 0, 0],
        [0, 1, 0, 0, 0]], dtype=torch.uint8)


In [52]:
print(max_target_length)

11
