In [16]:
import os
import sys
import importnb
from torch import nn
import torch
import numpy as np

In [17]:
notebook_path = os.getcwd()
parent_dir = os.path.dirname(notebook_path)
sys.path.append(parent_dir)
with __import__('importnb').Notebook(): 
    from utils.tools import MultiHeadAttention
    from utils.tools import AddPositionalEncoding
    from utils.tools import TransformerFFN

In [18]:
class TransformerEncoderLayer(nn.Module):
    def __init__(
        self,
        d_model:int,
        d_ff:int,
        num_head:int,
        dropout_rate:float,
        layer_norm_eps:float,
    ) -> None:
        super().__init__()
        ###layerの宣言
        self.mha = MultiHeadAttention(num_head,d_model)
        self.layernorm_mha = nn.LayerNorm(d_model,eps=layer_norm_eps)
        self.dropout_mha = nn.Dropout(dropout_rate)

        self.ffn = TransformerFFN(d_model,d_ff)
        self.dropout_ffn = nn.Dropout(dropout_rate)
        self.layernorm_ffn = nn.LayerNorm(d_model,eps=layer_norm_eps)

    def forward(
        self,
        x:torch.Tensor,
        mask:torch.Tensor=None
    ) -> torch.Tensor:
        
        ###attention層を通す
        ###add+layernorm
        x = self.layernorm_mha(self.__get_mha_output(x,mask)+x)
        
        ###FFN層を通す
        ###add+layernorm
        x = self.layernorm_ffn(self.__get_ffn_output(x)+x)

        return x
        
    def __get_mha_output(
        self,
        x:torch.Tensor,
        mask:torch.Tensor=None
    ) -> torch.Tensor:
        x = self.mha(x,x,x,mask)
        x = self.dropout_mha(x)
        return x
        
    def __get_ffn_output(
        self,
        x:torch.Tensor,
    ) -> torch.Tensor:
        x = self.ffn(x)
        x = self.dropout_ffn(x)
        return x

In [24]:
class TransformerEncoder(nn.Module):
    def __init__(
        self,
        d_model:int,
        d_ff:int,
        num_head:int,
        dropout_rate:float,
        layer_norm_eps:float,
        ###変更点
        max_len:int,
        src_vocab_size:int,
        N:int,
        pad_idx:int,
        device:torch.device=torch.device("cpu"),
    ) -> None:
        super().__init__()
        ###InputEmbedding層の定義
        self.embedding = nn.Embedding(src_vocab_size,d_model,pad_idx)
        ###positionalencoding層の定義
        self.pos = AddPositionalEncoding(d_model,max_len,device)
        ###encoderlayer層の定義
        self.encoder_layers = nn.ModuleList(
            [
                TransformerEncoderLayer(
                    d_model,d_ff,num_head,dropout_rate,layer_norm_eps
                )
                for _ in range(N)
            ]
        )
    def forward(
        self,
        x:torch.Tensor,
        mask:torch.Tensor=None,
    ) -> torch.Tensor:
        # テンソルを表す変数（例：input_tensor）があると仮定
        assert x.dtype == torch.int64 or x.dtype == torch.int32,"xを整数型にしてください"
        x = self.embedding(x)
        #x = self.pos(x)
        for layer in self.encoder_layers:
            x = layer(x,mask)
        return x

In [26]:
class TransformerDecoderLayer(nn.Module):
    def __init__(
        self,
        d_model:int,
        d_ff:int,
        num_head:int,
        dropout_rate:float,
        layer_norm_eps:float
    ) -> None:
        super().__init__()
        ###layerの宣言
        self.self_mha = MultiHeadAttention(num_head,d_model)
        self.mha = MultiHeadAttention(num_head,d_model)
        self.layernorm_self_mha = nn.LayerNorm(d_model,eps=layer_norm_eps)
        self.layernorm_mha = nn.LayerNorm(d_model,eps=layer_norm_eps)
        self.dropout_self_mha = nn.Dropout(dropout_rate)
        self.dropout_mha = nn.Dropout(dropout_rate)
        self.ffn = TransformerFFN(d_model,d_ff)
        self.dropout_ffn = nn.Dropout(dropout_rate)
        self.layernorm_ffn = nn.LayerNorm(d_model,eps=layer_norm_eps)
        
    def forward(
        self,
        x:torch.Tensor,
        encoder_key_value:torch.Tensor,
        mask_key_value:torch.Tensor,
        mask_self:torch.Tensor,
    )->torch.Tensor:
        x = self.layernorm_self_mha(self.__get_self_mha_output(x,mask_self)+x)
        x = self.layernorm_mha(self.__get_mha_output(
            q=x,
            k=encoder_key_value,
            v=encoder_key_value,
            mask=mask_key_value
        )+x)
        x = self.layernorm_ffn(self.__get_ffn_output(x)+x)
        return x
        
    def __get_self_mha_output(
        self,
        x:torch.Tensor,
        mask:torch.Tensor=None
    )->torch.Tensor:
        x = self.mha(x,x,x,mask)
        x = self.dropout_self_mha(x)
        return x
        
    def __get_mha_output(
        self,
        q:torch.Tensor,
        k:torch.Tensor,
        v:torch.Tensor,
        mask:torch.Tensor=None
    )->torch.Tensor:
        q = self.mha(q,k,v,mask)
        q = self.dropout_mha(q)
        return q
        
    def __get_ffn_output(
        self,
        x:torch.Tensor,
    )->torch.Tensor:
        x = self.ffn(x)
        x = self.dropout_ffn(x)
        return x

