In [21]:
import torch

In [22]:
inputs = torch.tensor(
    [[0.43, 0.15, 0.89], # Your
    [0.55, 0.87, 0.66],  # journey
    [0.57, 0.85, 0.64],  # starts
    [0.22, 0.58, 0.33],  # with
    [0.77, 0.25, 0.10],  # one
    [0.05, 0.80, 0.55]]  # step
)

In [23]:
x_2 = inputs[1]
d_in = inputs.shape[1] # tamanho do input embedding, d = 3
d_out = 2 # tamanho do output embedding d_out = 2

In [24]:
torch.manual_seed(123)
W_query = torch.nn.Parameter(torch.rand(d_in, d_out), requires_grad=False)
W_value = torch.nn.Parameter(torch.rand(d_in, d_out), requires_grad=False)
W_key = torch.nn.Parameter(torch.rand(d_in, d_out), requires_grad=False)

In [25]:
# x_2 é journey, então vamos encontrar a query para ele
query_2 = x_2 @ W_query
key_2 = x_2 @ W_key
value_2 = x_2 @ W_value
print(query_2)

tensor([0.4306, 1.4551])


In [26]:
keys = inputs @ W_key
values = inputs @ W_value
queries = inputs @ W_query

print("keys.shape: ", keys.shape)
print("values.shape: ", values.shape)
print("queries.shape: ", queries.shape)

keys.shape:  torch.Size([6, 2])
values.shape:  torch.Size([6, 2])
queries.shape:  torch.Size([6, 2])


In [27]:
keys_2 = keys[1] # keys para journey
attn_score_22 = query_2.dot(keys_2)
print(attn_score_22)

tensor(1.6307)


In [28]:
# query_2 é de dimensão 1x2
attn_score_2 = query_2 @ keys.T
print(attn_score_2)

tensor([1.3621, 1.6307, 1.5975, 0.9023, 0.5511, 1.2828])


In [29]:
# Queries 6x2 e keys 2x6, resulta em matriz 6x6
attn_scores = queries @ keys.T # Produto escalar entre vetores
print(attn_scores) # cada linha simboliza os attn scores entre as queries e todas as outras chaves

tensor([[1.0092, 1.1920, 1.1677, 0.6576, 0.4014, 0.9366],
        [1.3621, 1.6307, 1.5975, 0.9023, 0.5511, 1.2828],
        [1.3437, 1.6096, 1.5769, 0.8908, 0.5441, 1.2663],
        [0.7477, 0.8950, 0.8768, 0.4952, 0.3025, 0.7041],
        [0.6339, 0.7769, 0.7612, 0.4320, 0.2642, 0.6123],
        [0.9758, 1.1587, 1.1351, 0.6400, 0.3907, 0.9109]])


In [31]:
# Normalização para obter attention weights
# Escalar valores para sqrt(keys_dim) antes de softmax
# No caso, keys_dim = 2
keys_dim = keys.shape[-1]
attn_weights = torch.softmax(attn_scores / keys_dim**0.5, dim=-1)
print(attn_weights)

tensor([[0.1774, 0.2019, 0.1984, 0.1384, 0.1154, 0.1685],
        [0.1779, 0.2151, 0.2101, 0.1285, 0.1002, 0.1682],
        [0.1777, 0.2145, 0.2096, 0.1290, 0.1010, 0.1683],
        [0.1742, 0.1933, 0.1908, 0.1457, 0.1271, 0.1689],
        [0.1718, 0.1900, 0.1879, 0.1489, 0.1322, 0.1692],
        [0.1766, 0.2010, 0.1977, 0.1393, 0.1168, 0.1687]])


In [32]:
context_vector = attn_weights @ values
print(context_vector)

tensor([[0.3507, 0.8808],
        [0.3566, 0.8973],
        [0.3563, 0.8966],
        [0.3464, 0.8692],
        [0.3446, 0.8644],
        [0.3502, 0.8795]])


In [34]:
class SelfAttention_v1(torch.nn.Module):
    
    def __init__(self, d_in, d_out):
        super().__init__()
        self.W_query = torch.nn.Parameter(torch.rand(d_in, d_out))
        self.W_key = torch.nn.Parameter(torch.rand(d_in, d_out))
        self.W_value = torch.nn.Parameter(torch.rand(d_in, d_out))
    
    # Computa keys values e queries
    def forward(self, x):
        keys = x @ self.W_key
        values = x @ self.W_value
        queries = x @ self.W_query
        
        attn_scores = queries @ keys.T
        attn_weights = torch.softmax(
            attn_scores / keys.shape[-1]**0.5, dim=-1
        )
        
        context_vector = attn_weights @ values
        return context_vector

In [36]:
torch.manual_seed(123)
sa_v1 = SelfAttention_v1(d_in, d_out)
print(sa_v1(inputs)) # Chama a função de forward

tensor([[0.2996, 0.8053],
        [0.3061, 0.8210],
        [0.3058, 0.8203],
        [0.2948, 0.7939],
        [0.2927, 0.7891],
        [0.2990, 0.8040]], grad_fn=<MmBackward0>)
