# Implementation of a multihead attention layer


**Q1 (5 points)** In this part of the the assignment, you work is to implement a multihead attention layer. Your implementation will be compared against `torch.nn.MultiheadAttention`: the output of your implementation should match that of the torch MHA layer. 

You should put your work in `mha_implementation.py`. You can make changes to this nodebook to debug your code, but your implementation should pass the test case in this released version of notebook.  

In [1]:
import torch

%load_ext autoreload
%autoreload 2
%autosave 180

Autosaving every 180 seconds


In [2]:
# first initialize a set of random inputs for the MHA layer
# This is the test case we will use. You code should be correct
# if you can pass this test case later.

input_dim = 6
num_heads = 2

batch_size = 5
length1 = 7
length2 = 11

query = torch.rand([length1, batch_size, input_dim])
key = torch.rand([length2, batch_size, input_dim])
value = torch.rand([length2, batch_size, input_dim])

In [3]:
# run the computation with the Torch implementation

mha = torch.nn.MultiheadAttention(input_dim, num_heads)

with torch.no_grad():
    torch_output, _ = mha(query, key, value)

In [4]:
# Then we extract parameters from the Torch MHA layer.

weight = mha.in_proj_weight
bias = mha.in_proj_bias

out_weight = mha.out_proj.weight
out_bias = mha.out_proj.bias

w_q, w_k, w_v = weight.chunk(3)
b_q, b_k, b_v = bias.chunk(3)

wqs = list(w_q.chunk(num_heads))
wks = list(w_k.chunk(num_heads))
wvs = list(w_v.chunk(num_heads))

bqs = list(b_q.chunk(num_heads))
bks = list(b_k.chunk(num_heads))
bvs = list(b_v.chunk(num_heads))


in_wbs = (wqs, wks, wvs, bqs, bks, bvs)
out_wbs = (out_weight, out_bias)


In [5]:
# now we test your implementation here.

from mha_implementation import mha_137

output = mha_137(query, key, value, in_wbs, out_wbs)

# if you implementation is correct, `diff` should be smaller than 1e-6.
with torch.no_grad():
    diff = torch.mean(torch.abs(output - torch_output)).numpy()

print("The entry-wise difference between outputs of two implementations should be 1e-5.",
      "The difference from your implementation is ", diff)


The entry-wise difference between outputs of two implementations should be 1e-5. The difference from your implementation is  8.235553e-09
