In [1]:
import os
import json
from torch.utils.data import DataLoader
from collections import Counter
import torch
from torchtext.vocab import Vocab
from torchtext.data.utils import get_tokenizer
from torch.nn.utils.rnn import pad_sequence
import torch.nn as nn

In [2]:
dir_path = 'cornell movie-dialogs corpus'
convs_path = 'movie_conversations.txt'
lines_path = 'movie_lines.txt'

In [3]:

def getData():

    """
    Returns a list of (text, label) pairs
    """

    convs = []

    with open(os.path.join(dir_path, convs_path), 'r', encoding= 'iso-8859-1') as file:

        for line in file:
            li = line.split('+++$+++')
            li = li[3]
            convs.append(eval(li)) 

    lines = dict()

    with open(os.path.join(dir_path, lines_path), 'r', encoding = 'iso-8859-1') as file:

        for line in file:
            li = line.split(' +++$+++ ')
            lines[li[0]] = li[4].strip('\n')

    data = []

    for conv in convs:
        for i in range(len(conv) - 1):
            li = (lines[conv[i]], lines[conv[i + 1]])
            data.append(li)

    return data

#builds a vocabulary object from torchtext

def tokenizeData(data):

    tokenizer = get_tokenizer('basic_english')

    tokenized_data = []

    for (text, label) in data:
        tokenized_data.append((tokenizer(text), tokenizer(label)))

    return tokenized_data

def buildVocab(data):

    """
    builds a Vocab object with pretrained embeddings, which can be used later 
    in the decoder
    """

    counter = Counter()
    tokenizer = get_tokenizer('basic_english')

    for sentence, response in data:
        counter.update(sentence)
        counter.update(response)


    vocab = Vocab(counter, 
            specials=('<eos>', '<pad>', '<unk>', '<bos>'),
            vectors='fasttext.simple.300d'
            )

    return vocab

def MapData(data, vocab):

    mappedData = []

    for sentence, response in data:
        mappedSentence = torch.tensor([vocab[i] for i in sentence] + [vocab['<eos>']])
        mappedResponse = torch.tensor([vocab[i] for i in response] + [vocab['<eos>']])

        mappedData.append((mappedSentence, mappedResponse))

    return mappedData

def UnmapData(mappedData, vocab):
    
    unmappedData = []

    for(sentence, response) in mappedData:
        unmappedSentence = [vocab.itos[i] for i in sentence] 
        unmappedResponse = [vocab.itos[i] for i in response]

        unmappedData.append((unmappedSentence, unmappedResponse))

    return unmappedData


In [4]:
# get data from conversational dataset
data = getData()
# tokenize data
tokenized_data = tokenizeData(data)
# build pytorch vocab object w/ pretrained weights from tokenized data
vocab = buildVocab(tokenized_data)
# map the data to its vocab indices
mappedData = MapData(tokenized_data, vocab)

In [5]:
unmappedData = UnmapData(mappedData, vocab)
print(mappedData[0], unmappedData[0])


(tensor([   42,    24,   117,    27,   969,     9, 58718, 54606,    17,  4007,
         7896,    46,   415,    86,  3857, 23552,   959, 14055,    60,    41,
           10, 41318,     4,   194,     4,     0]), tensor([   68,     6,     8,   150,    24,     5,    77,   319,    44, 41205,
            6,    55,    16,     5,    13,   121,    44,     7,     4,     0])) (['can', 'we', 'make', 'this', 'quick', '?', 'roxanne', 'korrine', 'and', 'andrew', 'barrett', 'are', 'having', 'an', 'incredibly', 'horrendous', 'public', 'break-', 'up', 'on', 'the', 'quad', '.', 'again', '.', '<eos>'], ['well', ',', 'i', 'thought', 'we', "'", 'd', 'start', 'with', 'pronunciation', ',', 'if', 'that', "'", 's', 'okay', 'with', 'you', '.', '<eos>'])


In [6]:
def collate(batch):
    
    queries = []
    responses = []

    for query, response in batch:
        queries.append(query)
        responses.append(response)
    
    queries = pad_sequence(queries, batch_first= False)
    responses = pad_sequence(responses, batch_first= False)

    return (queries, responses)

def MakeDataset(mappedData, vocab):
    train_dataloader = DataLoader(mappedData, batch_size=16, sampler=torch.utils.data.RandomSampler(mappedData), collate_fn = collate)
    return train_dataloader

dataset = MakeDataset(mappedData, vocab)


In [7]:
### TODO make 
# since this is an encoder, do we really need x?

class Encoder(nn.Module):

    def __init__(self, vocab):
        super(Encoder, self).__init__()

        self.vocabLen = len(vocab)
        self.embedding = nn.Embedding.from_pretrained(vocab.vectors)
        self.rnn = nn.GRU(300, 1000, 1)  

    def forward(self, x):
        x = self.embedding(x)
        x, h = self.rnn(x)

        return x, h

class Decoder(nn.Module):
    def __init__(self, vocab):
        super(Decoder, self).__init__()

        #softmax with vocab size
        self.rnn = nn.GRU(1, 1000, 1)
        self.dense = nn.Linear(1000, len(vocab))


    def forward(self, x, h):
        
        x, h = self.rnn(x, h)
        x = self.dense(x)

        return x, h

In [22]:
data_test = next(iter(dataset))[0]

encoderModel = Encoder(vocab)
decoderModel = Decoder(vocab)

encoderOptimizer = torch.optim.Adam(encoderModel.parameters(), lr=0.0001)
decoderOptimizer = torch.optim.Adam(decoderModel.parameters(), lr=0.0001)

criterion = nn.CrossEntropyLoss()



In [23]:

def train(epochs):
    for i in range(epochs):
        for batchCount, batch in enumerate(dataset):

            #run batch through encoder
            query = batch[0]
            labels = batch[1]
            loss = 0

            x, h = encoderModel(query)
            #get the last hidden unit to use in decoder
            #set max length to max length of query for now
            max_len_out = query.shape[1]
            last_x = torch.zeros(1, 16, 1)
            last_h = h

            #output sentence for decoder
            output = []

            #generate the output sentence

            for j in range(max_len_out):
              

                last_x, last_h = decoderModel(last_x, last_h)
    
                loss += criterion(last_x.squeeze(0), labels[:][j])
                
                last_x = torch.argmax(last_x, dim=2, keepdim=True)

                last_x = last_x.type(torch.float32)

               
            print('epoch: ', i,  'loss: ', loss, ' iteration: ', batchCount, ' / ', 865)

            encoderOptimizer.zero_grad()
            decoderOptimizer.zero_grad()

            loss.backward()      

            encoderOptimizer.step()
            decoderOptimizer.step()

train(1)

RuntimeError: input must have 3 dimensions, got 2