In [1]:
import torch
import torch.nn.functional as F
import torch.nn as nn
from math import sqrt
import math

### tenor implementation

In [51]:
x = torch.tensor([
    [1, 0, 1, 0],
    [0, 2, 0, 2],
    [1, 1, 1, 1]
]).float()
print(x.shape)

torch.Size([3, 4])


In [52]:
w_key = torch.tensor([[0, 0, 1],
                   [1, 1, 0],
                   [0, 1, 0],
                   [1, 1, 0]]).float()

w_query = torch.tensor([[1, 0, 1],
                   [1, 0, 0],
                   [0, 0, 1],
                   [0, 1, 1]]).float()
w_value = torch.tensor([[0, 2, 0],
                   [0, 3, 0],
                   [1, 0, 3],
                   [1, 1, 0]]).float()

print(w_key.shape)


torch.Size([4, 3])


In [58]:
keys = x @ w_key
print('keys', keys)

querys = x @ w_query
print('querys', keys)

values = x @ w_value
print('values', values)

print('shape', keys.shape)

keys tensor([[0., 1., 1.],
        [4., 4., 0.],
        [2., 3., 1.]])
querys tensor([[0., 1., 1.],
        [4., 4., 0.],
        [2., 3., 1.]])
values tensor([[1., 2., 3.],
        [2., 8., 0.],
        [2., 6., 3.]])
shape torch.Size([3, 3])


In [59]:
attn_scores = querys @ keys.t()
print(attn_scores)
attn_scores_softmax = F.softmax(attn_scores, dim=1)
print(attn_scores_softmax)
print(attn_scores_softmax.shape)

tensor([[ 2.,  4.,  4.],
        [ 4., 16., 12.],
        [ 4., 12., 10.]])
tensor([[6.3379e-02, 4.6831e-01, 4.6831e-01],
        [6.0337e-06, 9.8201e-01, 1.7986e-02],
        [2.9539e-04, 8.8054e-01, 1.1917e-01]])
torch.Size([3, 3])


In [60]:
weighted_values = attn_scores_softmax @ values
print(weighted_values)
print(weighted_values.shape)

tensor([[1.9366, 6.6831, 1.5951],
        [2.0000, 7.9640, 0.0540],
        [1.9997, 7.7599, 0.3584]])
torch.Size([3, 3])


            @ w_query(4,3) -> Q(3,3)
                                     -> softmat(Q@K.t()) -> attn_scores(3,3)
input(3,4)  @ w_key(4,3)   -> K(3,3)
                                                                        -> attn_scores @ V -> weighted_values(3,3)
            @ w_value(4,3) -> V(3,3) ----------------------------------

### Pytorch implementtation

In [4]:
class Self_Attention(nn.Module):
    # input : (batch_size,(seq_len,input_dim)
    # q : (batch_size,input_dim,dim_k)
    # k : (batch_size,input_dim,dim_k)
    # v : (batch_size,input_dim,dim_v)
    def __init__(self,input_dim, dim_k,dim_v):
        super(Self_Attention,self).__init__()
        self.q = nn.Linear(input_dim,dim_k)
        self.k = nn.Linear(input_dim,dim_k)
        self.v = nn.Linear(input_dim,dim_v)
        self._norm_fact = 1 / sqrt(dim_k)
        
    
    def forward(self,x):
        Q = self.q(x) # Q: (batch_size,seq_len,dim_k)
        K = self.k(x) # K: (batch_size,seq_len,dim_k)
        V = self.v(x) # V: (batch_size,seq_len,dim_v)
         
        score = torch.bmm(Q,K.permute(0,2,1)) * self._norm_fact #(batch_size, seq_len, seq_len)
        score = F.softmax(score, dim=-1) # (batch_size, seq_len, seq_len)
        atten = torch.bmm(score,V) # (batch_size, seq_len, dim_v)

        return atten

In [5]:
batch_size = 4
seq_len = 3
input_dim = 2
x = torch.randn(batch_size, seq_len, input_dim)
print('input-size', x.size())

dim_k = 4
dim_v = 5
self_attn = Self_Attention(input_dim, dim_k, dim_v)
output = self_attn(x)
print('output-size', output.size())

input-size torch.Size([4, 3, 2])
output-size torch.Size([4, 3, 5])


### Muti-headed attention

In [7]:
class MultiHeadAttention(nn.Module):
    def __init__(self, heads, d_model, dropout = 0.1):
        super().__init__()
        
        self.d_model = d_model
        self.d_k = d_model // heads
        self.h = heads
        
        self.q_linear = nn.Linear(d_model, d_model)
        self.v_linear = nn.Linear(d_model, d_model)
        self.k_linear = nn.Linear(d_model, d_model)
        self.dropout = nn.Dropout(dropout)
        self.out = nn.Linear(d_model, d_model)
    
    def forward(self, q, k, v, mask=None):
        
        bs = q.size(0)
        
        # perform linear operation and split into h heads
        
        k = self.k_linear(k).view(bs, -1, self.h, self.d_k)
        q = self.q_linear(q).view(bs, -1, self.h, self.d_k)
        v = self.v_linear(v).view(bs, -1, self.h, self.d_k)
        
        # transpose to get dimensions bs * h * sl * d_model
       
        k = k.transpose(1,2)
        q = q.transpose(1,2)
        v = v.transpose(1,2)

	# calculate attention using function we will define next
        scores = self.attention(q, k, v, self.d_k, mask, self.dropout)
        
        # concatenate heads and put through final linear layer
        concat = scores.transpose(1,2).contiguous().view(bs, -1, self.d_model)
        
        output = self.out(concat)
    
        return output
    
    def attention(self, q, k, v, d_k, mask=None, dropout=None):
        scores = torch.matmul(q, k.transpose(-2, -1)) /  math.sqrt(d_k)

        if mask is not None:
            mask = mask.unsqueeze(1)
            scores = scores.masked_fill(mask == 0, -1e9)
        scores = F.softmax(scores, dim=-1)
        
        if dropout is not None:
            scores = dropout(scores)
            
        output = torch.matmul(scores, v)
        return output

In [9]:
heads = 2
d_models = 4

x = torch.randn(3, 4, 4)
multihead_attn = MultiHeadAttention(heads, d_models)
output = multihead_attn(x, x, x)
print(output.size())


torch.Size([3, 4, 4])
