In [1]:
import numpy as np
import tensorflow as tf

In [2]:
# Don't use all the VRAM!
gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.15)
sess = tf.Session(config=tf.ConfigProto(gpu_options=gpu_options))

In [3]:
# Data variables
seq_length = 16
out_seq_length = 16
batch_size = 32
vocab_size = 26 + 1               # 0 for padding
embedding_dim = 26

# Network hyperparameters
memory_dim = 200
num_layers = 1

# Training variables
epochs = 10

First build input placeholders and constants. The `seq2seq` API generally deals with lists of tensors, where each tensor represents a single timestep. An input to an embedding encoder, for example, would be a list of `seq_length` tensors, each of which is of dimension `batch_size` (specifying the embedding indices to input at a particular timestep).

We allocate a `labels` placeholder using the same convention. A `weights` constant specifies cross-entropy weights for each label at each timestep.

In [4]:
enc_inp = [tf.placeholder(tf.int32, shape=(None,), name="inp%i" % t) for t in range(seq_length)]
labels = [tf.placeholder(tf.int32, shape=(None,), name="labels%i" % t) for t in range(out_seq_length)]
weights = [tf.ones_like(labels_t, dtype=tf.float32) for labels_t in labels]

# Decoder input: prepend some "GO" token and drop the final
# token of the decoder output
dec_inp = ([tf.zeros_like(enc_inp[0], dtype=np.int32, name="GO")] +
           [tf.placeholder(tf.int32, shape=(None,), name="dec_inp%i" % t) for t in range(out_seq_length - 1)])

# Initial memory value for recurrence.
prev_mem = tf.zeros((batch_size, memory_dim))

Build the sequence-to-sequence graph.

There is a **lot** of complexity hidden in these two calls, and it's certainly worth digging into both in order to really understand how this is working.

In [5]:
constituent_cell = tf.nn.rnn_cell.BasicLSTMCell(memory_dim)

if num_layers > 1:
    cell = tf.nn.rnn_cell.MultiRNNCell([constituent_cell] * num_layers)
else:
    cell = constituent_cell

# Without teacher forcing, with attention
ntf_dec_outputs, ntf_dec_memory = tf.nn.seq2seq.embedding_attention_seq2seq(enc_inp, dec_inp, cell, vocab_size+1, vocab_size+1, embedding_dim, feed_previous=True)



Build a standard sequence loss function: mean cross-entropy over each item of each sequence.

In [6]:
ntf_loss = tf.nn.seq2seq.sequence_loss(ntf_dec_outputs, labels, weights, vocab_size + 1)

Build an optimizer.

In [7]:
learning_rate = 0.05
momentum = 0.9
optimizer = tf.train.AdamOptimizer()
train_op = optimizer.minimize(ntf_loss)

In [8]:
saver = tf.train.Saver(tf.all_variables(), max_to_keep=5)

# Restore variables
Optionally restore variables

In [9]:
resume_at = 0

if resume_at > 0:
    saver.restore(sess, 'checkpoints/saved-model-1off-attn-' + str(resume_at))

# Train

Do not initialize variables if restoring from a saved file.  
**Warning:** epoch numbers start from 0, and *will* overwrite your old saves!

In [10]:
if resume_at == 0:
    sess.run(tf.initialize_all_variables())

In [11]:
# Load data
train_x = np.load('data/mutated-train.npy')
train_y = np.load('data/fixes-train.npy')

assert(len(train_x) == len(train_y))
num_train = len(train_x)
print num_train

valid_x = np.load('data/mutated-validation.npy')
valid_y = np.load('data/fixes-validation.npy')

assert(len(valid_x) == len(valid_y))
num_validation = len(valid_x)
print num_validation

2080
640


