In [31]:
import numpy as np
import nltk
import os
import ast
import pickle
import tensorflow as tf
import random

In [2]:
cwd = os.getcwd()
#print(cwd)
corpusDir = os.path.join(cwd, 'data/cornell')
#print(corpusDir)

In [3]:
lines = {}
with open(os.path.join(corpusDir, 'movie_lines.txt'), 'r', encoding='iso-8859-1') as f:
    for line in f:
        #print(line)
        fields = line.split(' +++$+++ ')
        #print(fields)
        obj = {}
        obj['lineID'] = fields[0]
        obj['characterID'] = fields[1]
        obj['movieID'] = fields[2]
        obj['characterName'] = fields[3]
        obj['text'] = fields[4]
        lines[fields[0]] = obj
#print(lines)

In [4]:
conversations = []
with open(os.path.join(corpusDir, 'movie_conversations.txt'), 'r', encoding='iso-8859-1') as f:
    for line in f:
        #print(line)
        fields = line.split(' +++$+++ ')
        #print(fields)
        obj = {}
        obj['character1ID'] = fields[0]
        obj['character2ID'] = fields[1]
        obj['movieID'] = fields[2]
        #obj['lineIDs'] = fields[3]
        #print(obj)
        lineIDs = ast.literal_eval(fields[3])
        #print(lineIDs)
        obj['lineIDs'] = lineIDs
        #print(obj)
        obj['lines'] = []
        for lineID in lineIDs:
            #print(lineID, "--", lines[lineID])
            obj['lines'].append(lines[lineID])
        conversations.append(obj)
#print(conversations)

In [5]:
wordIDMap = {}
IDWordMap = {}
unknownToken = -1
trainingSamples = []

In [6]:
def getWordID(word, shouldAddToDict=True):
    word = word.lower()
    wordID = wordIDMap.get(word, -1)
    if wordID == -1:
        if shouldAddToDict:
            wordID = len(wordIDMap)
            wordIDMap[word] = wordID
            IDWordMap[wordID] = word
        else:
            wordID = unknownToken
    return wordID

In [7]:
sentMaxLength = 10 #maximum length of an input or output sentence
def getWordsFromLine(line, isReply=False):
    '''Returns the word IDs from the vovabulary'''
    words = []
    sentences = nltk.sent_tokenize(line)
    #print(sentences)
    # Since we are limited by a maxmimum length of sentences, we keep the last lines if the statement is a question/input
    # and the first few lines if the statement is an answer/reply
    for i in range(len(sentences)):
        if not isReply:
            i = len(sentences) - 1 - i
        tokensFromCurrSent = nltk.word_tokenize(sentences[i])
        #print(tokensFromCurrSent)
        if len(words) + len(tokensFromCurrSent) > sentMaxLength:
            break
        else:
            temp = []
            for token in tokensFromCurrSent:
                temp.append(getWordID(token))
            if isReply:
                words = words + temp
            else:
                words = temp + words # Append in the reverse order because we're considering the last few lines
    return words

In [8]:
for conversation in conversations:
    #print(conversation)
    for i in range(len(conversation['lines']) - 1):
        #print(conversation['lines'][i])
        inputStatement = conversation['lines'][i]
        #print(inputStatement)
        replyStatement = conversation['lines'][i + 1]
        inputWords = getWordsFromLine(inputStatement['text'])
        replyWords = getWordsFromLine(replyStatement['text'], True)
        #print(inputWords)
        #print(replyWords)
        
        if inputWords and replyWords:
            trainingSamples.append([inputWords, replyWords])
#print(trainingSamples)

In [9]:
print("Saving dataset samples ...")
with open(os.path.join(cwd, 'data/samples', 'sampleData.pkl'), 'wb') as f:
    data = {
        'wordIDMap': wordIDMap,
        'IDWordMap': IDWordMap,
        'trainingSamples': trainingSamples
    }
    pickle.dump(data, f, -1)
print('Done')

Saving dataset samples ...
Done


In [10]:
#Parameters
globalStep = 85
cellUnitCount = 512
numOfLayers = 2
embeddingSize = 64
learningRate = 0.02
batchSize = 256
dropout = 0.9
softmaxSamples = 0
numOfEpochs = 2

