# 제작한 Transformer 훈련하기

## Transformer 불러오기

In [1]:
import torch
import torch.nn as nn

class selfAttention(nn.Module) :
    def __init__(self,embed_size, heads) -> None:
        '''
        embed_size : input 토큰 개수, 논문에서는 512개로 사용 
        heads : multi_head의 개수, 논문에서는 8개 사용

        Self Attention은 특정 단어(query)와 다른 단어(keys) 간의 중요도를 파악하는 매커니즘이다.
        '''

        super().__init__()


        self.embed_size = embed_size # 512차원
        self.heads = heads # 8개
        self.head_dim = embed_size // heads # 64차원(개별 attention의 차원)
        

        '''
        dict에서 쓰는 key,value 와 같다.
        query는 현재 찾고자 하는 값이다.
        '''
        # input feature, output feature
        self.values = nn.Linear(self.head_dim, self.head_dim, bias=False) 
        self.keys = nn.Linear(self.head_dim, self.head_dim, bias=False)
        self.queries = nn.Linear(self.head_dim, self.head_dim, bias=False)

        # Multi-headed attention을 만듬
        # fully connected out 
        # input feature = outfut feature
        self.fc_out = nn.Linear(heads*self.head_dim, embed_size) # 64 * 8 
        
    def forward(self, values,keys,query,mask) :
        N = query.shape[0] # 단어 개수
        value_len = values.shape[1] # head 차원
        key_len = keys.shape[1] 
        query_len = query.shape[1]

        values = values.reshape(N,value_len, self.heads,self.head_dim)
        keys = keys.reshape(N,key_len, self.heads,self.head_dim)
        queries = query.reshape(N,query_len, self.heads,self.head_dim)

        vlaues = self.values(values) 
        keys = self.values(keys)
        queries = self.queries(queries)

        # score = Q dot K^T 
        score = torch.einsum("nqhd,nkhd->nhqk", [queries,keys]) 
        # queries shape : N,value_len, self.heads,self.head_dim
        # keys shape : N,key_len, self.heads,self.head_dim
        # score shape : N, heads, query_len, key_len
        
        # decoder 구조인 masked Self Attention 적용 시 활용되는 구문
        # score = -inf로 둬서 값을 예측 하도록 한다.
        if mask is not None :
            score = score.masked_fill(mask == 0, float("-1e20"))
            '''
            mask = 0 인 값에 대해서 -inf 대입
            -1e20 = -inf
            -inf이기 때문에 값이 0에 수렴
            mask가 부여된 경우 score 값을 0으로 준다.

            '''
        # attention 정의
        attention = torch.softmax(score / (self.embed_size**(1/2)),dim=3)

        out = torch.einsum("nhql,nlhd -> nqhd",[attention, values]).reshape(
            N,query_len,self.heads * self.head_dim
            )
        # attention shape : N, heads,query_len,key_len
        # values shape : N, value_len, heads, heads_dim
        # out shape : N, query_len, heads * head_dim

        # concat all heads 
        out = self.fc_out(out)

        return out
        
class TransformerBlock(nn.Module) :
    def __init__(self,embed_size, heads, dropout, forward_expansion) -> None:
        '''
        embed_size : token 개수 | 논문 512개
        heads : attention 개수 | 논문 8개
        dropout : 골고루 학습하기 위한 방법론 
        forward_expansion : forward 계산시 차원을 얼마나 늘릴 것인지 결정, 임의로 결정하는 값
                            forward_차원 계산은 forward_expension * embed_size 
                            논문에서는 4로 정함. 총 2048차원으로 늘어남.

        '''
        super().__init__()

        # Attention 정의
        self.attention = selfAttention(embed_size,heads)
        
        ### Norm & Feed Forward
        self.norm1 = nn.LayerNorm(embed_size)
        self.norm2 = nn.LayerNorm(embed_size)
        
        self.feed_forawrd = nn.Sequential(
            # 차원을 512 -> 2048로 증가
            nn.Linear(embed_size,forward_expansion*embed_size),
            # 차원을 Relu 연산
            nn.ReLU(),
            # 차원 2048 -> 512로 축소 
            nn.Linear(forward_expansion*embed_size,embed_size)
            )
        self.dropout = nn.Dropout(dropout)

    ### Encoder 구현 
    def forward(self, value,key,query,mask) :
        # self Attention
        attention = self.attention(value, key, query, mask)
        # Add & Normalization
        x = self.dropout(self.norm1(attention + query))

        # Feed_Forward
        forward = self.feed_forawrd(x)
        # Add & Normalization
        out = self.dropout(self.norm2(forward + x))
        return out 

