In [None]:
import keras
import einops

In [None]:
class SingleHeadAttention(keras.layers.Layer):
    """
    Computes the output of a single head of attention, inclusive of the output projection.
    The shape of the input tensors is expected to be (batch_size, time_steps, model_dim).
    The shape of the output tensor will be (batch_size, time_steps, model_dim).
    """

    def __init__(self, d_k:int, d_v:int=None, causal:bool=False, **kwargs):
        """
        Constructor for the SingleHeadAttention layer.

        d_k: int, the dimension of the key and query vectors.
        d_v: int, the dimension of the value vectors. If None, it is set to d_k.
        causal: bool, whether to apply a causal mask to the attention scores.

        Note: the learnable weights are added to the layer in the `build` method, once we know dimension of the embeddings.
        """
        super().__init__(**kwargs)
        # Store the parameters for the layer.
        self.d_k = d_k
        ## YOUR CODE HERE
        # Precompute the scaling factor (1/sqrt(d_k)) for the attention scores.
        self.scaling_factor = ## YOUR CODE HERE

    def build(self, batch_input_shape_list):
        """
        Add the learnable weights to the layer.

        batch_input_shape_list: list of input shapes. Can be used to infer the shape of the weights.
        """
        Q_shape,  *_ = batch_input_shape_list
        batch_size, time_steps, model_dim = Q_shape

        # Add weights in the order WQ, WK, WV, WO
        self.WQ = self.add_weight(shape=(model_dim, self.d_k), initializer='he_normal')
        ## YOUR CODE HERE
        

    def call(self, inputs:list):
        """
        Inputs should be a list of 2 or 3 tensors, in the order Q, K, V.
        If only two tensors are provided, the second tensor is used for both K and V.
        """
        if len(inputs) == 3:
            Q, K, V = inputs
        elif len(inputs) == 2:
            Q, K = inputs
            V = K

        # Compute the scores using matrix multiplications.
        score = ## YOUR CODE HERE

        if self.causal:
            # Apply the causal mask to the scores. 
            # Steps: 
            # 1. fill a tensor with -1e9 (a very large negative number) 
            # 2. Take the upper triangular part of the tensor
            # 3. Add this to the original scores.
            
            ## YOUR CODE HERE

        # Compute the attention weights using softmax.
        attn_weigths = ## YOUR CODE HERE

        # Take linear combinations of the values and apply the output projection using matrix multiplication.
        ## YOUR CODE HERE

In [None]:
class SimpleMultiHeadAttention(keras.layers.Layer):
    """
    Simple implementation of multi-head attention.
    Is uses $h$ instances of SingleHeadAttention and sums their outputs in a loop.
    """

    def __init__(self, d_k, d_v=None, num_heads=1, causal=False, **kwargs):
        super().__init__(**kwargs)
        # Store the parameters for the layer.
        ## YOUR CODE HERE
    
        # Create the list of SingleHeadAttention instances.
        ## YOUR CODE HERE

    def call(self, inputs):
        """
        Inputs should be a list of 2 or 3 tensors, in the order Q, K, V.
        If only two tensors are provided, the second tensor is used for both K and V.
        """

        ## YOUR CODE HERE

### Some Tests

#### Test for `SingleHeadAttention`

Execute the cell below to test the `SingleHeadAttention` class. 
This is done by comparing the results with that of Keras' built-in `MultiHeadAttention` class.
The weights are copied from the Keras class to ensure that the results are comparable.

You should see fairly small differences (hopefully zero) between the two outputs. 

In [None]:
key_dim = 16
num_heads = 1
causal = False # Try both False and True
my_attention = SingleHeadAttention(d_k=key_dim, causal=causal)
keras_attention = keras.layers.MultiHeadAttention(num_heads=num_heads, key_dim=key_dim, use_bias=False)

# Generate random data, batch of 100, 10 time steps, 32 features
X = keras.random.normal(shape=(100,10,32))
keras_output = keras_attention(X, X, X, use_causal_mask=causal)
print(f"Shape of keras output: {keras.ops.shape(keras_output)}")
print("Shape of keras weights")
for w in keras_attention.get_weights():
    print(keras.ops.shape(w))

