In [42]:
import tensorflow as tf
import numpy as np
from tensorflow.keras.layers import Layer, Dense

In [15]:
batch = 128
seq_q = 10
seq_len = 10
d_k = 512
q = tf.constant(np.random.randn(batch, seq_q, d_k))
k, v = tf.constant(np.random.randn(batch, seq_len, d_k)), tf.constant(np.random.randn(batch, seq_len, d_k))
mask = tf.constant(np.random.randint(2, size=(batch, seq_q, seq_len)))

In [65]:
def attention(q, k, v, mask=None):
    # q: [batch, seq_q, d_k]
    # k, v: [batch, seq_len, d_k]
    # mask: [batch, seq_q, seq_len]
    # output: [batch, seq_q, d_k]
    d_k = q.shape[-1]
    # tf.einsum 交换最后两个纬度
    scores = tf.matmul(q, tf.einsum('...ij->...ji', k)) / np.sqrt(d_k) # batch, seq_q, seq_len
    if mask is not None:
        scores = tf.where(tf.equal(mask, 0), 1e-9, scores)
    att_scores = tf.nn.softmax(scores) # batch, seq_q, seq_len
    output = tf.matmul(att_scores, v)
    return output

In [66]:
attention(q, k, v, mask=mask).shape

TensorShape([128, 10, 512])

In [70]:
class MultiHeadedAttention(Layer):
    def __init__(self, h, d_model):
        super(MultiHeadedAttention, self).__init__()
        assert d_model % h == 0
        # We assume d_v always equals d_k
        self.d_k = d_model // h
        self.h = h
        self.linears = [Dense(d_model) for _ in range(4)]

    def call(self, query, key, value, mask=None):
        
        batch = query.shape[0]
        # 1) Do all the linear projections in batch from d_model => h x d_k
        # batch, h, seq_len, d_k
        q, k, v = [
            tf.transpose(tf.reshape(dense(x), shape=[batch, -1, self.h, self.d_k]), [0, 2, 1, 3])
            for dense, x in zip(self.linears[:3], (query, key, value))]

        # 2) Apply attention on all the projected vectors in batch.
        out = attention(q, k, v, mask=mask) # batch, h, seq_len, d_k
        print(out.shape)
        # 3) "Concat" using a view and apply a final linear.
        x = tf.reshape(tf.transpose(out, [0, 2, 1, 3]), shape=[batch, -1, self.h * self.d_k])
        return self.linears[-1](x)

In [71]:
multiheadatt = MultiHeadedAttention(8, 512)

In [74]:
multiheadatt(q, k, v).shape

(128, 8, 10, 64)


TensorShape([128, 10, 512])