# Transformer Encoder-Decoder

The encoder and decoder are connected via cross-attention layer.

<img src="./img/encoder_decoder_cross_attn.png" alt="encoder_decoder_cross_attn" style="width: 400px;"/>

Cross-Attention layer is added at each decoder layer after the masked attention and takes a double input: the information processed through the decoder, and the final hidden states produced by the encoder, thereby linking the transformer's two main building blocks.

This enables decoder to look-back at the input sequence and generate next in the target sequence.

<img src="./img/translation_task.png" alt="translation_task" style="width: 400px;"/>

In [1]:
import torch
import torch.nn as nn

from transformers_utils import *

In [35]:
class DecoderLayerCrossAttn(nn.Module):
    def __init__(self, d_model, num_heads, d_ff, dropout):
        super().__init__()
        # the causal (masked) self-attention and cross-attention
        self.self_attention = MultiHeadAttention(d_model, num_heads)
        self.cross_attention = MultiHeadAttention(d_model, num_heads)
        self.feed_forward = FeedForwardTransformation(d_model, d_ff)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.norm3 = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, y, causal_mask, cross_mask):        
        self_attn_output = self.self_attention(x,x,x,causal_mask)
        x = self.norm1(x + self.dropout(self_attn_output))

        # x - decoder information flow, becomes cross-attention query
        # y - encoder output, becomes cross-attention key and values
        cross_attn_output = self.cross_attention(x,y,y,cross_mask)
        
        x = self.norm2(x + self.dropout(cross_attn_output))
        ff_output = self.feed_forward(x)
        x = self.norm3(x + self.dropout(ff_output))
        
        return x
        

class TransformerDecoderCrossAttn(nn.Module):
    def __init__(self, vocab_size, d_model, num_layers, num_heads, d_ff, dropout, max_seq_length):
        super().__init__()
        
        self.encoder_embedding = torch.nn.Embedding(vocab_size, d_model)
        self.decoder_embedding = torch.nn.Embedding(vocab_size, d_model)
        
        self.positional_encoding = PositionalEncoder(d_model, max_seq_length)
        self.layers = nn.ModuleList(
            [DecoderLayer(d_model, num_heads, d_ff, dropout) for _ in range(num_layers)]
        )
        # linear layer (head) for next-word prediction
        self.fc = nn.Linear(d_model, vocab_size)
        self.dropout = torch.nn.Dropout(dropout)

    def forward(self, x, y, causal_mask, cross_mask):
        # Embeddings of input sequences
        x = self.dropout(
            self.positional_encoding(
                self.encoder_embedding(x))
        )
        # Embeddings of output sequences
        y = self.dropout(
            self.positional_encoding(
                self.decoder_embedding(y))
        )
        
        for layer in self.layers:
            out = layer(x, y, causal_mask, cross_mask)
            
        # Apply the forward pass through the model head
        out = self.fc(out)
        return nn.functional.log_softmax(out, dim=-1)

The decoder only needs to take actual target sequences during training time. 

In translation, these would be examples of translations associated with the source-language sequences fed to the encoder. 

In text summarization, the output embeddings for the decoder are summarized versions of the input embeddings for the encoder, and so on.

Words in the target sequence act as our training labels during the next-word generation process. At inference time, the decoder assumes the role of generating a target sequence, starting with an empty output embedding and gradually taking as its inputs the target words it is just generating on the fly.

# Complete Transformer Model

In [36]:
class Trasformer(nn.Module):
    def __init__(self, vocab_size, d_model, num_layers, num_heads, d_ff, dropout, max_seq_length):
        super().__init__()
        self.encoder = TransformerEncoder(vocab_size, d_model, num_layers, num_heads, d_ff, dropout, max_seq_length)
        self.decoder = TransformerDecoderCrossAttn(vocab_size, d_model, num_layers, num_heads, d_ff, dropout, max_seq_length)

    def forward(self, src, src_mask, causal_mask):
        encoder_out = self.encoder(src, src_mask)
        decoder_out = self.decoder(src, encoder_out, causal_mask, mask)
        return decoder_out

In [37]:
num_classes = 3
vocab_size = 10000
batch_size = 8
d_model = 512
num_heads = 8
num_layers = 6
d_ff = 2048
sequence_length = 64
dropout = 0.1

In [38]:
# input sequences
input_sequence = torch.randint(0, vocab_size, (batch_size, sequence_length))
padding_mask = torch.randint(0, 2, (sequence_length, sequence_length))
causal_mask = torch.triu(torch.ones(sequence_length, sequence_length), diagonal=1)

In [39]:
encoder = TransformerEncoder(vocab_size, d_model, num_layers, num_heads, d_ff, dropout, sequence_length)
decoder = TransformerDecoderCrossAttn(vocab_size, d_model, num_layers, num_heads, d_ff, dropout, sequence_length)

In [None]:
encoder_output = encoder(input_sequence, padding_mask)
decoder_output = decoder(input_sequence, encoder_output, padding_mask, causal_mask)
print("Batch's output shape: ", decoder_output.shape)