In [3]:
import torch.nn as nn
import torch
class SelfAttention_v1(nn.Module):
    def __init__(self, d_in, d_out):
        super().__init__()
        self.W_query = nn.Parameter(torch.rand(d_in, d_out))
        self.W_key = nn.Parameter(torch.rand(d_in, d_out))
        self.W_value = nn.Parameter(torch.rand(d_in, d_out))
    def forward(self, x):
        keys = x @ self.W_query
        queries = x @ self.W_query
        values = x @ self.W_value
        attn_scores = queries @ keys.T
        attn_weights = torch.softmax(
            attn_scores / keys.shape[-1] ** 0.5, dim=-1
        )
        print(attn_weights)
        context_vec = attn_weights @ values
        return context_vec

In [4]:
inputs = torch.tensor(
    [[0.43, 0.15, 0.89], # Your     (x^1)
     [0.55, 0.87, 0.66], # journey  (x^2)
     [0.57, 0.85, 0.64], # starts   (x^3) 
     [0.22, 0.58, 0.33], # with     (x^4)
     [0.77, 0.25, 0.10], # one      (x^5)
     [0.05, 0.80, 0.55]] # step     (x*6)
)
d_in = inputs.shape[1] # the input embedding size
d_out = 2 # the output embedding size

In [5]:
torch.manual_seed(123)
sa_v1 = SelfAttention_v1(d_in, d_out)
print(sa_v1(inputs))

tensor([[0.1615, 0.2204, 0.2168, 0.1283, 0.1161, 0.1569],
        [0.1565, 0.2404, 0.2353, 0.1154, 0.1016, 0.1508],
        [0.1567, 0.2395, 0.2344, 0.1160, 0.1023, 0.1511],
        [0.1631, 0.2065, 0.2040, 0.1380, 0.1286, 0.1598],
        [0.1630, 0.2008, 0.1988, 0.1421, 0.1348, 0.1606],
        [0.1615, 0.2187, 0.2153, 0.1295, 0.1178, 0.1571]],
       grad_fn=<SoftmaxBackward0>)
tensor([[0.2997, 0.8094],
        [0.3067, 0.8265],
        [0.3063, 0.8256],
        [0.2947, 0.7960],
        [0.2925, 0.7900],
        [0.2991, 0.8077]], grad_fn=<MmBackward0>)


# Self Attention V2


In [6]:
class SelfAttention_v2(nn.Module):
    def __init__(self, d_in, d_out, qkv_bias=False):
        super().__init__()
        self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)
    def forward(self, x):
        keys = self.W_key(x)
        queries = self.W_query(x)
        values = self.W_value(x)
        attn_scores = queries @ keys.T
        attn_weights = torch.softmax(
            attn_scores / keys.shape[-1] ** 0.5, dim=-1
        )
        context_vec = attn_weights @ values
        return context_vec

In [7]:
torch.manual_seed(789)
sa_v2 = SelfAttention_v2(d_in, d_out)
print(sa_v2(inputs))

tensor([[-0.0739,  0.0713],
        [-0.0748,  0.0703],
        [-0.0749,  0.0702],
        [-0.0760,  0.0685],
        [-0.0763,  0.0679],
        [-0.0754,  0.0693]], grad_fn=<MmBackward0>)
