In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np

maps input sequences to a sequence of continuous representations $x=(x_1 \cdots, x_n) \rightarrow z=(z_1, \cdots, y_n)$, given $z$ generates an output sequence $y=(y_1, \cdots, y_n)$

In [None]:
class ScaledDotProductAttention(nn.modules):
    """Scaled Dot-Product Attention"""
    def __init__(self, d_k):
        super(ScaledDotProductAttention, self).__init__()
        self.d_k = d_k
        self.softmax = nn.Softmax(dim=2)
        
    def forward(self, q, k, v, mask=None):
        """
        q: d_k
        k: d_k
        v: d_v
        """
        attn = torch.bmm(q, k.transpose(1, 2)) # (B, 1, d_k) * (B, 1, d_k) -> (B, 1, 1)
        attn = attn / torch.sqrt(self.d_k)  
        # why doing this? 
        # for the large values of d_k, the dot products grow large in magnitude, 
        # pushing the softmax function into regions where it has extremely small gradients
        # to counteract this effect, scaled the dot products by 1/sqrt(d_k)
        # to illustrate why the dot products get large, check the function 'check_dotproduct_dist'
        
        
        
        

In [49]:
def check_dotproduct_dist(d_k, sampling_size=1):
    """
    to check Paper page 4, annotation 4
    -------------------------------
    To illustrate why the dot products get large, 
    assume that the components of q and k are independent random variables 
    with mean 0 and variance 1.
    Then their dot product has mean 0 and variance d_k
    """
    temp = []
    for i in range(sampling_size):
        q = nn.init.normal_(torch.rand((d_k)), mean=0, std=1)
        k = nn.init.normal_(torch.rand((d_k)), mean=0, std=1)
        attn = torch.dot(q, k)
        temp.append(attn.item())
    print('size of vector d_k is {}, sampling result, dot product distribution has \n - mean: {}, \n - std: {}\n'.\
          format(d_k, np.mean(temp), np.std(temp)))

In [50]:
for d_k in [10, 100, 1000]:
    check_dotproduct_dist(d_k, sampling_size=100000)

size of vector d_k is 10, sampling result, dot product distribution has 
 - mean: -0.021629880513213576, 
 - std: 3.1659752059733126

size of vector d_k is 100, sampling result, dot product distribution has 
 - mean: -0.016482602213025093, 
 - std: 10.009568765651423

size of vector d_k is 1000, sampling result, dot product distribution has 
 - mean: 0.04095287548005581, 
 - std: 31.607727087367685

