In [1]:
import numpy as np
from math import sqrt
import torch
from torch import nn

In [3]:
class SelfAttention(nn.Module):
    # usually input_dim = dim_k = dim_v, input_dim is each token embedding's length
    # q and k must have same length because they need to do matrix multiply, so they use dim_k
    # v can has different length, 
    def __init__(self, input_dim, dim_k, dim_v):
        super(SelfAttention,self).__init__()
        
        self.q = nn.Linear(input_dim, dim_k)
        self.k = nn.Linear(input_dim, dim_k)
        self.v = nn.Linear(input_dim, dim_v)
        self._norm_fact = 1/sqrt(dim_k)
        
    def forward(self,x): # input size: batch_size, seq_len, input_dim(token_length)
        Q = self.q(x) # batch_size, seq_len, dim_k
        K = self.k(x) # batch_size, seq_len, dim_k
        V = self.v(x) # batch_size, seq_len, dim_v
        
        atten = nn.Softmax(dim=-1)(torch.bmm(Q,K.permute(0,2,1))*self._norm_fact)
        # batch_size, seq_len, seq_len
        
        output = torch.bmm(atten, V)# batch_size, seq_len, dim_v
        return output

In [14]:
x = torch.randn(2,4,5) # batch_size, seq_len, input_dim
print(x)

tensor([[[-0.0866,  0.6850,  1.9453, -0.2394, -0.4573],
         [-0.2590, -0.4115, -2.2949,  0.5432, -0.3393],
         [ 0.6740, -0.4641,  1.9441,  0.3750,  1.6446],
         [-0.6165,  0.0645, -0.5151,  1.7490, -0.0051]],

        [[-1.7339,  1.7558,  0.9879, -1.3552, -2.1753],
         [ 1.2276, -0.4604,  1.2258,  0.9501,  0.1888],
         [-1.3072, -0.2401, -0.0225,  1.2445,  0.0437],
         [-1.0297,  1.6629,  0.9111, -0.3878,  2.3639]]])


In [15]:
self_atten = SelfAttention(5,5,5)

res = self_atten(x)

print(res.shape)
print(res)

torch.Size([2, 4, 5])
tensor([[[ 1.6241e-01,  3.9731e-01,  2.1878e-01, -1.7698e-01, -1.8610e-01],
         [ 1.8993e-01,  3.1154e-01,  1.2097e-01, -1.7529e-01, -1.8005e-01],
         [ 1.0812e-01,  6.5263e-01,  5.0273e-01, -1.7204e-01, -1.9294e-01],
         [ 2.1346e-01,  2.8184e-01,  1.3961e-01, -1.7362e-01, -2.1780e-01]],

        [[-2.1695e-01,  6.6055e-01, -2.8112e-01,  1.8890e-01, -1.2270e+00],
         [ 4.1564e-01, -8.4745e-02,  7.4896e-02, -7.1744e-02, -7.9181e-01],
         [ 1.1171e-01,  3.7604e-01,  2.2371e-02,  1.8171e-02, -8.2058e-01],
         [-3.4695e-02,  5.7254e-01,  5.3644e-07,  1.3719e-01, -1.1470e+00]]],
       grad_fn=<BmmBackward0>)
