# Transformers Decoder

It is designed to handle autoregressive sequence generation tasks like text generation and completion.

The architecture is similar to encoder-only transformer with two differences:

* **Masked multi-head self-attention**. It helps the model specialize in predicting the next word in a sequence one step at a time, iteratively generating messages, answers, or any text just like GPT and other autoregressive LLMs do. For each token in the target sequence, only the previously generated tokens are observed, whereas subsequent tokens are hidden by using an upper triangular mask that prevents attending to future positions.

* **Transformer Head** model consists of a linear layer with softmax activation over the entire vocabulary to estimate the likelihood of each word or token being the next one to generate, and returning the most likely one.

## Masked Self-Attention

<img src="./img/masked_self_attn.png" alt="masked_self_attn" style="width: 300px;"/>

Mask is a triangular matrix. By passing this matrix to the attention heads, each token in the sequence only pays attention to "past" information on its left-hand side.

In [1]:
import torch
import torch.nn as nn

In [2]:
sequence_length = 5

In [3]:
self_attention_mask = (
    1 - torch.triu(
        torch.ones(1, sequence_length, sequence_length), diagonal=1
    )
).bool()

In [4]:
self_attention_mask

tensor([[[ True, False, False, False, False],
         [ True,  True, False, False, False],
         [ True,  True,  True, False, False],
         [ True,  True,  True,  True, False],
         [ True,  True,  True,  True,  True]]])

## Transformer Decoder from scratch

In [5]:
from transformers_utils import PositionalEncoder, MultiHeadAttention, FeedForwardTransformation

In [6]:
class DecoderLayer(nn.Module):
    def __init__(self, d_model, num_heads, d_ff, dropout):
        super().__init__()
        
        self.attention = MultiHeadAttention(d_model, num_heads)
        self.feed_forward = FeedForwardTransformation(d_model, d_ff)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, mask):
        attn_output = self.attention(x,x,x,mask)
        x = self.norm1(x + self.dropout(attn_output))
        ff_output = self.feed_forward(x)
        x = self.norm2(x + self.dropout(ff_output))
        return x


class TransformerDecoder(nn.Module):
    def __init__(self, vocab_size, d_model, num_layers, num_heads, d_ff, dropout, max_seq_length):
        super().__init__()
        
        self.embedding = nn.Embedding(vocab_size, d_model)
        self.positional_encoding = PositionalEncoder(d_model, max_seq_length)
        self.layers = nn.ModuleList(
            [DecoderLayer(d_model, num_heads, d_ff, dropout) for _ in range(num_layers)]
        )
        # linear layer (head) for next-word prediction
        self.fc = nn.Linear(d_model, vocab_size)

    def forward(self, x, self_mask):
        x = self.embedding(x)
        x = self.poistional_encoding(x)
        for layer in self.layers:
            x = layer(x, self_mask)
        # Apply the forward pass through the model head
        x = self.fc(x)
        return nn.functional.log_softmax(x, dim=-1)
        

## Testing Decoder

In [7]:
num_classes = 3
vocab_size = 10000
batch_size = 8
d_model = 512
num_heads = 8
num_layers = 6
d_ff = 2048
sequence_length = 64
dropout = 0.1

In [8]:
input_sequence = torch.randint(0, vocab_size, (batch_size, sequence_length))

In [9]:
## Triangular attention mask for causal attention
self_attention_mask = (
    1 - torch.triu(
        torch.ones(1, sequence_length, sequence_length), diagonal=1
    )
).bool()

In [10]:
# Instantiate the decoder transformer
decoder = TransformerDecoder(vocab_size, d_model, num_layers, num_heads, d_ff, dropout, sequence_length)

output = decoder(input_sequence, self_attention_mask)
print(output.shape)
print(output[0])

torch.Size([8, 64, 10000])
tensor([[ -8.8693,  -9.6007,  -8.9304,  ...,  -9.6856,  -9.3465,  -9.6564],
        [ -9.8812, -10.1586, -10.0967,  ...,  -8.9913,  -9.0750,  -9.9167],
        [ -8.8258,  -9.0934,  -9.8143,  ...,  -9.5052,  -9.8370,  -8.6690],
        ...,
        [ -9.6214,  -9.4889,  -8.9430,  ...,  -9.5053,  -9.2300,  -9.4115],
        [ -8.7715, -10.4404,  -8.6053,  ...,  -9.7100,  -8.7736,  -8.5819],
        [ -7.7874,  -9.9393,  -9.2000,  ...,  -9.2514,  -8.7640,  -9.0817]],
       grad_fn=<SelectBackward0>)
