In [None]:
GPT_CONFIG_124M = {
    "vocab_size" : 50257,
    "context_length" : 1024,
    "emb_dim" : 768,
    "n_heads" : 12,
    "n_layers" : 12, #This is the number of transformer blocks.
    "drop_rate" : 0.1,
    "qkv_bias" : False
}

In [None]:
import torch
import torch.nn as nn

"""Developing a dummy GPT class"""

class DummyGPTModel(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        ## Each 50257 tokens have an embedding of 768 dimensions.
        self.tok_emb = nn.Embedding(cfg["vocab_size"], cfg["emb_dim"])
        ## The positional embeddings are created for each of the 1024 positions -> 768 dim
        self.pos_emb = nn.Embedding(cfg["context_length"], cfg["emb_dim"])
        self.drop_emb = nn.Dropout(cfg["drop_rate"])
        self.trf_blocks = nn.Sequential(
            *[DummyTransformerBlock(cfg)
              for _ in range (cfg["n_layers"])]
        )
        #The final output is in 768 dimensions
        self.final_norm = DummyLayerNorm(cfg["emb_dim"])
        #The word is then converted back from 758 -> 50257
        self.out_head = nn.Linear(cfg["emb_dim"],cfg ["vocab_size"], bias=False)

    def forward(self, in_idx):
        batch_size, seq_len = in_idx.shape
        tok_embeds = self.tok_emb(in_idx)
        pos_embeds = self.pos_emb(
            torch.arrange(seq_len, device=in_idx.device)
        )
        x = tok_embeds + pos_embeds
        #Apply dropout to embeddings
        x = self.drop_emb(x)
        # Pass the embedding through the transformer
        x = self.trf_blocks(x)
        # Convert the output to the final norm (768 dimensions)
        x = self.final_norm(x)
        # Convert everything back to the 50257 dimensions.
        logits = self.out_head(x)
        return logits
    
    class DummyTransformerBlock(nn.Module):
        def __init__(self, cfg):
            super().__init__()

        def forward(self, x):
            return x

    class DummyLayerNorm(nn.Module):
        def __init__(self, normalized_shape, eps=1e-5):
            super().__init__()

        def forward(self, x):
            return x
