# 1. Understanding Attention

- Before running the jupyter notebook, don't forget to copy it into your drive **(`File` => `Save a copy in Drive`)**. *Failing to do this step may result in losing the progress of your code.*
- For this notebook, please fill in the line(s) directly after a `#TODO` comment with your answers.
- For the submission of the assignment, please download this notebook as a **Python file**, named `A2S1.py`.

## Imports and Setup

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [None]:
torch.manual_seed(447)

key = torch.randn(4, 3)
key /= torch.norm(key, dim=1, keepdim=True)
key.round_(decimals=2)

value = torch.randn(4, 3)
value /= torch.norm(value, dim=1, keepdim=True)
value.round_(decimals=2)

print(f'key:\n{key}')
print(f'value:\n{value}')

key:
tensor([[ 0.4700,  0.6500,  0.6000],
        [ 0.6400,  0.5000, -0.5900],
        [-0.0300, -0.4800, -0.8800],
        [ 0.4300, -0.8300,  0.3500]])
value:
tensor([[-0.0700, -0.8800,  0.4700],
        [ 0.3700, -0.9300, -0.0700],
        [-0.2500, -0.7500,  0.6100],
        [ 0.9400,  0.2000,  0.2800]])


In [None]:
def attention(query, key, value):
    """
    Note that we remove scaling for simplicity.
    """
    return F.scaled_dot_product_attention(query, key, value, scale=1)


def check_query(query, target, key=key, value=value):
    """
    Helper function for you to check if your query is close to the required target matrix.
    """
    a_out = attention(query, key, value)
    print("maximum absolute element-wise difference:", (target - a_out).abs().max())

## 1.2. Selection via Attention

In [None]:
# Define a query vector to ”select” the first value vector

# TODO:
query121 =

# compare output of attention with desired output
print(query121)
check_query(query121, value[0])

In [None]:
# Define a query matrix which results in an identity mapping – select all the value vectors

# TODO:
query122 =

# compare output of attention with desired output
print(query122)
check_query(query122, value)

## 1.3. Averaging via Attention

In [None]:
# define a query vector which averages all the value vectors

# TODO:
query131 =

# compare output of attention with desired output
print(query131)
target = torch.reshape(value.mean(0, keepdims=True), (3,))  # reshape to a vector
check_query(query131, target)

In [None]:
# define a query vector which averages the first two value vectors

# TODO:
query132 =

# compare output of attention with desired output
print(query132)
target = torch.reshape(value[(0, 1),].mean(0, keepdims=True), (3,))  # reshape to a vector
check_query(query132, target)

## 1.4. Interactions within Attention

In [None]:
# Define a replacement for only the third key vector k[2] such that the result of attention
# with the same unchanged query q from (1.3.2) averages the first three value vectors.
m_key = key.clone()

# TODO:
m_key[2] =

# compare output of attention with desired output
check_query(query132, value[(0, 1, 2),].mean(0, keepdims=True), key=m_key)

In [None]:
# Define a replacement for only the third key vector k[2] such that the result of attention
# with the same unchanged query q from (1.3.2) returns the third value vector v[2].
m_key = key.clone()

# TODO:
m_key[2] =
m_key[2] /= m_key[2].norm()

# compare output of attention with desired output
check_query(query132, value[2], key=m_key)