In [1]:
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

print(torch.__version__)

1.10.2


In [2]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print (device)

cpu


## Scaled Dot-Product Attention

- $X \in \mathbb{R}^{n \times d}$
    - $n$: number of data = sequence length = number of words
    - $d$: propagating vector = dimension of word embedding vector = usually a size of 512
    - Complexity per Layer of Self-Attention: $O(n^2\cdot d)$
- $Q, K \in \mathbb{R}^{n \times d_K}$ 
- $V \in \mathbb{R}^{n \times d_V} $

$\text{Attention}(Q,K,V) = \text{softmax} \left( \frac{QK^T}{\sqrt{d_K}} \right)V \in \mathbb{R}^{n \times d_V} $

In [29]:
class SDPA(nn.Module):
    
    def forward(self, Q, K, V, mask=None):
        
        d_K = K.size()[-1]
        
        scores = Q.matmul(K.transpose(-2, -1)) / np.sqrt(d_K)
        
        if mask is not None:
            scores = scores.masked_fill(mask==0, -1e9)
            
        attention = F.softmax(scores, dim=-1)
        output = attention.matmul(V)
        return output, attention
    
SDPA = SDPA()


In [30]:
# SDPA

n_batch = 3

d_K = 128 # d_K(=d_Q) is not necessarily equal to d_V
d_V = 256 

n_Q = 30
n_K = 50 # n_K must equal to n_V
n_V = 50

Q = torch.rand(n_batch, n_Q, d_K)
K = torch.rand(n_batch, n_K, d_K)
V = torch.rand(n_batch, n_V, d_V)

# disregard n_batch,
# Q.shape = (n_Q, d_K)
# K.shape = (n_K, d_K)
# matmul(Q, K.T).shape = attention.shape = (n_Q, n_K) = (30, 50)
# V.shape = (n_V, d_V) = (50, 256)
# matmul(attention, V).shape = output.shape = (n_Q, d_V) = (30, 256)

output, attention = SDPA.forward(Q, K, V, mask=None)

print(f'SDPA: Q{str(Q.shape)[11:]} K{str(K.shape)[11:]} V{str(V.shape)[11:]} \n=> output{str(output.shape)[11:]} attention{str(attention.shape)[11:]}')


SDPA: Q[3, 30, 128]) K[3, 50, 128]) V[3, 50, 256]) 
=> output[3, 30, 256]) attention[3, 30, 50])


In [31]:
# Multi-Head SDPA

n_batch = 3

n_head = 5

d_K = 128
d_V = 256

n_Q = 30
n_K = 50
n_V = 50

Q = torch.rand(n_batch, n_head, n_Q, d_K)
K = torch.rand(n_batch, n_head, n_K, d_K)
V = torch.rand(n_batch, n_head, n_V, d_V)

output, attention = SDPA.forward(Q, K, V, mask=None)

print(f'Multi-Head SDPA: Q{str(Q.shape)[11:]} K{str(K.shape)[11:]} V{str(V.shape)[11:]} \n=> output{str(output.shape)[11:]} attention{str(attention.shape)[11:]}')


Multi-Head SDPA: Q[3, 5, 30, 128]) K[3, 5, 50, 128]) V[3, 5, 50, 256]) 
=> output[3, 5, 30, 256]) attention[3, 5, 30, 50])
