In [25]:
import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt

## dot product attention(Luong attention)

----
input embedding : Nx1이라고 가정 
hidden state : 5x1 이라고 가정   
$H = [h_1,h_2,\cdots , h_{N-1},h_N]$
$H\in \mathbb{R}^{5xN} , s^t \in \mathbb{R}^{5x1}$  
\begin{align*}
e^t &= H^Ts_t  \\
    &= [h_1^T s_t,h_2^Ts_t,\cdots , h_{N-1}^Ts_t,h_N^Ts_t]^T \\
    &= [s_t^Th_1 ,s_t^Th_2,\cdots , s_t^Th_{N-1},s_t^Th_N]\\
\\
\\e^t \in \mathbb{R}^N
\end{align*}

\begin{align*}
\alpha_t = softmax(e^t)  \qquad \alpha_t \in \mathbb{R}^N $ : \text{attention distribution}\\
a_t &= \sum_{i=1}^N \alpha_t^i h_i\\
    &=H\alpha_t
\end{align*}

$H\alpha_t \in \mathbb{R}^5$

In [31]:
# N = 10이라고 가정 , 10개의 token이 존재하는 것 
st = torch.randn(5,1,dtype=torch.float32) # 
H = torch.randn(5,10)
# H  = torch.randint(1,10,(5,10),dtype = torch.float32)
st.shape, H.T.shape

(torch.Size([5, 1]), torch.Size([10, 5]))

In [37]:
et = torch.matmul(H.T,st) # attention score
et.shape

torch.Size([10, 1])

In [45]:
alpha = nn.Softmax(dim=0)(et)
alpha, alpha.shape #attention distribution

(tensor([[0.0195],
         [0.0350],
         [0.6956],
         [0.0208],
         [0.0302],
         [0.0035],
         [0.0952],
         [0.0506],
         [0.0347],
         [0.0148]]),
 torch.Size([10, 1]))

In [49]:
context_vector = torch.matmul(H,alpha) # attention in value process

In [51]:
context_vector.shape,st.shape

(torch.Size([5, 1]), torch.Size([5, 1]))

In [65]:
vt = torch.vstack((st,context_vector))
vt.shape

torch.Size([10, 1])

# Bahdanau Attention (바다나우 어텐션)

In [66]:
s_t1 = torch.randn_like(st)
s_t1.shape, H.shape

(torch.Size([5, 1]), torch.Size([5, 10]))

In [None]:
# Q, K 

# nn.Linear(bias=False)

$W_b \in R^{5\times5}$
$W_c \in R^{5\times5}$


In [99]:
W_a = nn.Linear(5,1,bias=False).weight
W_b = nn.Linear(5,5,bias=False).weight
W_c = nn.Linear(5,5,bias=False).weight

In [89]:
a = torch.matmul(W_b,s_t1)
b = torch.matmul(W_c,H)

In [100]:
c = torch.matmul(W_a,nn.Tanh()(a+b))

In [105]:
att = nn.Softmax(dim=1)(c)

In [114]:
H.shape , att.shape

(torch.Size([5, 10]), torch.Size([1, 10]))

In [115]:
d = att@H.T

In [116]:
d.shape

torch.Size([1, 5])

$s_t = (W_a s_{t-1}+cat(context_vector, input))$