# 1. Understanding Attention

- Before running the jupyter notebook, don't forget to copy it into your drive **(`File` => `Save a copy in Drive`)**. *Failing to do this step may result in losing the progress of your code.*
- For this notebook, please fill in the line(s) directly after a `#TODO` comment with your answers.
- For the submission of the assignment, please download this notebook as a **Python file**, named `A2S1.py`.

## Imports and Setup

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [3]:
torch.manual_seed(447)

key = torch.randn(4, 3)
key /= torch.norm(key, dim=1, keepdim=True)
key.round_(decimals=2)

value = torch.randn(4, 3)
value /= torch.norm(value, dim=1, keepdim=True)
value.round_(decimals=2)

print(f'key:\n{key}')
print(f'value:\n{value}')

key:
tensor([[ 0.4700,  0.6500,  0.6000],
        [ 0.6400,  0.5000, -0.5900],
        [-0.0300, -0.4800, -0.8800],
        [ 0.4300, -0.8300,  0.3500]])
value:
tensor([[-0.0700, -0.8800,  0.4700],
        [ 0.3700, -0.9300, -0.0700],
        [-0.2500, -0.7500,  0.6100],
        [ 0.9400,  0.2000,  0.2800]])


In [4]:
def attention(query, key, value):
    """
    Note that we remove scaling for simplicity.
    """
    return F.scaled_dot_product_attention(query, key, value, scale=1)


def check_query(query, target, key=key, value=value):
    """
    Helper function for you to check if your query is close to the required target matrix.
    """
    a_out = attention(query, key, value)
    print("maximum absolute element-wise difference:", (target - a_out).abs().max())

## 1.2. Selection via Attention

In [62]:
# Define a query vector to ”select” the first value vector

# TODO:
S = 1e5
query121 = S * key[0].unsqueeze(0)

attn_weight = query121 @ key.transpose(-2, -1)
print("A: ",attn_weight)
attn_weight = torch.softmax(attn_weight, dim=-1)
print("A': ", attn_weight)
output = attn_weight @ value
print("O: ", output)

# compare output of attention with desired output
print(query121)
check_query(query121, value[0])

A:  tensor([[100340.0000,  27180.0000, -85410.0000, -12739.9941]])
A':  tensor([[1., 0., 0., 0.]])
O:  tensor([[-0.0700, -0.8800,  0.4700]])
tensor([[47000.0000, 64999.9961, 60000.0039]])
maximum absolute element-wise difference: tensor(0.)


In [63]:
# Define a query matrix which results in an identity mapping – select all the value vectors

# TODO:
S = 1e5
query122 = S * key

attn_weight = query122 @ key.transpose(-2, -1)
print("A: ",attn_weight)
attn_weight = torch.softmax(attn_weight, dim=-1)
print("A': ", attn_weight)
output = attn_weight @ value
print("O: ", output)

# compare output of attention with desired output
print(query122)
check_query(query122, value)

A:  tensor([[100340.0000,  27180.0000, -85410.0000, -12739.9951],
        [ 27180.0000, 100770.0000,  25999.9961, -34630.0000],
        [-85410.0000,  25999.9980, 100570.0000,   7750.0005],
        [-12739.9951, -34630.0000,   7750.0000,  99630.0000]])
A':  tensor([[1., 0., 0., 0.],
        [0., 1., 0., 0.],
        [0., 0., 1., 0.],
        [0., 0., 0., 1.]])
O:  tensor([[-0.0700, -0.8800,  0.4700],
        [ 0.3700, -0.9300, -0.0700],
        [-0.2500, -0.7500,  0.6100],
        [ 0.9400,  0.2000,  0.2800]])
tensor([[ 47000.0000,  64999.9961,  60000.0039],
        [ 64000.0000,  50000.0000, -58999.9961],
        [ -3000.0000, -48000.0000, -88000.0000],
        [ 43000.0000, -83000.0000,  35000.0000]])
maximum absolute element-wise difference: tensor(0.)


## 1.3. Averaging via Attention

In [78]:
# define a query vector which averages all the value vectors

# TODO:
S = 1e-2
query131 = torch.mean(S * key, dim=0).unsqueeze(0)

attn_weight = query131 @ key.transpose(-2, -1)
print("A: ",attn_weight)
attn_weight = torch.softmax(attn_weight, dim=-1)
print("A': ", attn_weight)
output = attn_weight @ value
print("O: ", output)

# compare output of attention with desired output
print(query131)
target = torch.reshape(value.mean(0, keepdims=True), (3,))  # reshape to a vector
check_query(query131, target)

A:  tensor([[0.0007, 0.0030, 0.0012, 0.0015]])
A':  tensor([[0.2498, 0.2503, 0.2499, 0.2500]])
O:  tensor([[ 0.2476, -0.5901,  0.3223]])
tensor([[ 0.0038, -0.0004, -0.0013]])
maximum absolute element-wise difference: tensor(0.0002)


In [100]:
# define a query vector which averages the first two value vectors

# TODO:
S = 8
query132 = torch.mean(S * key[:2], dim=0).unsqueeze(0)

attn_weight = query132 @ key.transpose(-2, -1)
print("A: ",attn_weight)
attn_weight = torch.softmax(attn_weight, dim=-1)
print("A': ", attn_weight)
output = attn_weight @ value
print("O: ", output)

# compare output of attention with desired output
print(query132)
target = torch.reshape(value[(0, 1),].mean(0, keepdims=True), (3,))  # reshape to a vector
check_query(query132, target)

A:  tensor([[ 5.1008,  5.1180, -2.3764, -1.8948]])
A':  tensor([[4.9534e-01, 5.0393e-01, 2.8028e-04, 4.5368e-04]])
O:  tensor([[ 0.1521, -0.9047,  0.1978]])
tensor([[4.4400, 4.6000, 0.0400]])
maximum absolute element-wise difference: tensor(0.0022)


## 1.4. Interactions within Attention

In [None]:
# Define a replacement for only the third key vector k[2] such that the result of attention
# with the same unchanged query q from (1.3.2) averages the first three value vectors.
m_key = key.clone()

# TODO:
m_key[2] =

# compare output of attention with desired output
check_query(query132, value[(0, 1, 2),].mean(0, keepdims=True), key=m_key)

In [None]:
# Define a replacement for only the third key vector k[2] such that the result of attention
# with the same unchanged query q from (1.3.2) returns the third value vector v[2].
m_key = key.clone()

# TODO:
m_key[2] =
m_key[2] /= m_key[2].norm()

# compare output of attention with desired output
check_query(query132, value[2], key=m_key)