In [12]:
from math import sqrt
import torch
import torch.nn.functional as F
import torch.nn as nn

In [19]:
class Self_Attention(nn.Module):
    # input : batch_size * seq_len * input_dim
    # q : batch_size * input_dim * dim_k
    # k : batch_size * input_dim * dim_k
    # v : batch_size * input_dim * dim_v
    def __init__(self,input_dim, dim_k,dim_v):
        super(Self_Attention,self).__init__()
        self.q = nn.Linear(input_dim,dim_k)
        self.k = nn.Linear(input_dim,dim_k)
        self.v = nn.Linear(input_dim,dim_v)
        self._norm_fact = 1 / sqrt(dim_k)
        
    
    def forward(self,x):
        Q = self.q(x) # Q: batch_size * seq_len * dim_q
        K = self.k(x) # K: batch_size * seq_len * dim_k
        V = self.v(x) # V: batch_size * seq_len * dim_v
         
        qk = torch.bmm(Q,K.permute(0,2,1))
        norm_qk = F.softmax(qk, dim=-1) * self._norm_fact # (batch_size, seq_len, seq_len)
        atten = torch.bmm(norm_qk,V) # (batch_size, seq_len, dim_v)

        return atten

In [20]:
batch_size = 4
seq_len = 3
input_dim = 2
x = torch.randn(batch_size, seq_len, input_dim)
print(x.size())

torch.Size([4, 3, 2])


In [22]:
dim_k = 4
dim_v = 5
self_attn = Self_Attention(input_dim, dim_k, dim_v)
res = self_attn(x)
print(res.size())

torch.Size([4, 3, 5])
