In [20]:
import torch
from torch import nn, optim
from torch.nn import functional as F
import copy
import math
import numpy as np

In [35]:
class Transformer(nn.Module):
    def __init__(self, src_embed, trg_embed, encoder, decoder, fc_layer):
        super(Transformer, self).__init__()
        self.encoder = encoder
        self.decoder = decoder
        self.src_embed = src_embed
        self.trg_embed = trg_embed
        self.fc_layer = fc_layer
    
    def forward(self, x, z, src_mask, trg_mask):
        c = self.encoder(x, src_mask)
        y = self.decoder(z, trg_mask, c, src_mask)
        y = self.fc_layer(y)
        y = F.log_softmax(y, dim=-1)
        return y

In [30]:
class TransformerEmbedding(nn.Module):
    def __init__(self, embedding, positional_encoding):
        super(TransformerEmbedding, self).__init__()
        self.embedding = nn.Sequential(embedding, positional_encoding)

    def forward(self, x):
        out = self.embedding(x)
        return out

In [31]:
class Embedding(nn.Module):
    def __init__(self, d_embed, vocab):
        super(Embedding, self).__init__()
        self.embedding = nn.Embedding(len(vocab), d_embed)
        self.vocab = vocab
        self.d_embed = d_embed

    def forward(self, x):
        out = self.embedding(x) * math.sqrt(self.d_embed)
        return out

In [33]:
class PositionalEncoding(nn.Module):
    def __init__(self, d_embed, max_seq_len=5000):
        super(PositionalEncoding, self).__init__()
        encoding = torch.zeros(max_seq_len, d_embed)
        position = torch.arange(0, max_seq_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_embed, 2) * -(math.log(10000.0) / d_embed))
        encoding[:, 0::2] = torch.sin(position * div_term)
        encoding[:, 1::2] = torch.cos(position * div_term)
        self.encoding = encoding

    def forward(self, x):
        out = x + torch.tensor(self.encoding[:, :x.size(1)], requires_grad=False)
        out = self.dropout(out)
        return out

In [9]:
class Encoder(nn.Module):
    def __init__(self, encoder_layer, n_layer):
        super(Encoder, self).__init__()
        self.layer = []
        for i in range(n_layer):
            self.layer.append(copy.deepcopy(encoder_layer))
    
    def forward(self, x, mask):
        out = x
        for layer in self.layer:
            out = layer(out, mask)
        return out

In [15]:
class EncoderLayer(nn.Module):
    def __init__(self, multi_head_attention_layer, position_wise_feed_forward_layer, norm_layer):
        super(EncoderLayer, self).__init__()
        self.multi_head_attention_layer = multi_head_attention_layer
        self.position_wise_feed_forward_layer = position_wise_feed_forward_layer
        self.residual_connection_layer = [ResidualConnectionLayer(copy.deepcopy(norm_layer)) for i in range(2)]

    def forward(self, x, mask):
        out = self.residual_connection_layer[0](x, lambda x: self.multi_head_attention_layer(x, x, x, mask))
        out = self.residual_connection_layers[1](x, lambda x: self.position_wise_feed_forward_layer(x))
        return out

In [25]:
class Decoder(nn.Module):
    def __init__(self, sub_layer, n_layer):
        super(Decoder, self).__init_()
        self.layers = []
        for i in range(n_layer):
            self.layers.append(copy.deepcopy(sub_layer))
    
    #mask=subs, encoder_mask=pad
    def forward(self, x, mask, encoder_output, enmcoder_mask):
        out = x
        for layer in self.layers:
            out = layer(out, mask, encoder_output, enmcoder_mask)
        return out

In [28]:
class DecoderLayer(nn.Module):
    def __init__(self, masked_multi_head_attention_layer, multi_head_attention_layer, position_wise_feed_forward_layer, norm_layer):
        super(DecoderLayer, self).__init__()
        self.masked_multi_head_attention_layer = ResidualConnectionLayer(masked_multi_head_attention_layer, copy.deepcopy(norm_layer))
        self.multi_head_attention_layer = ResidualConnectionLayer(multi_head_attention_layer, copy.deepcopy(norm_layer))
        self.position_wise_feed_forward_layer = ResidualConnectionLayer(position_wise_feed_forward_layer, copy.deepcopy(norm_layer))
    
    def forward(self, x, mask, encoder_output, encoder_mask):
        out = self.masked_multi_head_attention_layer(query=x, key=x, value=x, mask=mask)
        out = self.multi_head_attention_layer(query=out, key=encoder_output, value=encoder_output, mask=encoder_mask)
        out = self.position_wise_feed_forward_layer(x=out)
        return out

In [7]:
#쿼리는 특정 토큰
#키와 밸류는 동일한 토큰을 가리킴
#dk는 key의 디멘