In [11]:
def make_lstm_cell():
    encoderDecoderCell = tf.contrib.rnn.BasicLSTMCell(cellUnitCount)
    encoderDecoderCell = tf.contrib.rnn.DropoutWrapper(encoderDecoderCell, input_keep_prob=1.0, output_keep_prob=dropout)
    return encoderDecoderCell

In [12]:
#Expand the list comprehension below
encoderDecoderCell = tf.contrib.rnn.MultiRNNCell(
    [make_lstm_cell() for _ in range(numOfLayers)],
)

In [13]:
with tf.name_scope('encoder'):
    encoderInputs = [tf.placeholder(tf.int32, [None, ]) for _ in range(sentMaxLength)]
with tf.name_scope('decoder'):
    decoderInputs = [tf.placeholder(tf.int32, [None, ], name="inputs") for _ in range(sentMaxLength + 2)]
    decoderTargets = [tf.placeholder(tf.int32, [None, ], name="targets") for _ in range(sentMaxLength + 2)]
    decoderWeights = [tf.placeholder(tf.float32, [None, ], name="weights") for _ in range(sentMaxLength + 2)]

In [14]:
#Verify this - is different from the existing
decoderOutput, state = tf.contrib.legacy_seq2seq.embedding_rnn_seq2seq(
    encoderInputs,
    decoderInputs,
    encoderDecoderCell,
    len(wordIDMap),
    len(wordIDMap),
    embeddingSize,
    output_projection=None,
    feed_previous=False
)

In [15]:
lossFunc = tf.contrib.legacy_seq2seq.sequence_loss(
    decoderOutput,
    decoderTargets,
    decoderWeights,
    len(wordIDMap),
    softmax_loss_function=None
)
tf.summary.scalar('loss', lossFunc)

<tf.Tensor 'loss:0' shape=() dtype=string>

In [16]:
optimizer = tf.train.AdamOptimizer(
    learning_rate=learningRate,
    beta1=0.9,
    beta2=0.999,
    epsilon=1e-08
)
optimizationOperation = optimizer.minimize(lossFunc)

In [17]:
writer = tf.summary.FileWriter('seq2seq')
saver = tf.train.Saver(max_to_keep=200)

In [18]:
sess = tf.Session(
    config=tf.ConfigProto(
        allow_soft_placement=True,
        log_device_placement=False
    )
)
sess.run(tf.global_variables_initializer())

In [19]:
#Change variable scope name
with tf.variable_scope("embedding_rnn_seq2seq/rnn/embedding_wrapper", reuse=True):
    in_embedding = tf.get_variable("embedding")
with tf.variable_scope("embedding_rnn_seq2seq/embedding_rnn_decoder", reuse=True):
    out_embedding = tf.get_variable("embedding")

embedding_vars = tf.get_collection_ref(tf.GraphKeys.TRAINABLE_VARIABLES)
embedding_vars.remove(in_embedding)
embedding_vars.remove(out_embedding)

'''
if globalStep != 0:
    return
'''

'\nif globalStep != 0:\n    return\n'

In [20]:
with open(os.path.join(cwd, 'data/word2vec/GoogleNews-vectors-negative300.bin'), "rb", 0) as f:
    header = f.readline().split()
    #print(header)
    vocabulary_size = int(header[0])
    word_vector_size = int(header[1])
    #print('{}, {}'.format(vocabulary_size, word_vector_size))
    binary_length = np.dtype('float32').itemsize * word_vector_size
    #print(binary_length)
    initial_weights = np.random.uniform(-0.25, 0.25, (len(wordIDMap), word_vector_size))
    #print(initial_weights)
    for line in range(word_vector_size):
        word = []
        while True:
            ch = f.read(1)
            if ch == b' ':
                word = b''.join(word).decode('utf-8')
                break
            if ch != b'\n':
                word.append(ch)
        if word in wordIDMap:
            initial_weights[wordIDMap[word]] = np.fromstring(f.read(binary_length), dtype='float32')
        else:
            f.read(binary_length)

