In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F 

In [3]:
class SelfAttention(nn.Module):
    def __init__(self, d_model, m, n):
        
        super().__init__()
        self.m = m
        self.n = n
        self.q = nn.Linear(in_features=d_model, out_features=d_model, bias=False)
        self.k = nn.Linear(in_features=d_model, out_features=d_model, bias=False)
        self.v = nn.Linear(in_features=d_model, out_features=d_model, bias=False)

    def forward(self, embds):
        
        q = self.q(embds)
        k = self.k(embds)
        v = self.v(embds)
        
        sims = torch.matmul(q, k.transpose(dim0=self.m, dim1=self.n))
        sims = sims / torch.tensor(k.size(self.n) **0.5)
        _attention = F.softmax(sims, dim=self.n)
        attention = torch.matmul(_attention, v)

        return attention


In [5]:
t = torch.tensor(
    [[.2, .3, .5],
    [.1, .2, .1],
    [.9, .1, .7]]
)

torch.manual_seed(1)

sa = SelfAttention(d_model=3, m=0, n=1)

sa(t)

tensor([[ 0.1450, -0.2672,  0.0806],
        [ 0.1443, -0.2664,  0.0807],
        [ 0.1469, -0.2696,  0.0806]], grad_fn=<MmBackward0>)

In [12]:
t

tensor([[0.2000, 0.3000, 0.5000],
        [0.1000, 0.2000, 0.1000],
        [0.9000, 0.1000, 0.7000]])

In [13]:
sa.q.weight.transpose(0, 1)

tensor([[ 0.2975,  0.2710, -0.1188],
        [-0.2548, -0.5435,  0.2937],
        [-0.1119,  0.3462,  0.0803]], grad_fn=<TransposeBackward0>)

In [14]:
sa.k.weight.transpose(0, 1)

tensor([[-0.0707,  0.2109, -0.0520],
        [ 0.1601, -0.2250,  0.0837],
        [ 0.0285, -0.0421, -0.0023]], grad_fn=<TransposeBackward0>)

In [15]:
sa.v.weight.transpose(0, 1)

tensor([[ 0.5047, -0.3487, -0.1850],
        [ 0.1797, -0.0968,  0.0276],
        [-0.2150, -0.2490,  0.3442]], grad_fn=<TransposeBackward0>)

In [11]:
sa.q(t)

tensor([[-0.0729,  0.0643,  0.1045],
        [-0.0324, -0.0470,  0.0549],
        [ 0.1639,  0.4319, -0.0213]], grad_fn=<MmBackward0>)

In [16]:
sa.k(t)

tensor([[ 0.0481, -0.0464,  0.0136],
        [ 0.0278, -0.0281,  0.0113],
        [-0.0277,  0.1378, -0.0400]], grad_fn=<MmBackward0>)

In [18]:
sims = torch.matmul(sa.q(t), sa.k(t).transpose(dim0=0, dim1=1))
sims

tensor([[-0.0051, -0.0027,  0.0067],
        [ 0.0014,  0.0010, -0.0078],
        [-0.0124, -0.0078,  0.0558]], grad_fn=<MmBackward0>)

In [24]:
sa.k(t).size(sa.n)

3

In [25]:
torch.tensor(sa.k(t).size(sa.n) **0.5)

tensor(1.7321)

In [27]:
scaled_sims = sims / torch.tensor(sa.k(t).size(sa.n) **0.5) # sqrt of dimension

scaled_sims

tensor([[-0.0029, -0.0015,  0.0039],
        [ 0.0008,  0.0006, -0.0045],
        [-0.0072, -0.0045,  0.0322]], grad_fn=<DivBackward0>)

In [29]:
_attention = F.softmax(scaled_sims, dim=sa.n)
_attention

tensor([[0.3324, 0.3329, 0.3347],
        [0.3339, 0.3339, 0.3322],
        [0.3286, 0.3295, 0.3419]], grad_fn=<SoftmaxBackward0>)

In [30]:
torch.matmul(_attention, sa.v(t))

tensor([[ 0.1450, -0.2672,  0.0806],
        [ 0.1443, -0.2664,  0.0807],
        [ 0.1469, -0.2696,  0.0806]], grad_fn=<MmBackward0>)