class MultiHeadAttentionLayer(nn.Module):
    def __init__(self, d_model, h, qkv_fc_layer, fc_layer):
        super(MultiHeadAttentionLayer, self).__init__()
        self.d_model = d_model
        self.h = h
        self.query_fc_layer = copy.deepcopy(qkv_fc_layer)
        self.key_fc_layer = copy.deepcopy(qkv_fc_layer)
        self.value_fc_layer = copy.deepcopy(qkv_fc_layer)
        self.fc_layer = fc_layer
    
    def calcualate_attention(self, query, key, value, mask):
        d_k = key.size(-1)
        attention_score = torch.matmul(query, key.traspose(-2, -1))
        attention_score = attention_score / math.sqrt(d_k)
        if mask is not None:
            attention_score = attention_score.masked_fill(mask==0, -1e9)
        attention_prob = F.softmax(score, dim=-1)
        out = torch.matmul(attention_prob, value)
        return out
    
    def forward(self, query, key, value, mask=None):
        n_batch = query.shape[0]
        
        def transform(x, fc_layer):
            out = fc_layer(x)
            out = out.view(n_batch, -1, self.h, self.d_model // self.h)
            out = out.transpose(1, 2)
            return out
        
        query = transform(query, self.query_fc_layer)
        key = transform(key, self.key_fc_layer)
        value = transform(value, self.value_fc_layera)
        
        if mask is not None:
            mask = mask.unsqueeze(1)
        out = self.calcualate_attention(query, key, value, mask)
        out = out.transpose(1, 2)
        out = out.contiguous().view(n_batch, -1, self.d_model)
        out = self.fc_layer(out)
        return out

In [11]:
class PositionWiseFeedForwardLayer(nn.Module):
    def __init__(self, first_fc_layer, second_fc_layer):
        self.first_fc_layer = first_fc_layer
        self.second_fc_layer = second_fc_layer
        
    def forward(self, x):
        out = self.first_fc_layer(x)
        out = F.relu(out)
        out = F.dropout(out)
        out = self.second_fc_layer(out)
        return out

In [12]:
class ResidualConnectionLayer(nn.Module):
    def __init__(self, norm_layer):
        super(ResidualConnectionLayer, self).__init__()
        self.norm_layer = norm_layer
    
    def forward(self, x, sub_layer):
        out = sub_layer(x) + x
        out = self.norm_layer(out)
        return out

In [22]:
def subsequent_mask(size):
    atten_shape = (1, size, size)
    mask = np.triu(np.ones(atatn_shape), k=1).astype('uint8')
    return torch.from_numpy(mask) == 0

def make_std_mask(tgt, pad):
    tgt_mask = (tgt != pad)
    tgt_mask = tgt_mask.unsqueeze(-2)
    tgt_mask = tgt_mask & torch.tensor(subsequent_mask(tgt.size(-1))).type_as(tgt_mask.data)
    return tgt_mask

In [37]:
def make_model(
    src_vocab, 
    trg_vocab, 
    d_embed = 512, 
    n_layer = 6, 
    d_model = 512, 
    h = 8, 
    d_ff = 2048):

    cp = lambda x: copy.deepcopy(x)

    # multi_head_attention_layer 생성한 뒤 copy해 사용
    multi_head_attention_layer = MultiHeadAttentionLayer(
                                    d_model = d_model,
                                    h = h,
                                    qkv_fc_layer = nn.Linear(d_embed, d_model),
                                    fc_layer = nn.Linear(d_model, d_embed))

    # position_wise_feed_forward_layer 생성한 뒤 copy해 사용    
    position_wise_feed_forward_layer = PositionWiseFeedForwardLayer(
                                        first_fc_layer = nn.Linear(d_embed, d_ff),
                                        second_fc_layer = nn.Linear(d_ff, d_embed))
    
    # norm_layer 생성한 뒤 copy해 사용
    norm_layer = nn.LayerNorm(d_embed, eps=1e-6)

    # 실제 model 생성
    model = Transformer(
                src_embed = TransformerEmbedding(    # SRC embedding 생성
                                embedding = Embedding(
                                                d_embed = d_embed, 
                                                vocab = src_vocab), 
                                positional_encoding = PositionalEncoding(
                                                d_embed = d_embed)), 

                trg_embed = TransformerEmbedding(    # TRG embedding 생성
                                embedding = Embedding(
                                                d_embed = d_embed, 
                                                vocab = trg_vocab), 
                                positional_encoding = PositionalEncoding(
                                                d_embed = d_embed)),
                encoder = Encoder(                    # Encoder 생성
                                sub_layer = EncoderLayer(
                                                multi_head_attention_layer = cp(multi_head_attention_layer),
                                                position_wise_feed_forward_layer = cp(position_wise_feed_forward_layer),
                                                norm_layer = cp(norm_layer)),
                                n_layer = n_layer),
                decoder = Decoder(                    # Decoder 생성
                                sub_layer = DecoderLayer(
                                                masked_multi_head_attention_layer = cp(multi_head_attention_layer),
                                                multi_head_attention_layer = cp(multi_head_attention_layer),
                                                position_wise_feed_forward_layer = cp(position_wise_feed_forward_layer),
                                                norm_layer = cp(norm_layer)),
                                n_layer = n_layer),
                fc_layer = nn.Linear(d_model, len(trg_vocab)))    # Generator의 FC Layer 생성
    
    return model

In [38]:
make_model()

TypeError: make_model() missing 2 required positional arguments: 'src_vocab' and 'trg_vocab'