In [None]:
from dataclasses import dataclassimport torch
import torch.nn as nn
from torch.nn import functional as F

#------------------------------------------------------------------------------------------------------------------

class MLP(nn.Module):

    def __init__(self, config):
        super().__init__()
        self.c_fc   = nn.Linear(config.n_embd, 4* config.n_embd)
        self.gelu   = nn.GELU(approximate = 'tanh')
        self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd)

    def forward(self, x):
        x = self.c_fc(x)
        x = self.gelu(x)
        x = self.c_proj(x)
        return x
    
    
class Block(nn.Module):
"""Instantiates the block that h recurses over.  These are GPT2's hidden layers.
"""
    def __init__(self, config):
        super().__init()
        self.ln_1 = nn.LayerNorm(config.n_embd)
        self.attn = CausalSelfAttention(config)
        self.ln_2 = nn.LayerNorm(config.n_embd)
        self.mlp = MLP(config)

    def forward(self, x):
        x = x + self.attn(self.ln_1(x))
        x = x + self.mlp(self.ln_2(x))
        return x

@dataclass
class GPTConfig:
    block_size: int = 256
    vocab_size: int = 65
    n_layers: int = 6
    n_head: int = 6
    n_embd: int = 384

class GPT(nn.module):
"""Initializze an nn.moduleDict, in which submodules can be indexed with keys just like a dictionary.
wte = weights of token embeddings as nn.Embeddings which is just an array of numbers (indexed with stringS)
wpe = weights of positional embeddings
h = hidden layers, indexed using integers where the integer is the hidden layer number
ln_f = final layernorm according to the GPT-2 paper
lm_h = final classifier/language model head, projects from embd_dims (768) to the vocab size (50257)"""


    def __init__(self, config):
        super().__init__()
        self.config = config

        self.transformer = nn.ModuleDict(dict(        
            wte = nn.Embedding(config.vocab_size, config.n_embd),
            wpe = nn.Embedding(config.block_size, config.n_embd),
            h = nn.ModuleList([Block(config) for _ in range(config.n_layer)])),
            ln_f = nn.LayerNorm(config.n_embd),
        ))
        self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)