In [1]:
import tensorflow as tf

def scaled_dot_product_attention(Q, K, V, mask=None):
    d_k = tf.cast(tf.shape(K)[-1], tf.float32)
    scores = tf.matmul(Q, K, transpose_b=True) / tf.math.sqrt(d_k)
    
    if mask is not None:
        scores += (mask * -1e9)
        
    weights = tf.nn.softmax(scores, axis=-1)
    return tf.matmul(weights, V)

In [2]:
def multi_head_attention(X, Wq, Wk, Wv, Wo, num_heads):
    batch_size = tf.shape(X)[0]
    seq_len = tf.shape(X)[1]
    d_model = X.shape[-1]
    depth = d_model // num_heads

    def split_heads(x):
        x = tf.reshape(x, (batch_size, seq_len, num_heads, depth))
        return tf.transpose(x, perm=[0, 2, 1, 3])

    Q = split_heads(tf.matmul(X, Wq))
    K = split_heads(tf.matmul(X, Wk))
    V = split_heads(tf.matmul(X, Wv))

    attention_output = scaled_dot_product_attention(Q, K, V)
    attention_output = tf.transpose(attention_output, perm=[0, 2, 1, 3])
    concat_attention = tf.reshape(attention_output, (batch_size, seq_len, d_model))
    
    return tf.matmul(concat_attention, Wo)

In [3]:
def add_norm(x, sublayer_out, gamma, beta):
    residual = x + sublayer_out
    
    mean = tf.reduce_mean(residual, axis=-1, keepdims=True)
    var = tf.reduce_mean(tf.square(residual - mean), axis=-1, keepdims=True)
    norm = (residual - mean) / tf.sqrt(var + 1e-6)
    
    return gamma * norm + beta

In [4]:
def ffn(x, d_ff=128):
    d_model = x.shape[-1]
    
    w1 = tf.Variable(tf.random.normal((d_model, d_ff)), trainable=True)
    b1 = tf.Variable(tf.zeros([d_ff]), trainable=True)
    
    w2 = tf.Variable(tf.random.normal((d_ff, d_model)), trainable=True)
    b2 = tf.Variable(tf.zeros([d_model]), trainable=True)
    
    return tf.matmul(tf.nn.relu(tf.matmul(x, w1) + b1), w2) + b2

In [5]:
def run_training_example(encoder_inputs, decoder_inputs, true_targets, all_vars):
    optimizer = tf.optimizers.Adam(1e-3)

    with tf.GradientTape() as tape:
        # Encoder
        enc_input = encoder_inputs
        enc_out = multi_head_attention(enc_input, Wq1, Wk1, Wv1, Wo1, num_heads)
        enc_out = add_norm(enc_input, enc_out, gamma1, beta1)

        # Decoder
        dec_input = decoder_inputs
        dec_out = multi_head_attention(dec_input, Wq2, Wk2, Wv2, Wo2, num_heads)
        dec_out = add_norm(dec_input, dec_out, gamma2, beta2)

        # prediction 
        logits = dec_out

        # Loss
        loss = tf.reduce_mean(tf.square(logits - true_targets))

    grads = tape.gradient(loss, all_vars)
    optimizer.apply_gradients(zip(grads, all_vars))

    print(f"Loss : {loss.numpy():.4f}")

In [6]:
# Hyperparameters
batch_size = 2
seq_len = 5
d_model = 64
num_heads = 8

# Dummy inputs and targets
input_embeddings = tf.random.normal((batch_size, seq_len, d_model))
input_pos = tf.random.normal((batch_size, seq_len, d_model))
encoder_inputs = tf.add(input_embeddings, input_pos)

target_embeddings = tf.random.normal((batch_size, seq_len, d_model))
target_pos = tf.random.normal((batch_size, seq_len, d_model))
decoder_inputs = tf.add(target_embeddings, target_pos)

true_targets = tf.random.normal((batch_size, seq_len, d_model))

# Initialize weights
Wq1 = tf.Variable(tf.random.normal((d_model, d_model)), trainable=True)
Wk1 = tf.Variable(tf.random.normal((d_model, d_model)), trainable=True)
Wv1 = tf.Variable(tf.random.normal((d_model, d_model)), trainable=True)
Wo1 = tf.Variable(tf.random.normal((d_model, d_model)), trainable=True)

Wq2 = tf.Variable(tf.random.normal((d_model, d_model)), trainable=True)
Wk2 = tf.Variable(tf.random.normal((d_model, d_model)), trainable=True)
Wv2 = tf.Variable(tf.random.normal((d_model, d_model)), trainable=True)
Wo2 = tf.Variable(tf.random.normal((d_model, d_model)), trainable=True)

# gamma/beta for add_norm
gamma1 = tf.Variable(tf.ones([d_model]), trainable=True)
beta1 = tf.Variable(tf.zeros([d_model]), trainable=True)
gamma2 = tf.Variable(tf.ones([d_model]), trainable=True)
beta2 = tf.Variable(tf.zeros([d_model]), trainable=True)

# All trainable variables
all_vars = [Wq1, Wk1, Wv1, Wo1, Wq2, Wk2, Wv2, Wo2, gamma1, beta1, gamma2, beta2]

In [7]:
# Run training
for _ in range(10):
    run_training_example(encoder_inputs, decoder_inputs, true_targets, all_vars)




Loss : 2.1840
Loss : 2.1612
Loss : 2.1306
Loss : 2.0946
Loss : 2.0700
Loss : 2.0575
Loss : 2.0472
Loss : 2.0379
Loss : 2.0266
Loss : 2.0131
