In [1]:
import torch 
import torch.nn as nn 
from torch import optim 
import torch.nn.functional as F  
import csv 
import re #regular expressions 
import os 
import unicodedata 
import codecs 
import itertools 

In [2]:
lines_filepath = "./movie_lines.txt"

In [3]:
conv_filepath = "./movie_conversations.txt"

In [4]:
with open(lines_filepath, 'r', encoding = 'iso-8859-1') as file:
    lines = file.readlines()  
for line in lines[:8]:
    print(line.strip())

L1045 +++$+++ u0 +++$+++ m0 +++$+++ BIANCA +++$+++ They do not!
L1044 +++$+++ u2 +++$+++ m0 +++$+++ CAMERON +++$+++ They do to!
L985 +++$+++ u0 +++$+++ m0 +++$+++ BIANCA +++$+++ I hope so.
L984 +++$+++ u2 +++$+++ m0 +++$+++ CAMERON +++$+++ She okay?
L925 +++$+++ u0 +++$+++ m0 +++$+++ BIANCA +++$+++ Let's go.
L924 +++$+++ u2 +++$+++ m0 +++$+++ CAMERON +++$+++ Wow
L872 +++$+++ u0 +++$+++ m0 +++$+++ BIANCA +++$+++ Okay -- you're gonna need to learn how to lie.
L871 +++$+++ u2 +++$+++ m0 +++$+++ CAMERON +++$+++ No


In [5]:
line_fields = ["lineID","characterID","movieID","character","text"]
lines = {}
with open(lines_filepath, 'r', encoding = 'iso-8859-1') as f:
    for line in f:
        values = line.split(" +++$+++ ")
        lineObj = {}
        for i, field in enumerate(line_fields):
            lineObj[field] = values[i]
        lines[lineObj["lineID"]] = lineObj

In [6]:
conv_fields = ["character1ID","character2ID","movieID","utteranceIDs"]
conversations = []
with open(conv_filepath, 'r', encoding = 'iso-8859-1') as f:
    for line in f:
        values = line.split(" +++$+++ ")
        convObj = {}
        for i, field in enumerate(conv_fields):
            convObj[field] = values[i]
        lineIds = eval(convObj["utteranceIDs"])
        convObj["lines"] = []
        for lineID in lineIds:
            convObj["lines"].append(lines[lineID])
        conversations.append(convObj)

In [7]:
conversations[0]

{'character1ID': 'u0',
 'character2ID': 'u2',
 'movieID': 'm0',
 'utteranceIDs': "['L194', 'L195', 'L196', 'L197']\n",
 'lines': [{'lineID': 'L194',
   'characterID': 'u0',
   'movieID': 'm0',
   'character': 'BIANCA',
   'text': 'Can we make this quick?  Roxanne Korrine and Andrew Barrett are having an incredibly horrendous public break- up on the quad.  Again.\n'},
  {'lineID': 'L195',
   'characterID': 'u2',
   'movieID': 'm0',
   'character': 'CAMERON',
   'text': "Well, I thought we'd start with pronunciation, if that's okay with you.\n"},
  {'lineID': 'L196',
   'characterID': 'u0',
   'movieID': 'm0',
   'character': 'BIANCA',
   'text': 'Not the hacking and gagging and spitting part.  Please.\n'},
  {'lineID': 'L197',
   'characterID': 'u2',
   'movieID': 'm0',
   'character': 'CAMERON',
   'text': "Okay... then how 'bout we try out some French cuisine.  Saturday?  Night?\n"}]}

In [8]:
qa_pairs = []
for conversation in conversations:
    for i in range(len(conversation["lines"]) - 1):
        inputline = conversation['lines'][i]['text'].strip()
        targetline = conversation['lines'][i+1]['text'].strip() 
        if inputline and targetline:
            qa_pairs.append([inputline,targetline])

In [9]:
qa_pairs[0]

['Can we make this quick?  Roxanne Korrine and Andrew Barrett are having an incredibly horrendous public break- up on the quad.  Again.',
 "Well, I thought we'd start with pronunciation, if that's okay with you."]

In [10]:
datafile = './formatted_movie_lines.txt'
delimeter = '\t'
delimeter = str(codecs.decode(delimeter,"unicode_escape"))

with open(datafile, 'w', encoding= 'utf-8') as outputfile:
    writer = csv.writer(outputfile,delimiter = delimeter)
    for pair in qa_pairs:
        writer.writerow(pair)
        

In [11]:
with open(datafile, 'rb') as file:
    lines = file.readlines() 
for line in lines[:8]:
    print(line)

b"Can we make this quick?  Roxanne Korrine and Andrew Barrett are having an incredibly horrendous public break- up on the quad.  Again.\tWell, I thought we'd start with pronunciation, if that's okay with you.\r\n"
b"Well, I thought we'd start with pronunciation, if that's okay with you.\tNot the hacking and gagging and spitting part.  Please.\r\n"
b"Not the hacking and gagging and spitting part.  Please.\tOkay... then how 'bout we try out some French cuisine.  Saturday?  Night?\r\n"
b"You're asking me out.  That's so cute. What's your name again?\tForget it.\r\n"
b"No, no, it's my fault -- we didn't have a proper introduction ---\tCameron.\r\n"
b"Cameron.\tThe thing is, Cameron -- I'm at the mercy of a particularly hideous breed of loser.  My sister.  I can't date until she does.\r\n"
b"The thing is, Cameron -- I'm at the mercy of a particularly hideous breed of loser.  My sister.  I can't date until she does.\tSeems like she could get a date easy enough...\r\n"
b'Why?\tUnsolved myster

In [12]:
PAD_token = 0 
SOS_token = 0 
EOS_token = 0 
 
class Vocabulary:
    def __init__(self, name):
        self.name = name 
        self.word2index = {}
        self.word2count = {}
        self.index2word = {PAD_token : "PAD", SOS_taken : "SOS", EOS_token: "EOS"}
        self.num_words = 3 
        
    def addSentence(self,sentence):
        for word in sentence.split(' '):
            self.addWord(word)
    def addWord(self, word):
        if word not in self.word2index:
            self.word2index[word] = self.num_words
            self.word2count[word] = 1 
            self.index2word[self.num_words] = word
            self.num_words += 1 
        else :
            self.word2count[word] += 1 
            
    def trim(self, mincount):
        keep_words = []
        for k,v in self.word2count.items():
            if v >= mincount:
                keep_words.append(word)
        print('keep words {} / {} = {:.4f}'.format(len(keep_words), len(word2index), len(keep_words)/len(word2index)))
        self.name = name 
        self.word2index = {}
        self.word2count = {}
        self.index2word = {PAD_token : "PAD", SOS_taken : "SOS", EOS_token: "EOS"}
        self.num_words = 3 
        for word in keep_words:
            self.addWord(word)

In [13]:
def unicodetoascii(s):
    return ''.join(c for c in unicodedata.normalize('NFD', s) if unicodedata.category(c) != 'Mn')

In [14]:
def normalizeString(s):
    s = unicodetoascii(s.lower.strip())
    s = re.sub(r"([.!?])", r"  \1", s)
    s = re.sub(r"[^a-zA-Z.!?]+", r" ", s)
    s = re.sub(r"\s+", r" ", s).strip()
    return s