In [2]:
import torch
import torch.nn as nn
import math

In [3]:
# I'm deliberately not using the inbuilt torch modules like attention, etc.

In [4]:
dim_model = 768

In [None]:
class MultiHeadAttention(nn.Module):
    def __init__(self, num_heads, dim_model, dropout=0.1):
        super().__init()
        assert dim_model % num_heads == 0, "dim_model must be divisible by num_heads"

        self.num_heads = num_heads
        self.dim_model = dim_model
        self.dim_k = dim_model // num_heads # we are assuming dim_k = dim_v

        self.query_projections = nn.Linear(dim_model, self.dim_k * num_heads)
        self.key_projections = nn.Linear(dim_model, self.dim_k * num_heads)
        self.value_projections = nn.Linear(dim_model, self.dim_k * num_heads)
        self.fc = nn.Linear(self.dim_k * num_heads, dim_model)

        self.dropout = nn.Dropout(dropout)

    def forward(self, x, mask=None):

        batch_size, seq_length = x.shape[:2]

        # Shape: (batch_size, num_heads, seq_length, dim_k)
        queries = self.query_projections(x).view(batch_size, seq_length, self.num_heads, self.dim_k).transpose(1,2)
        keys = self.key_projections(x).view(batch_size, seq_length, self.num_heads, self.dim_k).transpose(1,2)
        values = self.value_projections(x).view(batch_size, seq_length, self.num_heads, self.dim_k).transpose(1,2)

        scores = torch.matmul(queries, keys.transpose(-1,-2)) / (self.dim_k ** 0.5)
        if mask is not None: scores = scores.masked_fill(mask == 0, float('-inf'))
        attention_weights = nn.functional.softmax(scores, dim=-1)
        attention_weights = self.dropout(attention_weights)

        multi_context = torch.matmul(attention_weights, values)
        context = self.fc(multi_context.transpose(1,2).contiguous().view(batch_size, seq_length, self.num_heads * self.dim_k))
        return context

In [None]:
class DecoderLayer(nn.Module):
    def __init__(self, dim_model, num_heads, dropout=0.1):
        super().__init__()
        self.masked_multi_attention = MultiHeadAttention(num_heads, dim_model)
        self.layer_norm_1 = nn.LayerNorm(normalized_shape=dim_model)
        self.attention_dropout = nn.Dropout(p=dropout)

        #FFN
        self.fc1 = nn.Linear(dim_model, dim_model * 4)
        self.gelu1 = nn.GELU()
        self.fc2 = nn.Linear(dim_model * 4, dim_model)
        self.ffn_dropout = nn.Dropout(p=dropout)
        self.layer_norm_2 = nn.LayerNorm(normalized_shape=dim_model)

        self._initialize_weights()

    def _initialize_weights(self):
        for module in self.modules():
            if isinstance(module, nn.Linear):
                nn.init.normal_(module.weight, mean=0.0, std=0.02)  # GPT-1 style
                if module.bias is not None:
                    nn.init.zeros_(module.bias)
            elif isinstance(module, nn.LayerNorm):
                nn.init.ones_(module.weight)
                nn.init.zeros_(module.bias)

    def forward(self, x):
        batch_size, seq_length = x.shape[:2]

        mask = torch.tril(torch.ones(seq_length, seq_length, device=x.device)).bool()
        mask = mask.unsqueeze(0).unsqueeze(0)

        context = self.masked_multi_attention(x, mask=mask)
        context = self.attention_dropout(context)

        x = context + x # residual connection
        layer_norm_1 = self.layer_norm_1(x)

        #FFN
        f1 = self.fc1(layer_norm_1)
        act1 = self.gelu1(f1)
        f2 = self.fc2(act1)
        ffn_dropout = self.ffn_dropout(f2)

        f_out = ffn_dropout + layer_norm_1 # residual connection
        layer_norm_2 = self.layer_norm_2(f_out) # Post-LN as in GPT-1
        return layer_norm_2

In [None]:
class GPT1(nn.Module):
    def __init__(self, vocab_size, dim_model, num_heads, num_decoder_layers, max_seq_length, dropout=0.1):
        super().__init__()
        self.token_embedding = nn.Embedding(vocab_size, dim_model)
        self.positional_embedding = nn.Embedding(max_seq_length, dim_model)
        self.decoder_layers = nn.ModuleList([DecoderLayer(dim_model, num_heads) for _ in range(num_decoder_layers)])
        self.layer_norm = nn.LayerNorm(dim_model)  # final LN
        self.output_head = nn.Linear(dim_model, vocab_size, bias=False)
        self.output_head.weight = self.token_embedding.weight
        self.dropout = nn.Dropout(dropout)
        self._initialize_weights()

    def _initialize_weights(self):
        for module in self.modules():
            if isinstance(module, nn.Linear):
                nn.init.normal_(module.weight, mean=0.0, std=0.02)
                if module.bias is not None:
                    nn.init.zeros_(module.bias)
            elif isinstance(module, nn.Embedding):
                nn.init.normal_(module.weight, mean=0.0, std=0.02)
            elif isinstance(module, nn.LayerNorm):
                nn.init.ones_(module.weight)
                nn.init.zeros_(module.bias)

    def forward(self, input_ids, targets=None):
        batch_size, seq_length = input_ids.shape
        positions = torch.arange(0, seq_length, device=input_ids.device).unsqueeze(0).expand(batch_size, seq_length)
        x = self.token_embedding(input_ids) + self.positional_embedding(positions)
        x = self.dropout(x)
        for layer in self.decoder_layers:
            x = layer(x)
        x = self.layer_norm(x)
        logits = self.output_head(x)
        if targets is None:
            return logits
        # Shift for next-token prediction
        shift_logits = logits[..., :-1, :].contiguous()
        shift_targets = targets[..., 1:].contiguous()
        loss = nn.functional.cross_entropy(shift_logits.view(-1, shift_logits.size(-1)), shift_targets.view(-1))
        return logits, loss