# MultiHead Attention

<img src="./image/multihead_attention.png" width="500" height="400">

In [3]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math

class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, n_head):
        super(MultiHeadAttention, self).__init__()
        
        self.d_model = d_model
        self.n_head = n_head
        self.n_d = d_model // n_head

        self.w_q= nn.Linear(d_model, d_model)
        self.w_k = nn.Linear(d_model, d_model)
        self.w_v = nn.Linear(d_model, d_model)
        self.w_o = nn.Linear(d_model, d_model)

    def forward(self, q, k, v, mask=None):
        batch_size,time,dimension=q.shape

        q,k,v=self.w_q(q),self.w_k(k),self.w_v(v)
        q=q.view(batch_size,time,self.n_head,self.n_d).permute(0,2,1,3)
        k=k.view(batch_size,time,self.n_head,self.n_d).permute(0,2,1,3)
        v=v.view(batch_size,time,self.n_head,self.n_d).permute(0,2,1,3)

        score=q@k.transpose(2,3)/math.sqrt(self.n_d)
        if mask is not None:
            mask=torch.tril(torch.ones(time,time,dtype=bool))
            score=score.masked_fill(mask==0,1e-9)

        out=F.softmax(score,dim=-1)@v
        out=out.permute(0,2,1,3).contiguous().view(batch_size,time,self.d_model)
        out=self.w_o(out)

        return out

X=torch.rand(128,64,512)
d_model=512 
n_head=8

attention=MultiHeadAttention(d_model,n_head)
output=attention(X,X,X)
#print(output,output.shape)

# Embedding
## TokenEmbedding

In [None]:
class TokenEmbedding(nn.Module):
    def __init__(self, vocab_size, d_model):
        super(TokenEmbedding, self).__init__()
        self.vocab_size = vocab_size
        self.d_model = d_model

        self.token_embedding = nn.Embedding(vocab_size, d_model)
    
    def forward(self, x):
        return self.token_embedding(x) 

## PositionalEbedding

<img src="./image/positional_encoding.jpg" width="500" height="400">

In [None]:
class PositionalEmbedding(nn.Module):
    def __init__(self, d_model, max_len, device):
        super(PositionalEmbedding, self).__init__()
        self.d_model = d_model
        self.max_len = max_len
        self.device = device

        self.pe = torch.zeros(max_len, d_model, devide=device)
        self.pe.requires_grad_(False)

        for pos in range(max_len):
            for i in range(0, d_model, 2):
                self.pe[pos, i] = math.sin(pos / (10000 ** ((2 * i)/d_model)))
                self.pe[pos, i + 1] = math.cos(pos / (10000 ** ((2 * (i + 1))/d_model)))

    def forward(self, x):
        seq_len=x.shape[1]
        return self.pe[:seq_len, :]

## Totol Embedding

In [None]:
class TransformerEmbedding(nn.Module):
    def __init__(self, vocab_size, d_model, max_len, drop_prob , device):
        super(TransformerEmbedding, self).__init__()
        self.token_embedding = TokenEmbedding(vocab_size, d_model)
        self.position_embedding = PositionalEmbedding(d_model, max_len, device)
        self.drop_out= nn.Dropout(drop_prob)
        

    def forward(self, x): 
        tok_emb= self.token_embedding(x)
        pos_emb = self.position_embedding(x)
        return self.drop_out(tok_emb + pos_emb)

<img src="./image/layer_norm.jpg" width="500" height="400">


Batch Norm 是在通道维度上做归一化，即把每个样本的对应的通道都拿出来单独做归一化
Layer Norm 是在样本维度上做归一化，即把每个样本都拿出来单独做归一化

In [None]:
class LayerNorm(nn.Module):
    def __init__(self, d_model, eps=1e-6):
        super(LayerNorm, self).__init__()
        self.d_model = d_model
        self.eps = eps

        self.gamma = nn.Parameter(torch.ones(d_model))
        self.beta = nn.Parameter(torch.zeros(d_model))

    def forward(self, x):
        mean=x.mean(dim=-1, keepdim=True)
        var=x.var(dim=-1,unbiased=False,keepdim=True)
        out=(x-mean)/torch.sqrt(var+self.eps)
        out=self.gamma*out+self.beta
        
        return out

# FFN

<img src="./image/positionwise_feed_forward.jpg" width="500" height="400">

In [None]:
class PositionwiseFeedForward(nn.Module):
    def __init__(self, d_model, hidden, dropout=0.1):
        super(PositionwiseFeedForward, self).__init__()
        self.d_model = d_model
        self.hidden = hidden
        self.dropout = dropout

        self.fc1 = nn.Linear(d_model, hidden)
        self.dropout = nn.Dropout(dropout)
        self.fc2 = nn.Linear(hidden, d_model)

    def forward(self, x):
        x=self.fc1(x)
        x=self.dropout(F.relu(x))
        x=self.fc2(x)

        return x

# Encoder Layer

<img src="./image/transformer_resideual_layer_norm_3.png" width="500" height="400">

In [None]:
class EncoderLayer(nn.Module):
    def __init__(self, d_model, ffn_hidden ,n_head, drop_prob):
        super(EncoderLayer, self).__init__()
        self.d_model = d_model
        self.ffn_hidden = ffn_hidden
        self.n_head = n_head
        self.drop_prob = drop_prob

        self.attention = MultiHeadAttention(d_model, n_head)
        self.norm1= LayerNorm(d_model)
        self.drop1= nn.Dropout(drop_prob)

        self.ffn= PositionwiseFeedForward(d_model, ffn_hidden,drop_prob)
        self.norm2= LayerNorm(d_model)
        self.drop2= nn.Dropout(drop_prob)

    def forward(self, x, mask=None):
        _x=x
        x=self.attention(x,x,x,mask)

        x=self.norm1(x+_x)
        x=self.drop1(x)

        _x=x
        x=self.ffn(x)

        x=self.norm2(x+_x)
        x=self.drop2(x)

        return x