class Encoder(nn.Module) :
    def __init__(
        self, 
        src_vocab_size,
        embed_size,
        num_layers,
        heads,
        device,
        forward_expansion,
        dropout,
        max_length,
        ) -> None:
        
        '''
        src_vocab_size : input vocab 개수  
        num_layers : Encoder block 구현할 개수
        dropout : dropout 비율
        max_length : 
        '''
    
        super().__init__()

        
        self.embed_size = embed_size
        self.device = device

        # 시작부분 구현(input + positional_embeding)
        self.word_embeding = nn.Embedding(src_vocab_size, embed_size) # row / col
        self.position_embeding = nn.Embedding(max_length,embed_size) # row / col

        # Transformer Layer 구현 
        self.layers = nn.ModuleList(
            [TransformerBlock(
                    embed_size,
                    heads,
                    dropout=dropout,
                    forward_expansion = forward_expansion,
                )
            for _ in range(num_layers)]
        )

        # dropout = 0 ~ 1
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, mask):
        N,seq_length = x.shape # (a,b)
        positions = torch.arange(0, seq_length).expand(N,seq_length).to(self.device) # (a,b)

        out = self.dropout(self.word_embeding(x) + self.position_embeding(positions))

        for layer in self.layers :
            # query, key, value
            out = layer(out,out,out,mask)
            
        return out 

class DecoderBlock(nn.Module) :
    def __init__(self,embed_size, heads, forward_expansion, dropout, device) -> None:
        super().__init__()

        self.norm = nn.LayerNorm(embed_size)
        self.attention = selfAttention(embed_size, heads=heads)
        self.transfromer_block = TransformerBlock(
            embed_size, heads, dropout, forward_expansion
        )
        self.dropout = nn.Dropout(dropout)

    def forward(self,x,value,key,src_mask, target_mask) :
        # output에 대한 attention 수행
        attention = self.attention(x,x,x,target_mask)

        # add & Norm
        query = self.dropout(self.norm(attention + x))
        
        # encoder_decoder attention + feed_forward 
        out = self.transfromer_block(value, key, query, src_mask)
        return out 


class Decoder(nn.Module) :
    def __init__(
        self,
        trg_vocab_size, 
        embed_size, 
        num_layers, 
        heads, 
        forward_expansion, 
        dropout, 
        device, 
        max_length
        
        ) -> None:
        super().__init__()
        self.device = device
        self.word_embedding = nn.Embedding(trg_vocab_size, embed_size)
        self.position_embedding = nn.Embedding(max_length,embed_size)

        self.layers = nn.ModuleList(
            [DecoderBlock(embed_size,heads,forward_expansion, dropout, device)
                for _ in range(num_layers)]
        )
        self.fc_out = nn.Linear(embed_size,trg_vocab_size)
        self.dropout = nn.Dropout(dropout)

    def forward(self,x,enc_out, src_mask, trg_mask) :
        N, seq_length = x.shape
        positions = torch.arange(0,seq_length).expand(N,seq_length).to(self.device)
        x = self.dropout((self.word_embedding(x) + self.position_embedding(positions)))

        for layer in self.layers :
            x = layer(x,enc_out, enc_out, src_mask, trg_mask)
        
        out = self.fc_out(x)
        
        return out

class transformer(nn.Module) :
    def __init__(
        self,
        src_vocab_size,
        trg_vocab_size,
        src_pad_idx,
        trg_pad_idx,
        embed_size = 512,
        num_layers = 6,
        forward_expansion = 4,
        heads = 8,
        dropout = 0,
        device = 'cpu',
        max_length = 100
    ) -> None:
    
        super().__init__()


        self.Encoder = Encoder(
            src_vocab_size,
            embed_size,
            num_layers,
            heads,
            device,
            forward_expansion,
            dropout,
            max_length,
        )

        self.decoder = Decoder(trg_vocab_size, 
            embed_size, 
            num_layers, 
            heads, 
            forward_expansion, 
            dropout, 
            device, 
            max_length)

        self.src_pad_idx = src_pad_idx
        self.trg_pad_idx = trg_pad_idx
        self.device = device
    
    def mask_src_mask(self,src) :
        src_mask = (src != self.src_pad_idx).unsqueeze(1).unsqueeze(2)
        #(N,1,1,src_len)
        return src_mask.to(self.device)

    def mask_trg_mask(self,trg) : 
        # trg = triangle
        N, trg_len = trg.shape
        trg_mask = torch.tril(torch.ones((trg_len, trg_len))).expand(N,1,trg_len,trg_len)
        return trg_mask.to(self.device)
    
    def forward(self,src,trg) :
        src_mask = self.mask_src_mask(src)
        trg_mask = self.mask_trg_mask(trg)
        enc_src = self.Encoder(src,src_mask)
        out = self.decoder(trg,enc_src, src_mask, trg_mask)
        return out


        

## Data Load

In [7]:
from torchtext.datasets import WikiText2
from torchtext.data.utils import get_tokenizer
from torchtext.vocab import build_vocab_from_iterator

train_iter = WikiText2(split='train')
tokenizer = get_tokenizer('basic_english')
vocab = build_vocab_from_iterator(map(tokenizer, train_iter), specials=['<unk>'])
vocab.set_default_index(vocab['<unk>'])