Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
…tensorflow

Commiting tensorflow 1.3
  • Loading branch information
ritheshkumar95 committed Oct 20, 2017
2 parents b297e2c + 540a685 commit efe57a3
Show file tree
Hide file tree
Showing 4 changed files with 34 additions and 35 deletions.
8 changes: 4 additions & 4 deletions attention.py
Expand Up @@ -40,8 +40,8 @@


loss = tf.reshape(tf.nn.sparse_softmax_cross_entropy_with_logits(
tf.reshape(logits,[-1,V]),
tf.reshape(seqs[:,1:],[-1])
logits=tf.reshape(logits,[-1,V]),
labels=tf.reshape(seqs[:,1:],[-1])
), [tf.shape(X)[0], -1])

mask_mult = tf.to_float(mask[:,1:])
Expand Down Expand Up @@ -132,7 +132,7 @@ def score(set='valid',batch_size=32):
# init = tf.initialize_all_variables()
sess.run(init)
saver = tf.train.Saver()
saver.restore(sess,'./weights_best.ckpt')
# saver.restore(sess,'./weights_best.ckpt')
## start the tensorflow QueueRunner's
# tf.train.start_queue_runners(sess=sess)
## start our custom queue runner's threads
Expand Down Expand Up @@ -171,7 +171,7 @@ def score(set='valid',batch_size=32):
val_loss, val_perp = score('valid',BATCH_SIZE)
if val_perp < best_perp:
best_perp = val_perp
# saver.save(sess,"weights_best.ckpt")
saver.save(sess,"weights_best.ckpt")
print "\tBest Perplexity Till Now! Saving state!"
else:
lr = lr * 0.5
Expand Down
1 change: 0 additions & 1 deletion data_loaders.py
Expand Up @@ -3,7 +3,6 @@
import threading
import numpy as np
import re
import cv2
import glob
from PIL import Image

Expand Down
2 changes: 1 addition & 1 deletion tflib/network.py
Expand Up @@ -69,7 +69,7 @@ def vgg16(X,num_feats=64):

return X

def im2latex_cnn(X, num_feats, bn, train_mode='True'):
def im2latex_cnn(X, num_feats, bn, train_mode=True):
X = X-128.
X = X/128.

