### Attention

##### Example 1

In [1]:
import torch

In [7]:
import torch.nn.functional as F

In [3]:
encoder_hidden_states = torch.tensor([
                            [0.21, 0.51, 1.25],
                            [4.13, -41.3, 0.31],
                            [3.2, 0.02, 9.44]])

In [4]:
encoder_hidden_states

tensor([[ 2.1000e-01,  5.1000e-01,  1.2500e+00],
        [ 4.1300e+00, -4.1300e+01,  3.1000e-01],
        [ 3.2000e+00,  2.0000e-02,  9.4400e+00]])

In [9]:
scores = torch.tensor([13, 9, 9], dtype=torch.float)

In [11]:
softmax_scores = F.softmax(scores, dim=0)

In [14]:
output = softmax_scores * encoder_hidden_states

In [15]:
output

tensor([[ 2.0258e-01,  9.0109e-03,  2.2086e-02],
        [ 3.9841e+00, -7.2971e-01,  5.4772e-03],
        [ 3.0869e+00,  3.5337e-04,  1.6679e-01]])

In [16]:
output.sum(dim=-1)

tensor([0.2337, 3.2598, 3.2541])

##### Example 2: Compute keys

In [54]:
v_1 = torch.tensor([1, 0, 1, 0])

In [55]:
v_2 = torch.tensor([0, 2, 0, 2])

In [56]:
v_3 = torch.tensor([1, 1, 1, 1])

In [60]:
W_k1 = torch.tensor([0, 1, 0, 1])

In [61]:
W_k2 = torch.tensor([0, 1, 1, 1])

In [62]:
W_k3 = torch.tensor([1, 0, 0, 0])

In [59]:
W_k

tensor([[0, 0, 1],
        [1, 1, 0],
        [0, 1, 0],
        [1, 1, 0]])

`v_1`, `v_2`, `v_3` are embedding vectors of three words

In [95]:
v_1, v_2, v_3

(tensor([1, 0, 1, 0]), tensor([0, 2, 0, 2]), tensor([1, 1, 1, 1]))

`W_k1`, `W_k2`, `W_k3` are weights vectors for keys

In [96]:
W_k1, W_k2, W_k3

(tensor([0, 1, 0, 1]), tensor([0, 1, 1, 1]), tensor([1, 0, 0, 0]))

**Self-Attention**: Write a function that computes the key for each vector as bellow from scratch

In [97]:
def compute_keys(vs, ws):
    v = torch.stack([*vs])
    w = torch.stack([*ws], dim=1)
    
    return v@w

In [98]:
keys = compute_keys([v_1, v_2, v_3], [W_k1, W_k2, W_k3])

In [99]:
keys

tensor([[0, 1, 1],
        [4, 4, 0],
        [2, 3, 1]])

`keys[0]` is the key for `v_1`

In [100]:
keys[0]

tensor([0, 1, 1])

##### Example 3

In [101]:
keys = compute_keys([v_1, v_2, v_3], [W_k1, W_k2, W_k3])

`W_k` is a matrix that contains the weight keys for `v1`, `v2`, `v3` respectively

In [108]:
v = torch.stack([v_1, v_2, v_3])

In [110]:
W_k = torch.tensor([[0, 0, 1],
                    [1, 1, 0],
                    [0, 1, 0],
                    [1, 1, 0]])

In [111]:
W_q = torch.tensor([[1, 0, 1],
                    [1, 0, 0],
                    [0, 0, 1],
                    [0, 1, 1]])

In [112]:
W_v = torch.tensor([[0, 2, 0],
                    [0, 3, 0],
                    [1, 0, 3],
                    [1, 1, 0]])

In [113]:
keys = v @ W_k

In [114]:
keys

tensor([[0, 1, 1],
        [4, 4, 0],
        [2, 3, 1]])

In [115]:
queries = v @ W_q

In [119]:
queries

tensor([[1, 0, 2],
        [2, 2, 2],
        [2, 1, 3]])

In [116]:
values = v @ W_v

In [118]:
values

tensor([[1, 2, 3],
        [2, 8, 0],
        [2, 6, 3]])

In [123]:
queries[0] * keys.t()

tensor([[0, 0, 4],
        [1, 0, 6],
        [1, 0, 2]])

In [158]:
queries

tensor([[1, 0, 2],
        [2, 2, 2],
        [2, 1, 3]])

In [147]:
keys.t()

tensor([[0, 4, 2],
        [1, 4, 3],
        [1, 0, 1]])

In [143]:
queries[0][1] * keys[0]

tensor([0, 0, 0])

In [153]:
attention_score = queries[0] @ keys.t()

In [154]:
attention_score

tensor([2, 4, 4])

In [156]:
queries[0] @ keys.t()

tensor([2, 4, 4])

In [157]:
F.softmax(attention_score.float(), dim=0)

tensor([0.0634, 0.4683, 0.4683])

##### Example 2

In [161]:
v = torch.tensor([[1, 0, 1, 0],
                  [0, 2, 0, 2],
                  [1, 1, 1, 1]], dtype=torch.float)

In [163]:
v

tensor([[1., 0., 1., 0.],
        [0., 2., 0., 2.],
        [1., 1., 1., 1.]])

In [176]:
v[0]

tensor([1., 0., 1., 0.])

They the weights of keys of `v[0]` is `w_key[:, 0]`

In [174]:
w_key = torch.tensor([[0, 0, 1],
                      [1, 1, 0],
                      [0, 1, 0],
                      [1, 1, 0]], dtype=torch.float)

In [175]:
w_query = torch.tensor([[1, 0, 1],
                        [1, 0, 0],
                        [0, 0, 1],
                        [0, 1, 1]], dtype=torch.float)

In [170]:
w_value = torch.tensor([[0, 2, 0],
                        [0, 3, 0],
                        [1, 0, 3],
                        [1, 1, 0]], dtype=torch.float)

In [178]:
keys = v @ w_key

In [179]:
queries = v @ w_query

In [180]:
values = v @ w_value

In [181]:
queries

tensor([[1., 0., 2.],
        [2., 2., 2.],
        [2., 1., 3.]])

In [182]:
keys

tensor([[0., 1., 1.],
        [4., 4., 0.],
        [2., 3., 1.]])

In [183]:
queries @ keys

tensor([[ 4.,  7.,  3.],
        [12., 16.,  4.],
        [10., 15.,  5.]])