my_output = my_attention([X, X, X])
print(f"Shape of SingleHeadAttention output: {keras.ops.shape(keras_output)}")
print("Shape of SingleHeadAttention weights")
for w in my_attention.get_weights():
    print(keras.ops.shape(w))

squeezed_weights = []
for w in keras_attention.get_weights():
    squeezed_weights.append(keras.ops.squeeze(w))
my_attention.set_weights(squeezed_weights)
my_output = my_attention([X, X, X])
print(f"Largest absolute difference between outputs: {keras.ops.max(keras.ops.abs(my_output - keras_output))}")
print(f"Largest relative difference between outputs: {keras.ops.max(keras.ops.abs(my_output - keras_output) / keras.ops.abs(keras_output))}")


#### Test for `SimpleMultiHeadAttention`

Execute the cell below to test the `SimpleMultiHeadAttention` class. 
This is done by comparing the results with that of Keras' built-in `MultiHeadAttention` class.
The weights are copied from the Keras class to ensure that the results are comparable.

You should see fairly small differences between the two outputs, although the relative differences may be a bit larger.
This is probably due to the fact that the results are summed in a loop.

In [None]:
key_dim = 16
num_heads = 2
causal = True # Try both False and True
my_attention = SimpleMultiHeadAttention(num_heads=num_heads, d_k=key_dim, causal=causal)
keras_attention = keras.layers.MultiHeadAttention(num_heads=num_heads, key_dim=key_dim, use_bias=False)
X = keras.random.normal(shape=(200,10,32))
keras_output = keras_attention(X, X, X, use_causal_mask=causal)
print(f"Shape of keras output: {keras.ops.shape(keras_output)}")
print("Shape of keras weights")
for w in keras_attention.get_weights():
    print(keras.ops.shape(w))

my_output = my_attention([X, X, X])
print(f"Shape of SimpleMultiHeadAttention output: {keras.ops.shape(keras_output)}")
print("Shape of SimpleMultiHeadAttention weights")
for w in my_attention.get_weights():
    print(keras.ops.shape(w))

# Put the weights in the right shape and order for SimpleMultiHeadAttention
weights_for_heads = []
for idx in range(num_heads):
    for w in keras_attention.get_weights()[:-1]:
        weights_for_heads.append(w[:, idx, :])
    weights_for_heads.append(keras_attention.get_weights()[-1][idx, :,:])

my_attention.set_weights(weights_for_heads)
my_output = my_attention([X, X, X])
print(f"Largest absolute difference between outputs: {keras.ops.max(keras.ops.abs(my_output - keras_output))}")
print(f"Largest relative difference between outputs: {keras.ops.max(keras.ops.abs(my_output - keras_output) / keras.ops.abs(keras_output))}")
#assert keras.ops.all(keras.ops.isclose(keras_output, my_output)), "Outputs are not close enough"


### MultiHead Attention with Einstein Summation

Implement a second version of `SimpleMultiHeadAttention` using Einstein summation.
You should not rely on `SingleHeadAttention` for this implementation.

