In [1]:
from collections import Counter
import json
import torch
import torch.nn as nn
from torch.utils.data import Dataset
import math
import torch.nn.functional as F

In [2]:
corpus_movie_conv = 'data/cornell movie-dialogs corpus/movie_conversations.txt'
corpus_movie_lines = 'data/cornell movie-dialogs corpus/movie_lines.txt'
max_len = 25

In [3]:
with open(corpus_movie_conv, 'r', encoding='utf-8', errors='ignore') as c:
    conv = c.readlines()

In [4]:
with open(corpus_movie_lines, 'r', encoding='utf-8', errors='ignore') as l:
    lines = l.readlines()

In [5]:
conv

["u0 +++$+++ u2 +++$+++ m0 +++$+++ ['L194', 'L195', 'L196', 'L197']\n",
 "u0 +++$+++ u2 +++$+++ m0 +++$+++ ['L198', 'L199']\n",
 "u0 +++$+++ u2 +++$+++ m0 +++$+++ ['L200', 'L201', 'L202', 'L203']\n",
 "u0 +++$+++ u2 +++$+++ m0 +++$+++ ['L204', 'L205', 'L206']\n",
 "u0 +++$+++ u2 +++$+++ m0 +++$+++ ['L207', 'L208']\n",
 "u0 +++$+++ u2 +++$+++ m0 +++$+++ ['L271', 'L272', 'L273', 'L274', 'L275']\n",
 "u0 +++$+++ u2 +++$+++ m0 +++$+++ ['L276', 'L277']\n",
 "u0 +++$+++ u2 +++$+++ m0 +++$+++ ['L280', 'L281']\n",
 "u0 +++$+++ u2 +++$+++ m0 +++$+++ ['L363', 'L364']\n",
 "u0 +++$+++ u2 +++$+++ m0 +++$+++ ['L365', 'L366']\n",
 "u0 +++$+++ u2 +++$+++ m0 +++$+++ ['L367', 'L368']\n",
 "u0 +++$+++ u2 +++$+++ m0 +++$+++ ['L401', 'L402', 'L403']\n",
 "u0 +++$+++ u2 +++$+++ m0 +++$+++ ['L404', 'L405', 'L406', 'L407']\n",
 "u0 +++$+++ u2 +++$+++ m0 +++$+++ ['L575', 'L576']\n",
 "u0 +++$+++ u2 +++$+++ m0 +++$+++ ['L577', 'L578']\n",
 "u0 +++$+++ u2 +++$+++ m0 +++$+++ ['L662', 'L663']\n",
 "u0 +++$+++ u2 

In [6]:
lines

['L1045 +++$+++ u0 +++$+++ m0 +++$+++ BIANCA +++$+++ They do not!\n',
 'L1044 +++$+++ u2 +++$+++ m0 +++$+++ CAMERON +++$+++ They do to!\n',
 'L985 +++$+++ u0 +++$+++ m0 +++$+++ BIANCA +++$+++ I hope so.\n',
 'L984 +++$+++ u2 +++$+++ m0 +++$+++ CAMERON +++$+++ She okay?\n',
 "L925 +++$+++ u0 +++$+++ m0 +++$+++ BIANCA +++$+++ Let's go.\n",
 'L924 +++$+++ u2 +++$+++ m0 +++$+++ CAMERON +++$+++ Wow\n',
 "L872 +++$+++ u0 +++$+++ m0 +++$+++ BIANCA +++$+++ Okay -- you're gonna need to learn how to lie.\n",
 'L871 +++$+++ u2 +++$+++ m0 +++$+++ CAMERON +++$+++ No\n',
 'L870 +++$+++ u0 +++$+++ m0 +++$+++ BIANCA +++$+++ I\'m kidding.  You know how sometimes you just become this "persona"?  And you don\'t know how to quit?\n',
 'L869 +++$+++ u0 +++$+++ m0 +++$+++ BIANCA +++$+++ Like my fear of wearing pastels?\n',
 'L868 +++$+++ u2 +++$+++ m0 +++$+++ CAMERON +++$+++ The "real you".\n',
 'L867 +++$+++ u0 +++$+++ m0 +++$+++ BIANCA +++$+++ What good stuff?\n',
 "L866 +++$+++ u2 +++$+++ m0 +++$+++ CAME

In [7]:
lines_dict = {}
for line in lines:
    line = line.split(' +++$+++ ')
    lines_dict[line[0]] = line[-1]

lines_dict

{'L1045': 'They do not!\n',
 'L1044': 'They do to!\n',
 'L985': 'I hope so.\n',
 'L984': 'She okay?\n',
 'L925': "Let's go.\n",
 'L924': 'Wow\n',
 'L872': "Okay -- you're gonna need to learn how to lie.\n",
 'L871': 'No\n',
 'L870': 'I\'m kidding.  You know how sometimes you just become this "persona"?  And you don\'t know how to quit?\n',
 'L869': 'Like my fear of wearing pastels?\n',
 'L868': 'The "real you".\n',
 'L867': 'What good stuff?\n',
 'L866': "I figured you'd get to the good stuff eventually.\n",
 'L865': 'Thank God!  If I had to hear one more story about your coiffure...\n',
 'L864': "Me.  This endless ...blonde babble. I'm like, boring myself.\n",
 'L863': 'What crap?\n',
 'L862': 'do you listen to this crap?\n',
 'L861': 'No...\n',
 'L860': 'Then Guillermo says, "If you go any lighter, you\'re gonna look like an extra on 90210."\n',
 'L699': 'You always been this selfish?\n',
 'L698': 'But\n',
 'L697': "Then that's all you had to say.\n",
 'L696': 'Well, no...\n',
 'L695

In [8]:
lines_dict['L1045']

'They do not!\n'

In [9]:
def remove_punc(string):
    punc = '''!()-[]{};:'"\,<>./?@#$%^&*_~'''
    for char in string:
        if char in punc:
            string = string.replace(char, "")
    return string.lower()

In [10]:
ele = remove_punc(lines_dict['L3317'])
ele

'ill make some for us\n'

In [11]:
pairs = []
for con in conv:
    
    con = eval(con.split(' +++$+++ ')[-1])
    
    for i in range(len(con)):
        qa_pairs = []

        if i==len(con)-1:
            break

        qa_pairs.append(remove_punc(lines_dict[(con[i])].strip()).split()[:max_len])
        qa_pairs.append(remove_punc(lines_dict[(con[i+1])].strip()).split()[:max_len])
        pairs.append(qa_pairs)

pairs

[[['can',
   'we',
   'make',
   'this',
   'quick',
   'roxanne',
   'korrine',
   'and',
   'andrew',
   'barrett',
   'are',
   'having',
   'an',
   'incredibly',
   'horrendous',
   'public',
   'break',
   'up',
   'on',
   'the',
   'quad',
   'again'],
  ['well',
   'i',
   'thought',
   'wed',
   'start',
   'with',
   'pronunciation',
   'if',
   'thats',
   'okay',
   'with',
   'you']],
 [['well',
   'i',
   'thought',
   'wed',
   'start',
   'with',
   'pronunciation',
   'if',
   'thats',
   'okay',
   'with',
   'you'],
  ['not',
   'the',
   'hacking',
   'and',
   'gagging',
   'and',
   'spitting',
   'part',
   'please']],
 [['not',
   'the',
   'hacking',
   'and',
   'gagging',
   'and',
   'spitting',
   'part',
   'please'],
  ['okay',
   'then',
   'how',
   'bout',
   'we',
   'try',
   'out',
   'some',
   'french',
   'cuisine',
   'saturday',
   'night']],
 [['youre',
   'asking',
   'me',
   'out',
   'thats',
   'so',
   'cute',
   'whats',
   'your',
   

In [12]:
len(pairs)

221616

In [13]:
word_count = Counter()
for pair in pairs:
    for word in pair[0]:
        word_count[word] += 1
    for word in pair[1]:
        word_count[word] += 1

In [14]:
word_count

Counter({'can': 14103,
         'we': 25914,
         'make': 5821,
         'this': 30508,
         'quick': 310,
         'roxanne': 1,
         'korrine': 1,
         'and': 52151,
         'andrew': 49,
         'barrett': 20,
         'are': 21717,
         'having': 1081,
         'an': 8828,
         'incredibly': 49,
         'horrendous': 4,
         'public': 308,
         'break': 799,
         'up': 14318,
         'on': 23915,
         'the': 120915,
         'quad': 2,
         'again': 2807,
         'well': 16283,
         'i': 137674,
         'thought': 4202,
         'wed': 548,
         'start': 1459,
         'with': 21401,
         'pronunciation': 2,
         'if': 16734,
         'thats': 14834,
         'okay': 5947,
         'you': 169718,
         'not': 26500,
         'hacking': 18,
         'gagging': 9,
         'spitting': 15,
         'part': 1260,
         'please': 3259,
         'then': 7533,
         'how': 14004,
         'bout': 393,
         'try

In [15]:
word_freq = Counter()
for pair in pairs:
    word_freq.update(pair[0])
    word_freq.update(pair[1])

In [16]:
min_word_freq = 5
words = [word for word in word_freq.keys() if (word_freq[word] > min_word_freq)]
word_map = {k: v+1 for v, k in enumerate(words)}
word_map['<unk>'] = len(word_map)+1
word_map['<pad>'] = 0
word_map['<start>'] = len(word_map)+1
word_map['<end>'] = len(word_map)+1


In [17]:
len(word_map)

18190

In [18]:
with open('data/WORDMAP_corpus.json', 'w') as j:
    json.dump(word_map, j)

In [19]:
def encode_question(words, word_map):
    enc_c = [word_map.get(word, word_map['<unk>']) for word in words] + [word_map['<pad>']]*(max_len-len(words))
    # enc_c = torch.LongTensor(enc_c)
    return enc_c

In [20]:
def encode_reply(words, word_map):
    enc_c = [word_map['<start>']] + [word_map.get(word, word_map['<unk>']) for word in words] + [word_map['<end>']] + [word_map['<pad>']]*(max_len-len(words))
    # enc_c = torch.LongTensor(enc_c)
    return enc_c

In [21]:
pairs[0][0]

['can',
 'we',
 'make',
 'this',
 'quick',
 'roxanne',
 'korrine',
 'and',
 'andrew',
 'barrett',
 'are',
 'having',
 'an',
 'incredibly',
 'horrendous',
 'public',
 'break',
 'up',
 'on',
 'the',
 'quad',
 'again']

In [22]:
encode_question(pairs[0][0], word_map)

[1,
 2,
 3,
 4,
 5,
 18187,
 18187,
 6,
 7,
 8,
 9,
 10,
 11,
 12,
 18187,
 13,
 14,
 15,
 16,
 17,
 18187,
 18,
 0,
 0,
 0]

In [23]:
pairs_encoded = []
for pair in pairs:
    question = encode_question(pair[0], word_map)
    reply = encode_reply(pair[1], word_map)
    pairs_encoded.append([question, reply])

In [24]:
pairs_encoded

[[[1,
   2,
   3,
   4,
   5,
   18187,
   18187,
   6,
   7,
   8,
   9,
   10,
   11,
   12,
   18187,
   13,
   14,
   15,
   16,
   17,
   18187,
   18,
   0,
   0,
   0],
  [18189,
   19,
   20,
   21,
   22,
   23,
   24,
   18187,
   25,
   26,
   27,
   24,
   28,
   18190,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0]],
 [[19,
   20,
   21,
   22,
   23,
   24,
   18187,
   25,
   26,
   27,
   24,
   28,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0],
  [18189,
   29,
   17,
   30,
   6,
   31,
   6,
   32,
   33,
   34,
   18190,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0]],
 [[29,
   17,
   30,
   6,
   31,
   6,
   32,
   33,
   34,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0],
  [18189,
   27,
   35,
   36,
   37,
   2,
   38,
   39,
   40,
   41,
   42,
   43,
   44,
   18190,
   0,
   0,
   0,
   0,
   0,
 

In [25]:
with open('pairs_encoded.json', 'w') as w:
    json.dump(pairs_encoded, w)

In [26]:
class Dataset(Dataset):
    def __init__(self):
        self.pairs = json.load(open('pairs_encoded.json', 'r'))
        self.dataset_size = len(self.pairs)
        
    def __len__(self):
        return self.dataset_size
    
    def __getitem__(self, index):
        question = torch.LongTensor(self.pairs[index][0])
        reply = torch.LongTensor(self.pairs[index][1])
        return question, reply
        

In [27]:
train_loader = torch.utils.data.DataLoader(Dataset(),
                                            batch_size=100,
                                            shuffle=True,
                                            pin_memory=True)

In [28]:
question, reply = next(iter(train_loader))

In [30]:
question.shape

torch.Size([100, 25])

In [36]:
# set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [33]:
def create_masks(question, reply_input, reply_target):

    def subsequent_mask(size):
        "Mask out subsequent positions."
        mask = torch.triu(torch.ones(size, size)).transpose(0, 1).type(dtype=torch.uint8)
        return mask.unsqueeze(0)
    
    question_mask = (question!=0).to(device)
    question_mask = question_mask.unsqueeze(1).unsqueeze(1) # (batch_size, 1, 1, max_words)

    reply_input_mask = (reply_input!=0).to(device)
    reply_input_mask = reply_input_mask.unsqueeze(1) # (batch_size, 1, max_words)
    reply_input_mask = reply_input_mask & subsequent_mask(reply_input.size(-1)).type_as(reply_input_mask.data) # (batch_size, max_words, max_words)
    reply_input_mask = reply_input_mask.unsqueeze(1) # (batch_size, 1, max_words, max_words)
    reply_target_mask = (reply_target!=0).to(device)

    return question_mask, reply_input_mask, reply_target_mask



In [32]:
# How subsequent_mask works
size = 5
t = torch.ones(size, size)
t_triu = torch.triu(t)
t_triu.T # transpose

tensor([[1., 0., 0., 0., 0.],
        [1., 1., 0., 0., 0.],
        [1., 1., 1., 0., 0.],
        [1., 1., 1., 1., 0.],
        [1., 1., 1., 1., 1.]])

In [35]:
question[0] !=0

tensor([ True,  True,  True,  True,  True,  True, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False])