In [1]:
import tensorflow as tf
import operations as op

In [2]:

def attention(
    Q, K, V, 
    Q_lengths, K_lengths, 
    attention_type='dot', 
    is_mask=True, mask_value=-2**32+1,
    drop_prob=None):
    '''Add attention layer.
    Args:
        Q: a tensor with shape [batch, Q_time, Q_dimension]
        K: a tensor with shape [batch, time, K_dimension]
        V: a tensor with shape [batch, time, V_dimension]

        Q_length: a tensor with shape [batch]
        K_length: a tensor with shape [batch]

    Returns:
        a tensor with shape [batch, Q_time, V_dimension]

    Raises:
        AssertionError: if
            Q_dimension not equal to K_dimension when attention type is dot.
    '''
    assert attention_type in ('dot', 'bilinear')
    if attention_type == 'dot':
        assert Q.shape[-1] == K.shape[-1]

    Q_time = Q.shape[1]
    K_time = K.shape[1]

    if attention_type == 'dot':
        logits = op.dot_sim(Q, K) #[batch, Q_time, time]
    if attention_type == 'bilinear':
        logits = op.bilinear_sim(Q, K)

    if is_mask:
        mask = op.mask(Q_lengths, K_lengths, Q_time, K_time) #[batch, Q_time, K_time]
        logits = mask * logits + (1 - mask) * mask_value
    
    attention = tf.nn.softmax(logits)

    if drop_prob is not None:
        print('use attention drop')
        attention = tf.nn.dropout(attention, drop_prob)

    return op.weighted_sum(attention, V)


In [13]:
batch_size = 2
turns_len = 5
words = 10
dim =10

input_turns = tf.placeholder(tf.float32, [2,5,10,10])
respones  = tf.placeholder(tf.float32, [batch_size, words, dim])
respones_len = tf.placeholder(tf.int32,[batch_size])

In [16]:
input_turns = tf.transpose(input_turns,perm=[1,0,2,3])
# input_turns = tf.transpose(input_turns,perm=[1,0,2,3])
print(input_turns)
_turn_match = []

for _t in tf.split(input_turns,5,0):
    _t = tf.squeeze(_t)
    _match_result= attention(respones, _t,  _t, respones_len, respones_len)
    _turn_match.append(_match_result)

Tensor("transpose_11:0", shape=(5, 2, 10, 10), dtype=float32)


In [17]:
_turn_match

[<tf.Tensor 'einsum_14/transpose_2:0' shape=(2, 10, 10) dtype=float32>,
 <tf.Tensor 'einsum_17/transpose_2:0' shape=(2, 10, 10) dtype=float32>,
 <tf.Tensor 'einsum_20/transpose_2:0' shape=(2, 10, 10) dtype=float32>,
 <tf.Tensor 'einsum_23/transpose_2:0' shape=(2, 10, 10) dtype=float32>,
 <tf.Tensor 'einsum_26/transpose_2:0' shape=(2, 10, 10) dtype=float32>]