In [1]:
import numpy as np
from math import sqrt
import torch
from torch import nn

In [2]:
class SelfAttention(nn.Module):
    # usually input_dim = dim_k = dim_v, input_dim is each token embedding's length
    # q and k must have same length because they need to do matrix multiply, so they use dim_k
    # v can has different length, 
    def __init__(self, input_dim, dim_k, dim_v):
        super(SelfAttention,self).__init__()
        
        self.q = nn.Linear(input_dim, dim_k)
        self.k = nn.Linear(input_dim, dim_k)
        self.v = nn.Linear(input_dim, dim_v)
        self._norm_fact = 1/sqrt(dim_k)
        
    def forward(self,x): # input size: batch_size, seq_len, input_dim(token_length)
        Q = self.q(x) # batch_size, seq_len, dim_k
        K = self.k(x) # batch_size, seq_len, dim_k
        V = self.v(x) # batch_size, seq_len, dim_v
        
        atten = nn.Softmax(dim=-1)(torch.bmm(Q,K.permute(0,2,1))*self._norm_fact)
        # batch_size, seq_len, seq_len
        
        output = torch.bmm(atten, V)# batch_size, seq_len, dim_v
        return output

In [3]:
x = torch.randn(2,4,5) # batch_size, seq_len, input_dim
print(x)

tensor([[[ 2.0458,  1.9726, -0.1512,  1.9982, -0.2152],
         [-0.0645, -1.7582, -0.5415, -0.1181,  0.2643],
         [-1.0900,  0.4482, -0.3582,  0.1351, -1.2927],
         [ 0.1714,  0.8654, -1.1385,  0.0436, -0.9536]],

        [[ 2.2399, -1.1780, -3.1951,  0.6134, -2.3315],
         [ 0.9681,  0.6806,  0.3951,  0.3766,  0.3298],
         [-0.7222, -0.5608,  1.9757,  0.6914, -0.6765],
         [ 0.7648, -0.1378, -1.1170, -0.4576,  1.1171]]])


In [4]:
self_atten = SelfAttention(5,5,5)

res = self_atten(x)

print(res.shape)
print(res)

torch.Size([2, 4, 5])
tensor([[[-0.1035,  0.4128, -0.2861, -0.3445,  0.2451],
         [-0.0300,  0.3427, -0.2526, -0.3070,  0.2649],
         [-0.0386,  0.2272, -0.2532, -0.3822,  0.1499],
         [-0.0010,  0.3377, -0.2378, -0.2772,  0.2957]],

        [[ 0.3935,  0.0149,  0.2075, -0.3303,  0.7374],
         [ 0.5318, -0.0816,  0.1414, -0.4798,  0.6215],
         [ 0.7085, -0.2124, -0.0448, -0.5835,  0.3781],
         [ 0.3008,  0.0723,  0.1411, -0.1314,  0.7126]]],
       grad_fn=<BmmBackward0>)
