# Reference

[visualizeda attention mechanism](https://towardsdatascience.com/transformers-explained-visually-part-3-multi-head-attention-deep-dive-1c1ff1024853)

[transformer linebyline example](https://paul-hyun.github.io/transformer-01/)

In [47]:
import torch
from torchtext.datasets import WikiText2
from torchtext.data.utils import get_tokenizer
from torchtext.vocab import build_vocab_from_iterator

train_iter = WikiText2(split='train')
tokenizer = get_tokenizer('basic_english')
vocab = build_vocab_from_iterator(map(tokenizer, train_iter), specials=["<unk>"])
vocab.set_default_index(vocab["<unk>"])

def data_process(raw_text_iter):
  data = [torch.tensor(vocab(tokenizer(item)), dtype=torch.long) for item in raw_text_iter]
  return torch.cat(tuple(filter(lambda t: t.numel() > 0, data)))

train_iter, val_iter, test_iter = WikiText2()
train_data = data_process(train_iter)
val_data = data_process(val_iter)
test_data = data_process(test_iter)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [2]:
train_data

tensor([   9, 3849, 3869,  ..., 2442, 4810,    3])

* narrow 역할
    * 깔끔하게 나누어 떨어지지 않는 추가적인 부분(나머지들) 은 잘라냅니다.

In [3]:
def batchify(data, bsz):
    # 데이터셋을 bsz 파트들로 나눕니다.
    nbatch = data.size(0) // bsz
    # 깔끔하게 나누어 떨어지지 않는 추가적인 부분(나머지들) 은 잘라냅니다.
    data = data.narrow(0, 0, nbatch * bsz)
    # 데이터에 대하여 bsz 배치들로 동등하게 나눕니다.
    data = data.view(bsz, -1).t().contiguous()
    return data.to(device)

batch_size = 20
eval_batch_size = 10
train_data = batchify(train_data, batch_size)
val_data = batchify(val_data, eval_batch_size)
test_data = batchify(test_data, eval_batch_size)

In [4]:
train_data.shape

torch.Size([102499, 20])

* get_batch() 
    * 함수는 트랜스포머 모델을 위한 입력과 타겟 시퀀스를 생성합니다. 
    * 이 함수는 소스 데이터를 bptt 길이를 가진 덩어리로 세분화 합니다. 
    * 언어 모델링 과제를 위해서, 모델은 다음 단어인 Target 이 필요 합니다. 
    * 예를 들어, bptt 의 값이 2 라면, 우리는 i = 0 일 때 다음의 2 개의 변수(Variable) 를 얻을 수 있습니다


* BPTT
    * BPTT를 통해 RNN의 가중치 행렬의 미분을 계산해보면 아래와 같이 최종적으로 미분의 곱으로 이루어진 항이 계산된다
    * 시퀀스 길이가 길어지는 경우 BPTT가 불안정해지므로 길이를 끊는 것이 필요하다. 이 방법을 Truncated BPTT라고 부른다.
    

In [5]:
bptt = 35
def get_batch(source, i):
    seq_len = min(bptt, len(source) - 1 - i)
    data = source[i:i+seq_len]
    target = source[i+1:i+1+seq_len] # .reshape(-1)
    return data, target

bptt_list  = [i for batch, i in enumerate(range(0, train_data.size(0) - 1, bptt))]
data, targets =get_batch(train_data , bptt_list[0])

# Transpose

(bptt, batch) -> (batch,bptt)

In [9]:
data = data.T

In [10]:
from torch import nn
import math 
import numpy as np
ntokens = len(vocab) # the size of vocabulary
emsize = 512 # embedding dimension
ninp = emsize
encoder = nn.Embedding(ntokens, ninp)


In [11]:
src = encoder(data) * math.sqrt(ninp) ## question

In [12]:
data.shape , src.shape

(torch.Size([20, 35]), torch.Size([20, 35, 512]))

# Scale-Dot Product Attention

* Q와 K 사이를 내적하여 어텐션을 softmax를 통해 구하고, 그 후에 V를 내적하여 중요한 부분(Attention)을 더 살린다는 의미 내포

## mask
* self attention에서는 time sequence와 같이 적용해야하므로 필수

## return 
* context
    * softmax와 내적하여 context vector  rntjd 
* Attention

In [31]:
def generate_square_subsequent_mask(sz):
    mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1)
    mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
    return mask

In [150]:
src_mask = generate_square_subsequent_mask(bptt)

In [20]:
class ScaledDotProductAttention(nn.Module):
  def __init__(self,d_k):
    super(ScaledDotProductAttention, self).__init__()
    self.d_k = d_k

  def forward(self, Q, K, V, attn_mask=None):
    scores = torch.matmul(Q, K.transpose(-1, -2)) / np.sqrt(self.d_k) # scores : [batch_size x n_heads x len_q(=len_k) x len_k(=len_q)]
    if attn_mask is not None:
      scores.masked_fill_(attn_mask, -1e9)
    attn = nn.Softmax(dim=-1)(scores)
    context = torch.matmul(attn, V)
    return context, attn

torch.Size([20, 512, 35])

In [21]:
from torch import Tensor
class AttentionHead(nn.Module):
    def __init__(self, dim_in: int, dim_k: int, dim_v: int):
        super().__init__()
        self.q = nn.Linear(dim_in, dim_k)
        self.k = nn.Linear(dim_in, dim_k)
        self.v = nn.Linear(dim_in, dim_v)
        self.scaledot = ScaledDotProductAttention(dim_k)

    def forward(self,
    query: Tensor, key: Tensor, value: Tensor,attn_mask:Tensor=None) -> Tensor:
        return self.scaledot(self.q(query), self.k(key), self.v(value),attn_mask)
    

In [22]:
dim_model = ninp
num_heads = 8
dim_k = dim_v = dim_model // num_heads
print(dim_model, num_heads , dim_k , dim_v)

512 8 64 64


In [28]:
attention_value = ScaledDotProductAttention(dim_k)(src,src,src)
attention_value[1].shape , attention_value[0].shape

(torch.Size([20, 35, 35]), torch.Size([20, 35, 512]))

In [29]:
attention = AttentionHead(dim_model , dim_k , dim_v)
attention(src,src,src)[0].shape

torch.Size([20, 35, 64])

In [35]:
src_mask = generate_square_subsequent_mask(src.shape[1])
src.shape , src_mask.shape

(torch.Size([20, 35, 512]), torch.Size([35, 35]))

In [42]:
attention1head = AttentionHead(dim_model , dim_k , dim_v)
attention1head(src,src,src,src_mask)[0].shape

torch.Size([20, 35, 64])

In [43]:
class MultiHeadAttention(nn.Module):
    def __init__(self, num_heads: int, dim_in: int, dim_k: int, dim_v: int):
        super().__init__()
        self.heads = nn.ModuleList(
            [AttentionHead(dim_in, dim_k, dim_v) for _ in range(num_heads)]
        )
        self.linear = nn.Linear(num_heads * dim_v, dim_in)
        print(len(self.heads),num_heads * dim_v, dim_in)

    def forward(self, query: Tensor, key: Tensor, value: Tensor,attn_mask:Tensor=None) -> Tensor:
        attn_cat = torch.cat([h(query, key, value,attn_mask)[0] for h in self.heads], dim=-1)
        print(attn_cat.shape)
        return self.linear(attn_cat)  

In [44]:
num_heads

8

In [45]:
print(dim_v)
multiheadattention = MultiHeadAttention(num_heads,dim_model,dim_k,dim_v)
src_mask = generate_square_subsequent_mask(src.shape[1])
multiheadattention(src,src,src,src_mask)


64
8 512 512
torch.Size([20, 35, 512])


tensor([[[-1.0696e+01,  4.0669e+00,  4.1515e+00,  ..., -8.7903e+00,
          -8.7064e+00, -9.1512e-01],
         [ 1.0199e+01, -9.8666e+00, -7.7445e+00,  ..., -1.1177e+01,
           2.1795e+00, -4.6826e+00],
         [-7.1328e+00,  4.3087e-02, -1.6991e+00,  ..., -2.2577e+00,
          -1.2382e-01, -1.0623e+01],
         ...,
         [ 1.1241e+01, -3.3353e-01, -1.1624e+00,  ..., -1.2642e+01,
          -3.4864e+00, -1.3560e+00],
         [ 9.1558e+00, -1.0391e+01, -1.4705e+01,  ...,  5.9991e+00,
          -9.5669e+00, -4.5261e+00],
         [ 4.5109e+00, -6.5855e+00,  5.5173e+00,  ..., -2.8333e+00,
          -8.1005e-01, -1.4307e+01]],

        [[-1.0836e+01, -1.9000e+00,  6.5398e+00,  ..., -5.4715e+00,
          -3.5084e+00,  3.6073e+00],
         [-2.1476e-01,  5.7531e+00, -1.0472e+00,  ...,  7.5718e+00,
           1.3334e+00,  5.6058e+00],
         [ 1.6913e+00,  5.1038e+00, -1.0241e+01,  ...,  1.3705e+01,
          -1.8139e+00,  5.1536e+00],
         ...,
         [ 2.8922e+00,  3

In [54]:
Q = src
K = src
V = src


In [58]:
src_mask.shape , Q.shape ,  K.size(1)

(torch.Size([35, 35]), torch.Size([20, 35, 512]), 35)

In [64]:
# src_mask2 = src_mask.repeat(20,1,1)
# src_mask2.eq(0).unsqueeze(1).expand(Q.size(0), Q.size(1), K.size(1))

In [70]:
batch_size

20

In [73]:
dim_in = dim_model
W_Q = nn.Linear(dim_in, num_heads * dim_k)
W_K = nn.Linear(dim_in, num_heads * dim_k)
W_V = nn.Linear(dim_in, num_heads * dim_k)

# (bs, n_seq, n_head * d_head)
q_s = W_Q(Q)
print(q_s.size())
# (bs, n_seq, n_head, d_head)
q_s = q_s.view(batch_size, -1, num_heads, dim_k)
print(q_s.size())
# (bs, n_head, n_seq, d_head)
q_s = q_s.transpose(1,2)
print(q_s.size())

torch.Size([20, 35, 512])
torch.Size([20, 35, 8, 64])
torch.Size([20, 8, 35, 64])


In [74]:
# (bs, n_head, n_seq, d_head)
q_s = W_Q(Q).view(batch_size, -1, num_heads, dim_k).transpose(1,2)
# (bs, n_head, n_seq, d_head)
k_s = W_K(K).view(batch_size, -1, num_heads, dim_k).transpose(1,2)
# (bs, n_head, n_seq, d_head)
v_s = W_V(V).view(batch_size, -1, num_heads, dim_k).transpose(1,2)
print(q_s.size(), k_s.size(), v_s.size())

torch.Size([20, 8, 35, 64]) torch.Size([20, 8, 35, 64]) torch.Size([20, 8, 35, 64])


In [75]:
# print(attn_mask.size())
# attn_mask = attn_mask.unsqueeze(1).repeat(1, n_head, 1, 1)
# print(attn_mask.size())

In [76]:
scaled_dot_attn = ScaledDotProductAttention(dim_k)
context, attn_prob = scaled_dot_attn(q_s, k_s, v_s, attn_mask= None)
print(context.size())
print(attn_prob.size())

torch.Size([20, 8, 35, 64])
torch.Size([20, 8, 35, 35])


In [69]:
context = torch.matmul(attn_prob, V)
print(context.size())

torch.Size([20, 35, 512])


In [78]:
context = context.transpose(1, 2).contiguous().view(batch_size, -1, num_heads * dim_k)
print(context.size())

torch.Size([20, 35, 512])


In [79]:
linear = nn.Linear(num_heads * dim_k, dim_in)
# (bs, n_seq, d_hidn)
output = linear(context)
print(output.size())

torch.Size([20, 35, 512])


In [82]:
""" multi head attention """
class MultiHeadAttention(nn.Module):
    def __init__(self, dim_in, num_heads, dim_k):
        super().__init__()
        self.dim_in = dim_in
        self.num_heads = num_heads
        self.dim_k = dim_k

        self.W_Q = nn.Linear(dim_in, num_heads * dim_k)
        self.W_K = nn.Linear(dim_in, num_heads * dim_k)
        self.W_V = nn.Linear(dim_in, num_heads * dim_k)
        self.scaled_dot_attn = ScaledDotProductAttention(dim_k)
        self.linear = nn.Linear(num_heads * dim_k, dim_in)
    
    def forward(self, Q, K, V, attn_mask=None):
        batch_size = Q.size(0)
        # (bs, n_head, n_q_seq, d_head)
        q_s = self.W_Q(Q).view(batch_size, -1, self.num_heads, self.dim_k).transpose(1,2)
        # (bs, n_head, n_k_seq, d_head)
        k_s = self.W_K(K).view(batch_size, -1, self.num_heads, self.dim_k).transpose(1,2)
        # (bs, n_head, n_v_seq, d_head)
        v_s = self.W_V(V).view(batch_size, -1, self.num_heads, self.dim_k).transpose(1,2)

        # (bs, n_head, n_q_seq, n_k_seq) TODO:
        if attn_mask is not None :
            attn_mask = attn_mask.unsqueeze(1).repeat(1, self.n_head, 1, 1)

        # (bs, n_head, n_q_seq, d_head), (bs, n_head, n_q_seq, n_k_seq)
        context, attn_prob = self.scaled_dot_attn(q_s, k_s, v_s, attn_mask=attn_mask)
        # (bs, n_head, n_q_seq, h_head * d_head)
        context = context.transpose(1, 2).contiguous().view(batch_size, -1, self.num_heads * self.dim_k)
        # (bs, n_head, n_q_seq, e_embd)
        output = self.linear(context)
        # (bs, n_q_seq, d_hidn), (bs, n_head, n_q_seq, n_k_seq)
        return output, attn_prob

In [84]:
multiatten = MultiHeadAttention(dim_in=dim_in,num_heads=num_heads,dim_k=dim_k)
multiatten(Q,K,V,None)[0].shape

torch.Size([20, 35, 512])

In [85]:
multiatten(Q,K,V,None)[1].shape

torch.Size([20, 8, 35, 35])