Let us first define all relevant file paths.

In [1]:
import os, itertools
import tensorflow as tf
from tensorflow.python.ops import lookup_ops
from collections import namedtuple

#Set this to actual path of the folder where your training, validation and vocab files reside
DATA_DIR = 'ubuntu-data'

#Training files: context, response and flag
train_context_file = os.path.join(DATA_DIR, 'train.context')
train_response_file = os.path.join(DATA_DIR, 'train.response')
train_flag_file = os.path.join(DATA_DIR, 'train.flag')

#Validation files: context, response and flag
valid_context_file = os.path.join(DATA_DIR, 'valid.context')
valid_response_file = os.path.join(DATA_DIR, 'valid.response')
valid_flag_file = os.path.join(DATA_DIR, 'valid.flag')

#Vocab file
vocab_file = os.path.join(DATA_DIR, 'vocab.txt')

This should not seem new. We saw this in [dataset_ops notebook](https://github.com/vineetm/tensorflow-notes/blob/master/siamese/notebooks/dataset_ops.ipynb)

In [2]:
#Notice a new (sixth) field init. This is initializer for iterator
class DataIterator(namedtuple('DataIterator', 'init context len_context response len_response flag')):
    pass

def text_to_word_indexes(text_file, vocab_table):
    dataset = tf.data.TextLineDataset(text_file)
    
    #Split sentence to words
    dataset = dataset.map(lambda sentence: tf.string_split([sentence]).values)

    #Convert words to indexes
    dataset = dataset.map(lambda words: vocab_table.lookup(words))
    
    return dataset

def create_dataset_iterator(vocab_table, context_file, response_file, flag_file, batch_size):
    #Create a vocab table. word -> index. Tell it if word is not found, use index 0 `UNK`
    vocab_table = lookup_ops.index_table_from_file(vocab_file, default_value=0)

    #Create context dataset, sentence -> word indexes
    context_dataset = text_to_word_indexes(context_file, vocab_table)

    #Restrict context to Last 160 tokens
    context_dataset = context_dataset.map(lambda words: words[-160:])

    #Create response dataset, sentence -> word indexes
    response_dataset = text_to_word_indexes(response_file, vocab_table)

    flag_dataset = tf.data.TextLineDataset(flag_file)
    # Convert string to a float..
    flag_dataset = flag_dataset.map(lambda sentence: tf.string_to_number(sentence))

    #Join datasets together, using zip
    dataset = tf.data.Dataset.zip((context_dataset, response_dataset, flag_dataset))

    #Add length of context and response
    dataset = dataset.map(lambda context, response, flag: (context, tf.size(context), response, tf.size(response), flag))

    dataset = dataset.padded_batch(batch_size, padded_shapes=(tf.TensorShape([None]), tf.TensorShape([]), tf.TensorShape([None]), tf.TensorShape([]), tf.TensorShape([])))
    iterator = dataset.make_initializable_iterator()

    context, len_context, response, len_response, flag = iterator.get_next()

    return DataIterator(iterator.initializer, context, len_context, response, len_response, flag)

In [3]:
from tensorflow.contrib.learn import ModeKeys

class Model:
    def __init__(self, V, d, iterator, mode, lr=0.001):
        #Mode is either Train or evaluation!
        self.mode = mode
        
        self.iterator = iterator
        
        #Common parts to train and evaluation
        self.W = tf.get_variable(name='word_embeddings', shape=[V, d])
        
        #Context -> c
        self.context = tf.nn.embedding_lookup(self.W, self.iterator.context)
        rnn_cell = tf.contrib.rnn.BasicLSTMCell(d)
        with tf.variable_scope('rnn'):
            _, state_context = tf.nn.dynamic_rnn(cell=rnn_cell, inputs=self.context, 
                                                 sequence_length=self.iterator.len_context, dtype=tf.float32)
        c = state_context.h
        
        #Response -> c
        response = tf.nn.embedding_lookup(self.W, self.iterator.response)
        with tf.variable_scope('rnn', reuse=True):
            _, state_response = tf.nn.dynamic_rnn(cell=rnn_cell, inputs=response,
                                                  sequence_length=self.iterator.len_response, dtype=tf.float32)
        r = state_response.h

        self.M = tf.Variable(tf.eye(d), name='M')
        
        #For checkpoint
        self.saver = tf.train.Saver(tf.global_variables())
        
        self.logits = tf.reduce_sum(tf.multiply(c, tf.matmul(r, self.M)), axis=1)
        self.batch_loss = tf.nn.sigmoid_cross_entropy_with_logits(labels=self.iterator.flag, logits=self.logits)
        self.loss = tf.reduce_mean(self.batch_loss)
        
        if self.mode == ModeKeys.TRAIN:
            opt = tf.train.AdamOptimizer(lr)
            params = tf.trainable_variables()
            print('Trainable params: %s'%params)
            gradients = tf.gradients(self.loss, params)
            clipped_gradients, grad_norm = tf.clip_by_global_norm(gradients, 5.0)
            self.train_step = opt.apply_gradients(zip(clipped_gradients, params))

    def eval(self, sess):
        assert self.mode == ModeKeys.EVAL

        #Initialize iterator
        sess.run(self.iterator.init)

        total_loss = 0.0
        num_batches = 0
        while True:
          try:
            total_loss += sess.run(self.loss)
            num_batches += 1
          except tf.errors.OutOfRangeError:
            avg_loss = total_loss / num_batches
            return avg_loss

In [4]:
#This is size of `vocab.txt`
V = 30430
d = 128

In [5]:
train_graph = tf.Graph()
with train_graph.as_default():
    vocab_table = lookup_ops.index_table_from_file(vocab_file, default_value=0)
    train_iterator = create_dataset_iterator(vocab_table, train_context_file, train_response_file, 
                                             train_flag_file, batch_size=16)
    train_model = Model(V, d, train_iterator, ModeKeys.TRAIN)

    train_sess = tf.Session()
    train_sess.run(tf.global_variables_initializer())
    train_sess.run(tf.tables_initializer())
    train_sess.run(train_iterator.init)

Trainable params: [<tf.Variable 'word_embeddings:0' shape=(30430, 128) dtype=float32_ref>, <tf.Variable 'rnn/rnn/basic_lstm_cell/kernel:0' shape=(256, 512) dtype=float32_ref>, <tf.Variable 'rnn/rnn/basic_lstm_cell/bias:0' shape=(512,) dtype=float32_ref>, <tf.Variable 'M:0' shape=(128, 128) dtype=float32_ref>]


In [6]:
valid_graph = tf.Graph()
with valid_graph.as_default():
    vocab_table = lookup_ops.index_table_from_file(vocab_file, default_value=0)
    valid_iterator = create_dataset_iterator(vocab_table, valid_context_file, valid_response_file, 
                                             valid_flag_file, batch_size=16)
    valid_model = Model(V, d, valid_iterator, ModeKeys.EVAL)

    valid_sess = tf.Session()
    valid_sess.run(tf.global_variables_initializer())
    valid_sess.run(tf.tables_initializer())    

In [7]:
#Define checkpoint directory
CKPT_DIR = 'saved_model/'
if not tf.gfile.Exists(CKPT_DIR):
    print('Creating {}'.format(CKPT_DIR))
    tf.gfile.MkDir(CKPT_DIR)

In [None]:
for step in itertools.count():
    _, loss = train_sess.run([train_model.train_step, train_model.loss])
    if step % 10 == 0:
        print('Step: {} Train_Loss: {}'.format(step, loss))
        train_model.saver.save(train_sess, CKPT_DIR, step)
        
        latest_ckpt = tf.train.latest_checkpoint(CKPT_DIR)
        valid_model.saver.restore(valid_sess, latest_ckpt)
        val_loss = valid_model.eval(valid_sess)
        print('Step: {} Valid_Loss: {}'.format(step, val_loss))

Step: 0 Train_Loss: 0.6931326389312744
INFO:tensorflow:Restoring parameters from saved_model/-0
