In [1]:
import tensorflow as tf

def scaled_dot_product_attention(Q, K, V, mask=None):
    d_k = tf.cast(tf.shape(K)[-1], tf.float32)
    scores = tf.matmul(Q, K, transpose_b=True) / tf.math.sqrt(d_k)  

    if mask is not None:
        scores += (mask * -1e9)

    weights = tf.nn.softmax(scores, axis=-1)
    output = tf.matmul(weights, V)  
    return output, weights

In [2]:
def multi_head_attention(X, Wq, Wk, Wv, Wo, num_heads):
    batch_size = tf.shape(X)[0]
    seq_len = tf.shape(X)[1]
    d_model = X.shape[-1]
    depth = d_model // num_heads

    # Linear projections
    Q = tf.matmul(X, Wq)  
    K = tf.matmul(X, Wk)
    V = tf.matmul(X, Wv)

    def split_heads(x):
        x = tf.reshape(x, (batch_size, seq_len, num_heads, depth))
        return tf.transpose(x, perm=[0, 2, 1, 3])  # [batch, heads, seq_len, depth]

    # Split into heads
    Q = split_heads(Q)
    K = split_heads(K)
    V = split_heads(V)

    # Scaled Dot-Product Attention
    attention_output, _ = scaled_dot_product_attention(Q, K, V)  # [batch, heads, seq_len, depth]

    # Concatenate heads
    attention_output = tf.transpose(attention_output, perm=[0, 2, 1, 3])
    concat_attention = tf.reshape(attention_output, (batch_size, seq_len, d_model))  

    # Final linear layer
    output = tf.matmul(concat_attention, Wo)  
    return output

In [3]:
def init_weights(d_model, num_heads):
    depth = d_model // num_heads
    initializer = tf.random_normal_initializer()

    Wq = tf.Variable(initializer(shape=(d_model, d_model)))
    Wk = tf.Variable(initializer(shape=(d_model, d_model)))
    Wv = tf.Variable(initializer(shape=(d_model, d_model)))
    Wo = tf.Variable(initializer(shape=(d_model, d_model)))

    return Wq, Wk, Wv, Wo

In [10]:
batch_size = 2
seq_len = 5
d_model = 64
num_heads = 8

X = tf.random.normal((batch_size, seq_len, d_model))
Wq, Wk, Wv, Wo = init_weights(d_model, num_heads)

output = multi_head_attention(X, Wq, Wk, Wv, Wo, num_heads)

In [12]:
output

<tf.Tensor: shape=(2, 5, 64), dtype=float32, numpy=
array([[[-5.18431477e-02, -1.60861850e-01, -1.38540715e-01,
         -7.72450045e-02,  7.39555731e-02,  3.23570408e-02,
          5.09622358e-02,  1.19967684e-01, -7.36276358e-02,
         -2.49388609e-02, -3.18397433e-02, -3.55422348e-02,
         -9.96151846e-03, -6.43898919e-02,  3.35455276e-02,
         -5.07098343e-03,  1.49656832e-01, -5.12253428e-05,
         -7.46125262e-03, -4.35261615e-02,  3.36350012e-03,
          3.43101025e-02, -6.94895759e-02,  9.08236355e-02,
         -2.24284917e-01, -1.12803847e-01,  4.31766122e-04,
         -9.61931348e-02, -7.04454333e-02, -1.23298973e-01,
          2.22615190e-02,  9.05817747e-02,  5.73067665e-02,
          3.44301648e-02, -3.80388610e-02,  6.65489808e-02,
          2.19334047e-02, -8.69811848e-02,  3.06643825e-02,
          4.11964953e-02, -4.77124145e-03, -1.81447398e-02,
         -1.19961388e-01, -2.95224208e-02,  3.60641070e-02,
         -1.03311233e-01, -1.61343962e-01,  3.79