Expand Down
58 changes: 29 additions & 29 deletions tflib/ops.py
Expand Up @@ -137,7 +137,7 @@ def Linear(
else:
reshaped_inputs = tf.reshape(inputs, [-1, input_dim])
result = tf.matmul(reshaped_inputs, weight)
result = tf.reshape(result, tf.pack(tf.unpack(tf.shape(inputs))[:-1] + [output_dim]))
result = tf.reshape(result, tf.stack(tf.unstack(tf.shape(inputs))[:-1] + [output_dim]))

if bias:
b = tflib.param(
Expand Down Expand Up @@ -204,7 +204,7 @@ def conv2d(
out = tf.nn.bias_add(out,b,data_format='NCHW')

if batchnorm:
out = tf.contrib.layers.batch_norm(out,scope=scope,is_training=is_training,data_format='NCHW')
out = tf.contrib.layers.batch_norm(inputs=out,scope=scope,is_training=is_training,data_format='NCHW')

return out

Expand Down Expand Up @@ -266,19 +266,19 @@ def __call__(self, inputs, state, scope=None):
gates = tf.nn.sigmoid(
tflib.ops.Linear(
self._name+'.Gates',
tf.concat(1, [inputs, state]),
tf.concat(axis=1, values=[inputs, state]),
self._n_in + self._n_hid,
2 * self._n_hid
)
)

update, reset = tf.split(1, 2, gates)
update, reset = tf.split(axis=1, num_or_size_splits=2, value=gates)
scaled_state = reset * state

candidate = tf.tanh(
tflib.ops.Linear(
self._name+'.Candidate',
tf.concat(1, [inputs, scaled_state]),
tf.concat(axis=1, values=[inputs, scaled_state]),
self._n_in + self._n_hid,
self._n_hid
)
Expand All @@ -303,7 +303,7 @@ def GRU(
"""
h0 = tflib.param(name+'.h0', np.zeros(n_hid, dtype='float32'))
batch_size = tf.shape(inputs)[0]
h0 = tf.reshape(tf.tile(h0, tf.pack([batch_size])), tf.pack([batch_size, n_hid]))
h0 = tf.reshape(tf.tile(h0, tf.stack([batch_size])), tf.stack([batch_size, n_hid]))
return tf.nn.dynamic_rnn(GRUCell(name, n_in, n_hid), inputs, initial_state=h0, swap_memory=True)[0]

class LSTMCell(tf.nn.rnn_cell.RNNCell):
Expand All @@ -322,21 +322,21 @@ def output_size(self):
return self._n_hid

def __call__(self, inputs, state, scope=None):
c_tm1, h_tm1 = tf.split(1,2,state)
c_tm1, h_tm1 = tf.split(axis=1,num_or_size_splits=2,value=state)
gates = tflib.ops.Linear(
self._name+'.Gates',
tf.concat(1, [inputs, h_tm1]),
tf.concat(axis=1, values=[inputs, h_tm1]),
self._n_in + self._n_hid,
4 * self._n_hid,
activation='sigmoid'
)

i_t,f_t,o_t,g_t = tf.split(1, 4, gates)
i_t,f_t,o_t,g_t = tf.split(axis=1, num_or_size_splits=4, value=gates)

c_t = tf.nn.sigmoid(f_t+self._forget_bias)*c_tm1 + tf.nn.sigmoid(i_t)*tf.tanh(g_t)
h_t = tf.nn.sigmoid(o_t)*tf.tanh(c_t)

new_state = tf.concat(1, [c_t,h_t])
new_state = tf.concat(axis=1, values=[c_t,h_t])

return h_t,new_state

Expand All @@ -357,7 +357,7 @@ def LSTM(
batch_size = tf.shape(inputs)[0]
if h0 is None:
h0 = tflib.param(name+'.init.h0', np.zeros(2*n_hid, dtype='float32'))
h0 = tf.reshape(tf.tile(h0_1, tf.pack([batch_size])), tf.pack([batch_size, 2*n_hid]))
h0 = tf.reshape(tf.tile(h0_1, tf.stack([batch_size])), tf.stack([batch_size, 2*n_hid]))

return tf.nn.dynamic_rnn(LSTMCell(name, n_in, n_hid), inputs, initial_state=h0, swap_memory=True)

Expand All @@ -382,19 +382,19 @@ def BiLSTM(
batch_size = tf.shape(inputs)[0]
if h0_1 is None:
h0_1 = tflib.param(name+'.init.h0_1', np.zeros(2*n_hid, dtype='float32'))
h0_1 = tf.reshape(tf.tile(h0_1, tf.pack([batch_size])), tf.pack([batch_size, 2*n_hid]))
h0_1 = tf.reshape(tf.tile(h0_1, tf.stack([batch_size])), tf.stack([batch_size, 2*n_hid]))

if h0_2 is None:
h0_2 = tflib.param(name+'.init.h0_2', np.zeros(2*n_hid, dtype='float32'))
h0_2 = tf.reshape(tf.tile(h0_2, tf.pack([batch_size])), tf.pack([batch_size, 2*n_hid]))
h0_2 = tf.reshape(tf.tile(h0_2, tf.stack([batch_size])), tf.stack([batch_size, 2*n_hid]))


cell1 = LSTMCell(name+'_fw', n_in, n_hid)
cell2 = LSTMCell(name+'_bw', n_in, n_hid)

seq_len = tf.tile(tf.expand_dims(tf.shape(inputs)[1],0),[batch_size])
outputs = tf.nn.bidirectional_dynamic_rnn(cell1, cell2, inputs, sequence_length=seq_len, initial_state_fw=h0_1, initial_state_bw=h0_2, swap_memory=True)
return tf.concat(2,[outputs[0][0],outputs[0][1]])
return tf.concat(axis=2,values=[outputs[0][0],outputs[0][1]])

'''
Attentional Decoder as proposed in HarvardNLp paper (https://arxiv.org/pdf/1609.04938v1.pdf)
Expand All @@ -420,17 +420,17 @@ def output_size(self):

def __call__(self, _input, state, scope=None):

h_tm1, c_tm1, output_tm1 = tf.split(1,3,state)
h_tm1, c_tm1, output_tm1 = tf.split(axis=1,num_or_size_splits=3,value=state)

gates = tflib.ops.Linear(
self._name+'.Gates',
tf.concat(1, [_input, output_tm1]),
tf.concat(axis=1, values=[_input, output_tm1]),
self._n_in + self._n_hid,
4 * self._n_hid,
activation='sigmoid'
)

i_t,f_t,o_t,g_t = tf.split(1, 4, gates)
i_t,f_t,o_t,g_t = tf.split(axis=1, num_or_size_splits=4, value=gates)

## removing forget_bias
c_t = tf.nn.sigmoid(f_t)*c_tm1 + tf.nn.sigmoid(i_t)*tf.tanh(g_t)
Expand All @@ -439,7 +439,7 @@ def __call__(self, _input, state, scope=None):

target_t = tf.expand_dims(tflib.ops.Linear(self._name+'.target_t',h_t,self._n_hid,self._n_hid,bias=False),2)
# target_t = tf.expand_dims(h_t,2) # (B, HID, 1)
a_t = tf.nn.softmax(tf.batch_matmul(self._ctx,target_t)[:,:,0],name='a_t') # (B, H*W, D) * (B, D, 1)
a_t = tf.nn.softmax(tf.matmul(self._ctx,target_t)[:,:,0],name='a_t') # (B, H*W, D) * (B, D, 1)
print a_t.name

def _debug_bkpt(val):
Expand All @@ -453,20 +453,20 @@ def _debug_bkpt(val):
a_t = tf.identity(a_t,name='a_t_debug')

a_t = tf.expand_dims(a_t,1) # (B, 1, H*W)
z_t = tf.batch_matmul(a_t,self._ctx)[:,0]
z_t = tf.matmul(a_t,self._ctx)[:,0]
# a_t = tf.expand_dims(a_t,2)
# z_t = tf.reduce_sum(a_t*self._ctx,1)

output_t = tf.tanh(tflib.ops.Linear(
self._name+'.output_t',
tf.concat(1,[h_t,z_t]),
tf.concat(axis=1,values=[h_t,z_t]),
self._D+self._n_hid,
self._n_hid,
bias=False,
activation='tanh'
))

new_state = tf.concat(1,[h_t,c_t,output_t])
new_state = tf.concat(axis=1,values=[h_t,c_t,output_t])

return output_t,new_state

Expand Down Expand Up @@ -553,20 +553,20 @@ def output_size(self):
return self._n_out

def __call__(self, _input, state, scope=None):
h_tm1, c_tm1, output_tm1 = tf.split(1,3,state[:,:3*self._n_hid])
h_tm1, c_tm1, output_tm1 = tf.split(axis=1,num_or_size_splits=3,value=state[:,:3*self._n_hid])
_input = tf.argmax(state[:,3*self._n_hid:],axis=1)
_input = tflib.ops.Embedding('Embedding',self._n_out,self._n_in,_input)


gates = tflib.ops.Linear(
self._name+'.Gates',
tf.concat(1, [_input, output_tm1]),
tf.concat(axis=1, values=[_input, output_tm1]),
self._n_in + self._n_hid,
4 * self._n_hid,
activation='sigmoid'
)

i_t,f_t,o_t,g_t = tf.split(1, 4, gates)
i_t,f_t,o_t,g_t = tf.split(axis=1, num_or_size_splits=4, value=gates)

## removing forget_bias
c_t = tf.nn.sigmoid(f_t)*c_tm1 + tf.nn.sigmoid(i_t)*tf.tanh(g_t)
Expand All @@ -575,23 +575,23 @@ def __call__(self, _input, state, scope=None):

target_t = tf.expand_dims(tflib.ops.Linear(self._name+'.target_t',h_t,self._n_hid,self._n_hid,bias=False),2)
# target_t = tf.expand_dims(h_t,2) # (B, HID, 1)
a_t = tf.nn.softmax(tf.batch_matmul(self._ctx,target_t)[:,:,0],name='a_t') # (B, H*W, D) * (B, D, 1)
a_t = tf.nn.softmax(tf.matmul(self._ctx,target_t)[:,:,0],name='a_t') # (B, H*W, D) * (B, D, 1)
a_t = tf.expand_dims(a_t,1) # (B, 1, H*W)
z_t = tf.batch_matmul(a_t,self._ctx)[:,0]
z_t = tf.matmul(a_t,self._ctx)[:,0]
# a_t = tf.expand_dims(a_t,2)
# z_t = tf.reduce_sum(a_t*self._ctx,1)

output_t = tf.tanh(tflib.ops.Linear(
self._name+'.output_t',
tf.concat(1,[h_t,z_t]),
tf.concat(axis=1,values=[h_t,z_t]),
self._D+self._n_hid,
self._n_hid,
bias=False,
activation='tanh'
))

logits = tf.nn.softmax(tflib.ops.Linear('MLP.1',output_t,self._n_hid,self._n_out))
new_state = tf.concat(1,[h_t,c_t,output_t,logits])
new_state = tf.concat(axis=1,values=[h_t,c_t,output_t,logits])

return logits,new_state

Expand Down Expand Up @@ -646,7 +646,7 @@ def fn(prev_out,i):

V_t = tf.reshape(tf.transpose(V_cap,[1,0,2,3]),[batch_size,-1,ENC_DIM*2]) # (B, L, ENC_DIM)

h0_dec = tf.concat(1,[tf.tile(tflib.param(
h0_dec = tf.concat(axis=1,values=[tf.tile(tflib.param(
name+'.Decoder.init.h0',
np.zeros((1,3*DEC_DIM)).astype('float32')
),[batch_size,1]),tf.reshape(tf.one_hot(500,output_dim),(batch_size,output_dim))])
Expand Down

0 comments on commit efe57a3

Please sign in to comment.