In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [None]:
torch.einsum

# Implementation from ChatGPT

In [None]:
class MultiHeadAttention(nn.Module):
    # allows the decoder to focus on different parts of the input sequence for each head, enabling it to capture a wide range of dependencies within the data.
    def __init__(self, embed_size, heads):
        super(MultiHeadAttention, self).__init__()
        self.embed_size = embed_size
        self.heads = heads
        self.head_dim = embed_size // heads

        assert (
            self.head_dim * heads == embed_size
        ), "Embedding size needs to be divisible by heads"

        self.values = nn.Linear(self.head_dim, self.head_dim, bias=False)
        self.keys = nn.Linear(self.head_dim, self.head_dim, bias=False)
        self.queries = nn.Linear(self.head_dim, self.head_dim, bias=False)
        self.fc_out = nn.Linear(heads * self.head_dim, embed_size)

    def forward(self, values, keys, queries, mask=None):
        # Optionally, a mask can be passed to ignore certain positions within the input sequences during attention calculation, useful for handling variable-length sequences or excluding specific tokens.
        N = queries.shape[0]
        value_len, key_len, query_len = values.shape[1], keys.shape[1], queries.shape[1]

        # Split the embedding into multiple heads
        values = values.reshape(N, value_len, self.heads, self.head_dim)
        keys = keys.reshape(N, key_len, self.heads, self.head_dim)
        queries = queries.reshape(N, query_len, self.heads, self.head_dim)

        values = self.values(values)
        keys = self.keys(keys)
        queries = self.queries(queries)

        # Scaled dot-product attention for each head
        energy = torch.einsum("nqhd,nkhd->nhqk", [queries, keys])
        if mask is not None:
            energy = energy.masked_fill(mask == 0, float("-1e20"))
        attention = torch.softmax(energy / (self.embed_size ** (1/2)), dim=-1)

        out = torch.einsum("nhql,nlhd->nqhd", [attention, values]).reshape(
            N, query_len, self.heads * self.head_dim
        )
        out = self.fc_out(out)

        return out


# Transformer Block

In [None]:
class TransformerDecoderBlock(nn.Module):
    def __init__(self, embed_size, heads, forward_expansion, dropout, device):
        super(TransformerDecoderBlock, self).__init__()
        self.attention = MultiHeadAttention(embed_size, heads)
        self.norm1 = nn.LayerNorm(embed_size)
        self.norm2 = nn.LayerNorm(embed_size)

        # a simple two-layer feed-forward neural network that applies further transformations to the output of the attention layer.
        self.feed_forward = nn.Sequential(
            nn.Linear(embed_size, forward_expansion * embed_size),
            nn.ReLU(),
            nn.Linear(forward_expansion * embed_size, embed_size),
        )
        self.dropout = nn.Dropout(dropout)  # reduce overfitting by randomly setting a fraction of the inputs to zero during training. This encourages the model to learn more robust features that are not reliant on any small set of neurons.
        self.device = device

    def forward(self, x, mask):
        # Self attention
        attention = self.attention(x, x, x, mask)
        x = self.dropout(self.norm1(attention + x))  # These help in mitigating the vanishing gradient problem by allowing gradients to flow directly through the network.
        
        # Feed forward
        forward = self.feed_forward(x)
        out = self.dropout(self.norm2(forward + x))
        
        return out


To use this in a decoder-only architecture, like GPT, you would stack multiple instances of TransformerDecoderBlock, feeding the output of one block as the input to the next. Additionally, you'd typically start with an embedding layer to convert token indices into vectors and end with a linear layer to project the transformer output back to the vocabulary space for prediction. The architecture would learn to generate text or other sequence data by predicting the next token in a sequence given the previous tokens.