In [21]:
if embeddingSize < word_vector_size:
    u, s, vt = np.linalg.svd(initial_weights, full_matrices=False)
    S = np.zeros((word_vector_size, word_vector_size), dtype=complex)
    S[:word_vector_size, :word_vector_size] = np.diag(s)
    initial_weights = np.dot(u[:, :embeddingSize], S[:embeddingSize, :embeddingSize])

In [24]:
sess.run(in_embedding.assign(initial_weights))
sess.run(out_embedding.assign(initial_weights))

  nparray = values.astype(dtype.as_numpy_dtype)


array([[-0.03674802, -0.26551384, -0.23002574, ..., -0.1891018 ,
         0.0316099 , -0.06309055],
       [-0.15077463,  0.10474775,  0.01148877, ...,  0.04616332,
         0.20235911,  0.08956377],
       [-1.10908699,  0.10511734,  0.01103363, ..., -0.00897589,
         0.06301182, -0.07193553],
       ..., 
       [-0.11646231, -0.08408494,  0.06960518, ...,  0.13072075,
        -0.16817857, -0.03008695],
       [ 0.21607585,  0.01507915,  0.27167937, ..., -0.11331001,
        -0.12463232, -0.0332513 ],
       [ 0.24922076, -0.23746543, -0.04036949, ..., -0.16037728,
        -0.00980208, -0.04190104]], dtype=float32)

In [26]:
def generateNextSample():
    for i in range(0, len(trainingSamples), batchSize):
        yield trainingSamples[i:min(i + batchSize, len(trainingSamples))]

In [27]:
class Batch:
    def __init__(self):
        self.encoderSeqs = []
        self.decoderSeqs = []
        self.targetSeqs = []
        self.weights = []

In [32]:
# Training Loop
completeSummary = tf.summary.merge_all()
if globalStep == 0:
    writer.add_graph(sess.graph)
try:
    for epoch in range(numOfEpochs):
        print("\nEpoch {}".format(epoch+1))
        random.shuffle(trainingSamples)
        
        batches = []
        for samples in generateNextSample():
            batch = Batch()
            batchSize = len(samples)
            for i in range(batchSize):
                sample = samples[i]
                #print(sample)
                batch.encoderSeqs.append(list(reversed(sample[0])))
                
except (KeyboardInterrupt, SystemExit):
    print('Exiting')


Epoch 1
[[18, 156, 51, 52, 2, 993, 555, 347, 71, 1], [97, 62, 18, 93, 279, 1]]
[[68, 69, 308, 51, 154, 14], [68, 264, 18, 225, 51, 317, 68]]
[[18, 93, 415, 163, 12, 261, 62, 2, 51, 1], [10, 3, 1323, 42, 51, 114, 163, 45, 261, 14]]
[[97, 1], [80, 51, 70, 36, 10401, 2603, 14]]
[[51, 19, 20, 165, 154, 271, 26, 36, 1], [18, 41, 42, 739, 18, 253, 1]]
[[311, 68], [206, 10, 124, 12, 824, 62, 17573, 14]]
[[25919, 3256, 1, 3, 1261, 18, 621, 3589, 1], [376, 907, 203, 126, 391, 1]]
[[1686, 29, 1084, 1], [51, 662, 36, 101, 436, 177, 3, 1273, 14]]
[[126, 100, 43, 98, 51, 27, 9430, 39, 14], [42, 20, 67, 98, 14419, 1]]
[[342, 11, 176, 14, 32], [97, 1, 72, 3, 1231, 1]]
[[45, 570, 11, 2291, 1], [288, 72, 1]]
[[18, 42, 20, 114, 10, 75, 42, 1], [42, 20, 42, 739, 1]]
[[26, 16, 177, 166, 14], [164, 19, 20, 476, 51, 62, 2810, 1]]
[[18, 1479, 1475, 1], [10, 14, 51, 1479, 1475, 14, 165, 206, 14]]
[[51, 52, 256, 111, 1], [311, 1]]
[[1363, 1], [18, 225, 16, 124, 5443, 1]]
[[54, 10, 14], [3, 1075, 1]]
[[32, 3, 