In [1]:
import torch

### Decoder Layer

##### Example

In [1]:
from torch import nn
import torch

from foundation.transformer.encoder import ResidualLayerNorm, PostionWiseFeedForward
from foundation.transformer.efficient_attention import MultiHeadAttention

In [4]:
#| export
class DecoderLayer(nn.ModuleList):
    def __init__(self, d_model, n_heads, d_ff, dropout):
        super().__init__()
        self.norm_1 = ResidualLayerNorm(d_model)
        self.norm_2 = ResidualLayerNorm(d_model)
        self.norm_3 = ResidualLayerNorm(d_model)
        
        self.masked_mha = MultiHeadAttention(d_model, n_heads)
        self.encoder_decoder_mha = MultiHeadAttention(d_model, n_heads)
        self.feed_forward = PostionWiseFeedForward(d_model, d_ff)
    
    def forward(self, x, encoder_outputs, trg_mask, src_mask):
        # shape(x) = [batch_size x trg_seq_len x d_model]
        # shape(encoder_output) = [batch_size x src_seq_len x d_model]
        
        
        # shape(masked_mha) = [batch_size x trg_se_len x d_model]
        # shape(masked_mha_attn_weights) 
        # = [batch_size x n_heads x trg_seq_len x trg_seq_len]
        masked_mha, masked_mha_attn_weights = self.masked_mha(x, x, x, mask=trg_mask)
        
        norm_1 = self.norm_1(masked_mha, x)
        
        # shape(encoder_decoder_mha) = [batch_size x trg_seq_len x d_model]
        # shape(encoder_decoder_mha_attn_weights) = [batch_size x n_heads x trg_seq_len x trg_seq_len]
        encoder_decoder_mha, encoder_decoder_mha_attn_weights = self.encoder_decoder_mha(
            pre_q=norm_1, pre_k=encoder_outputs, pre_v=encoder_outputs,
            mask=src_mask
        )
        
        norm_2 = self.norm_2(encoder_decoder_mha, norm_1)
        feed_forward = self.feed_forward(norm_2)
        norm_3 = self.norm_3(feed_forward, norm_2)
        return norm_3, masked_mha_attn_weights, encoder_decoder_mha_attn_weights

In [5]:
decoder_layer = DecoderLayer(d_model=10, n_heads=2, d_ff=16, dropout=0.3)

In [None]:
decoder_layer()