In [28]:
class TransformerDecoder(nn.Module):
    def __init__(
        self,
        tgt_vocab_size:int,
        max_len:int,
        pad_idx:int,
        d_model:int,
        N:int,
        d_ff:int,
        num_head:int,
        dropout_rate:float,
        layer_norm_eps:float,
        device:torch.device = torch.device('cpu')
    )->None:
        super().__init__()
        self.embedding = nn.Embedding(tgt_vocab_size,d_model,pad_idx)
        self.pos = AddPositionalEncoding(d_model,max_len,device=device)
        self.decoder_layers = nn.ModuleList(
            [
                TransformerDecoderLayer(
                    d_model,d_ff,num_head,dropout_rate,layer_norm_eps
                )
                for _ in range(N)
            ]
        )
    def forward(
        self,
        tgt:torch.Tensor,
        src:torch.Tensor,
        mask_key_value:torch.Tensor,
        mask_self:torch.Tensor
    )->torch.Tensor:
        tgt = self.embedding(tgt)
        tgt = self.pos(tgt)
        for layer in self.decoder_layers:
            tgt = layer(
                x=tgt,
                encoder_key_value=src,
                mask_key_value=mask_key_value,
                mask_self=mask_self
            )
        return tgt

In [33]:
class Transformer(nn.Module):
    def __init__(
        self,
        src_vocab_size:int,
        tgt_vocab_size:int,
        max_len:int,
        d_model:int=512,
        num_head:int=8,
        d_ff:int=2048,
        N:int=6,
        dropout_rate:float=0.1,
        layer_norm_eps:float=1e-5,
        pad_idx:int=0,
        device:torch.device=torch.device("cpu")
    )->None:
        super().__init__()
        self.pad_idx = pad_idx
        self.device = device
        self.max_len = max_len
        ###それぞれの層の初期化
        self.encoder = TransformerEncoder(
            src_vocab_size=src_vocab_size,
            max_len=max_len,
            pad_idx=pad_idx,
            d_model=d_model,
            N=N,
            d_ff=d_ff,
            num_head=num_head,
            dropout_rate=dropout_rate,
            layer_norm_eps=layer_norm_eps,
            device=device 
        )
        self.decoder = TransformerDecoder(
            tgt_vocab_size=tgt_vocab_size,
            max_len=max_len,
            pad_idx=pad_idx,
            d_model=d_model,
            N=N,
            d_ff=d_ff,
            num_head=num_head,
            dropout_rate=dropout_rate,
            layer_norm_eps=layer_norm_eps,
            device=device 
        )
        self.linear = nn.Linear(
            d_model,
            tgt_vocab_size
        )
    def forward(
        self,
        src:torch.Tensor,
        tgt:torch.Tensor
    )->torch.Tensor:
        src_mask = self.__pad_mask(src)
        src = self.encoder(
            x=src,
            mask=src_mask
        )
        dec_mask = torch.logical_or(
            self.__subsequent_mask(tgt),self.__pad_mask(tgt)
        )
        tgt = self.decoder(
            tgt,
            src,
            src_mask,
            dec_mask
        )
        return self.linear(tgt)
        
    def __pad_mask(
        self,
        x:torch.Tensor
    )->torch.Tensor:
        seq_len = x.shape[1]
        mask = x.eq(self.pad_idx)
        mask = mask.unsqueeze(1)
        mask = mask.repeat(1,seq_len,1)
        return mask.to(self.device)

    def __subsequent_mask(
        self,
        x:torch.Tensor
    )->torch.Tensor:
        batch_size = x.shape[0]
        seq_len = x.shape[1]
        return (
            torch.tril(
                torch.ones(batch_size,self.max_len,self.max_len)
            ).eq(0).to(self.device)
        )