In [None]:
import tensorflow as tf
import numpy as np


_BATCH_SIZE = 2
_TIME1 = 5
_TIME2 = 3


class DynRnn:
	def __init__(self):
		self.X = None
		self.Y = None
		self.op = None
		self.loss = None
		self.train = None
		self.alphas = None
		self.attn = None
		self.batch_size = None
		self.model()
	
	def _linear(
			self, inputs, n_output, name, bias_init=tf.zeros_initializer,
			weight_init=tf.random_normal_initializer, activation=None):
		assert inputs.get_shape()[-1] is not None
		with tf.variable_scope(name):
			weights = tf.get_variable('w', [inputs.get_shape()[-1], n_output], initializer=weight_init)
			bias = tf.get_variable('b', [n_output], initializer=bias_init)
			op = tf.matmul(inputs, weights) + bias
			
			if activation:
				op = activation(op)
		return op

	def model(self):
		self.X = tf.placeholder(dtype=tf.float32, shape=[None, None, 2], name='X')
		self.Y = tf.placeholder(dtype=tf.float32, shape=[None, None, 2], name='Y')
		self.batch_size = tf.placeholder(dtype=tf.int32, shape=[], name="batch_size")
		
		cell_enc = tf.nn.rnn_cell.BasicLSTMCell(num_units=10, name="enc_cell")
		cell_dec = tf.nn.rnn_cell.BasicLSTMCell(num_units=10, name="dec_cell")
		
		enc_outputs, enc_state = tf.nn.dynamic_rnn(cell=cell_enc, inputs=self.X, dtype=tf.float32)
		dec_outputs, dec_state = tf.nn.dynamic_rnn(cell=cell_dec, inputs=self.Y, dtype=tf.float32)
		
		# O = tf.get_variable('o', shape=[None, 5, 10], dtype=tf.float32)
		# O = tf.reduce_sum(enc_outputs, axis=1)
		
		self.attn = self._attention(enc_outputs, dec_outputs)
	
	def _attention(self, encoder_states, decoder_states, time_major=False, return_alphas=False):
		if isinstance(encoder_states, tuple):
			# In case of Bi-RNN, concatenate the forward and the backward RNN outputs.
			encoder_states = tf.concat(encoder_states, 2)
		
		if time_major:
			# (T,B,D) => (B,T,D)
			encoder_states = tf.transpose(encoder_states, [1, 0, 2])
			decoder_states = tf.transpose(decoder_states, [1, 0, 2])
			
		encoder_n = encoder_states.shape[2].value  # D value - hidden size of the RNN layer encoder
		# encoder_t = encoder_states.shape[1].value  # D value - hidden size of the RNN layer encoder
		decoder_n = decoder_states.shape[2].value  # D value - hidden size of the RNN layer decoder
		# batch_size = encoder_states.shape[0].value
		
		# Trainable parameters
		w_omega = tf.Variable(tf.random_normal([encoder_n, decoder_n], stddev=0.1), name="w_omega")  # [D1, D2]
		
		with tf.variable_scope("Score"):
			enc_reshape = tf.reshape(encoder_states, [-1, encoder_n], name="enc_reshape")   # [(B*T1), D1]
			h1 = tf.matmul(enc_reshape, w_omega)    # [(B*T1), D1][D1, D2] = [(B*T1), D2]
			h1_reshape = tf.reshape(h1, tf.stack([self.batch_size, -1, decoder_n]), name="h1_reshape")     # [B, T1, D1]
			dec_transpose = tf.transpose(decoder_states, [0, 2, 1])                         # [B, D2, T2]
			score = tf.matmul(h1_reshape, dec_transpose)    # [B, T1, D1][B, D2, T2] = [B, T1, T2]
		
		with tf.variable_scope("align"):
			# For each of the timestamps its vector of size A from `v` is reduced with `u` vector
			alphas = tf.nn.softmax(score, axis=1, name='alphas')  # [B, T1, T2] with softmax on T1
		
		# Output of (Bi-)RNN is reduced with attention vector; the result has (B,D) shape
		outputs = "1"
		# outputs = tf.reduce_sum(encoder_states * tf.expand_dims(alphas, -1), 1)
		
		if not return_alphas:
			return alphas
		else:
			return outputs, alphas
	
	
if __name__ == '__main__':
	my_rnn = DynRnn()
	
	inp = np.random.rand(_BATCH_SIZE, _TIME1, 2)
	op = inp.argmax(axis=1).flatten()
	
	sess = tf.Session()
	init = tf.global_variables_initializer()
	sess.run(init)
	
	print(inp, _BATCH_SIZE)
	
	att = sess.run([my_rnn.attn], feed_dict={my_rnn.X: inp, my_rnn.Y: inp, my_rnn.batch_size: _BATCH_SIZE})
	print(att)