In [None]:
class EinopsMultiHeadAttention(keras.layers.Layer):
    """
    Implementation of multi-head attention using einops.einsum.

    No (explicit) loops are used in this implementation.
    """

    def __init__(self, d_k, d_v=None, num_heads=1, causal=False, **kwargs):
        """
        Constructor for the SingleHeadAttention layer.

        d_k: int, the dimension of the key and query vectors.
        d_v: int, the dimension of the value vectors. If None, it is set to d_k.
        num_heads: int, the number of attention heads.
        causal: bool, whether to apply a causal mask to the attention scores.

        Note: the learnable weights are added to the layer in the `build` method, once we know dimension of the embeddings.
        """
        # Call the parent constructor 
        super().__init__(**kwargs)
        # Store the parameters for the layer.
        ## YOUR CODE HERE
        # Precompute the scaling factor (1/sqrt(d_k)) for the attention scores.
        self.scaling_factor = ## YOUR CODE HERE
    
    def build(self, batch_input_shape_list):
        """
        Add the learnable weights to the layer. The shapes of the weights are inferred from the input shapes, 
        and are identical to the shape of the weights in the Keras implementation.

        batch_input_shape_list: list of input shapes. Can be used to infer the shape of the weights.
        """
        Q_shape, *_ = batch_input_shape_list
        _, _, model_dim = Q_shape

        # Add weights in the order WQ, WK, WV, WO
        # The weighs have the shape (model_dim, num_heads, d_k) for Q, K and (model_dim, num_heads, d_v) for V
        # The output projection has shape (num_heads, d_v, model_dim)
        ## YOUR CODE HERE

        

    def call(self, inputs):
        """
        Inputs should be a list of 2 or 3 tensors, in the order Q, K, V.
        If only two tensors are provided, the second tensor is used for both K and V.
        """
        if len(inputs) == 3:
            Q, K, V = inputs
        elif len(inputs) == 2:
            Q, K = inputs
            V = K

        # Compute the queries, keys and values for each head using einsum.
        # qs should have shape (b, h, t_q, d_k)
        # ks should have shape (b, h, t_k, d_k)
        # vs should have shape (b, h, t_k, d_v)            
        qs = ## YOUR CODE HERE
        ks = ## YOUR CODE HERE
        vs = ## YOUR CODE HERE

        # Compute the (unnormalized) scores using einsum.
        # scores should have shape (b, h, t_q, t_k)
        score = ## YOUR CODE HERE
        if self.causal:
             # Apply the causal mask to the scores. 
            # Steps: 
            # 1. fill a tensor with -1e9 (a very large negative number) 
            # 2. Take the upper triangular part of the tensor
            # 3. Add this to the original scores. Rely on broadcasting
            ## YOUR CODE HERE
        
        # Compute the attention weights by applying softmax along the last axis.
        attn_weigths = ## YOUR CODE HERE

        # Compute the linear combinations of the "rows" in vs using einsum. Result should have shape (b, h, t_q, d_v).
        mixed_vs = ## YOUR CODE HERE
        # Compute the output projection using einsum. Result should have shape (b, t_q, d_model).
        result = ## YOUR CODE HERE
        return result

#### Test for `EinopsMultiHeadAttention`

Execute the cell below to test the `EinopsMultiHeadAttention` class. 
This is done by comparing the results with that of Keras' built-in `MultiHeadAttention` class.
The weights are copied from the Keras class to ensure that the results are comparable.

You should see very small differences (ideally zero) between the two outputs.

In [None]:
key_dim = 64
num_heads = 4
causal = True # Try both False and True
my_attention = EinopsMultiHeadAttention(num_heads=num_heads, d_k=key_dim, causal=causal)
keras_attention = keras.layers.MultiHeadAttention(num_heads=num_heads, key_dim=key_dim, use_bias=False)
X = keras.random.normal(shape=(200,10,32))
keras_output = keras_attention(X, X, X, use_causal_mask=causal)
print(f"Shape of keras output: {keras.ops.shape(keras_output)}")
print("Shape of keras weights")
for w in keras_attention.get_weights():
    print(keras.ops.shape(w))

my_output = my_attention([X, X, X])
print(f"Shape of EinopsMultiHeadAttention output: {keras.ops.shape(keras_output)}")
print("Shape of EinopsMultiHeadAttention weights")
for w in my_attention.get_weights():
    print(keras.ops.shape(w))

# Copy the weights from keras to the einops version
my_attention.set_weights(keras_attention.get_weights())
my_output = my_attention([X, X, X])
print(f"Largest absolute difference between outputs: {keras.ops.max(keras.ops.abs(my_output - keras_output))}")
print(f"Largest relative difference between outputs: {keras.ops.max(keras.ops.abs(my_output - keras_output) / keras.ops.abs(keras_output))}")