In [12]:
def validate_batch(batch_id):
    X = valid_x[batch_id*batch_size:(batch_id+1)*batch_size]
    Y = valid_y[batch_id*batch_size:(batch_id+1)*batch_size]
    
    # Dimshuffle to seq_len * batch_size
    X = np.array(X).T
    Y = np.array(Y).T

    feed_dict = {enc_inp[t]: X[t] for t in range(seq_length)}
    feed_dict.update({labels[t]: Y[t] for t in range(out_seq_length)})
    feed_dict.update({dec_inp[t]: Y[t] for t in range(out_seq_length - 1)})

    loss_t = sess.run([ntf_loss], feed_dict)
    dec_outputs_batch = sess.run(ntf_dec_outputs, feed_dict)
    Y_hat = [logits_t.argmax(axis=1) for logits_t in dec_outputs_batch]
    
    accuracy = float(np.count_nonzero(np.equal(Y, Y_hat)))/np.prod(np.shape(Y))
    
    return loss_t, accuracy

In [13]:
def train_batch(batch_id):
    X = train_x[batch_id*batch_size:(batch_id+1)*batch_size]
    Y = train_y[batch_id*batch_size:(batch_id+1)*batch_size]
    
    # Dimshuffle to seq_len * batch_size
    X = np.array(X).T
    Y = np.array(Y).T

    feed_dict = {enc_inp[t]: X[t] for t in range(seq_length)}
    feed_dict.update({labels[t]: Y[t] for t in range(out_seq_length)})
    feed_dict.update({dec_inp[t]: Y[t] for t in range(out_seq_length - 1)})

    _, loss_t = sess.run([train_op, ntf_loss], feed_dict)
    return loss_t

In [14]:
step = 0

for t in range(epochs):
    # Training
    train_loss = []
    
    for i in range(num_train/batch_size):
        train_loss.append(train_batch(i))
        step += 1
        
        if i % 1 == 0:
            print "Step: " + str(step) + "\tEpoch: " + str(t + resume_at + float(i)/(num_train/batch_size)) + "\tTraining: " + str(train_loss[-1])
            
        if step % 200 == 0:
            saver.save(sess, 'checkpoints/saved-model-1off-attn', global_step=step)
        
    train_loss = np.mean(train_loss)
    
    # Validation
    valid_loss = []
    valid_acc = []
    
    for i in range(num_validation/batch_size):
        this_loss, this_acc = validate_batch(i)
        valid_loss.append(this_loss)
        valid_acc.append(this_acc)
        #print this_acc
        
    valid_loss = np.mean(valid_loss)
    valid_acc = np.mean(valid_acc)
    
    print "Step: " + str(step) + "\tEpoch: " + str(t + resume_at + 1) + "\tTraining: " + str(train_loss) + "\tValidation loss: " + str(valid_loss) + "\tValidation acc: " + str(valid_acc)
    saver.save(sess, 'checkpoints/saved-model-1off-attn', global_step=step)

Step: 1	Epoch: 0.0	Training: 3.39436
Step: 2	Epoch: 0.0153846153846	Training: 3.22046
Step: 3	Epoch: 0.0307692307692	Training: 3.13181
Step: 4	Epoch: 0.0461538461538	Training: 2.88012
Step: 5	Epoch: 0.0615384615385	Training: 2.8712
Step: 6	Epoch: 0.0769230769231	Training: 2.60116
Step: 7	Epoch: 0.0923076923077	Training: 2.60308
Step: 8	Epoch: 0.107692307692	Training: 2.60628
Step: 9	Epoch: 0.123076923077	Training: 2.74845
Step: 10	Epoch: 0.138461538462	Training: 2.82863
Step: 11	Epoch: 0.153846153846	Training: 2.75379
Step: 12	Epoch: 0.169230769231	Training: 2.53044
Step: 13	Epoch: 0.184615384615	Training: 2.73766
Step: 14	Epoch: 0.2	Training: 2.5248
Step: 15	Epoch: 0.215384615385	Training: 2.67623
Step: 16	Epoch: 0.230769230769	Training: 2.55156
Step: 17	Epoch: 0.246153846154	Training: 2.54372
Step: 18	Epoch: 0.261538461538	Training: 2.45925
Step: 19	Epoch: 0.276923076923	Training: 2.32565
Step: 20	Epoch: 0.292307692308	Training: 2.47849
Step: 21	Epoch: 0.307692307692	Training: 2.5045

In [15]:
sess.close()