<a href="https://colab.research.google.com/github/vgandhi13/ReproducingGPT2/blob/main/ReproducingGPT2v2.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [8]:
from dataclasses import dataclass
import torch
import torch.nn as nn
from torch.nn import functional as F
import math
import inspect

class CausalSelfAttention(nn.Module):
  def __init__(self, config):
    assert config.n_embd % config.n_head == 0
    super().__init__()
    #key, query, value projections for all heads, but in a batch
    self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd)
    #output projection
    self.c_proj = nn.Linear(config.n_embd, config.n_embd)
    self.c_proj.NANOGPT_SCALE_INIT = 1
    #regularization
    self.n_head = config.n_head
    self.n_embd = config.n_embd
    #not really a bias more of a mask, but following Openai/HF naming
    self.register_buffer('bias', torch.tril(torch.ones(config.block_size, config.block_size)).view(1,1,config.block_size, config.block_size))

  def forward(self, x):
    B, T, C = x.size() # batch, size, sequence legnth
    #calculate query , key, values for all heads in batch and move head forward to be the batch
    #nh is 'number of heads', hs is 'head size', and C (number of channels) = ns * hs
    #eg in GPT-2 (124M), n_head = 12, hs = 654, so nh*hs = 768 channels in the transformer
    qkv = self.c_attn(x) #(B,T,3*C)
    q,k,v = qkv.split(self.n_embd, dim=2) #(B,T,C), (B,T,C), (B,T,C)
    #nh is number of heads, hs is head size
    k = k.view(B,T,self.n_head, C//self.n_head).transpose(1,2) #(B,nh,T,hs)
    q = q.view(B,T,self.n_head, C//self.n_head).transpose(1,2) #(B,nh,T,hs)
    v = v.view(B,T,self.n_head, C//self.n_head).transpose(1,2) #(B,nh,T,hs)
    #attention (materializes the large (T,T) matrix for all the queries and keys)

    #REGULAR ATTENTION APPROACH
    # att = (q @ k.transpose(-2, -1)) * (-1.0 / math.sqrt(k.size(-1)))
    # att = att.masked_fill(self.bias[:,:,:T,:T] == 0, float('-inf'))
    # att = F.softmax(att, dim=-1)
    # y = att @ v # (B,nh, T, T) X [B, nh, T, hs] -> (B,nh,T,hs)

    #FLASH ATTENTION APPROACH
    y = F.scaled_dot_product_attention(q,k,v, is_causal=True)

    y = y.transpose(1,2).contiguous().view(B,T,C) # reassemble all head outputs side by side
    #output projection
    y = self.c_proj(y)
    return y

class MLP(nn.Module):
  def __init__(self, config):
    super().__init__()
    self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd)
    self.gelu = nn.GELU(approximate='tanh')
    self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd)
    self.c_proj.NANOGPT_SCALE_INIT = 1

  def forward(self, x):
    x = self.c_fc(x)
    x = self.gelu(x)
    x = self.c_proj(x)
    return x

class Block(nn.Module):
  def __init__(self, config):
    super().__init__()
    self.ln_1 = nn.LayerNorm(config.n_embd)
    self.attn = CausalSelfAttention(config)
    self.ln_2 = nn.LayerNorm(config.n_embd)
    self.mlp = MLP(config)

  def forward(self, x):
    x = x + self.attn(self.ln_1(x))
    x = x + self.mlp(self.ln_2(x))
    return x


@dataclass
class GPTConfig:
  block_size: int = 1024 #max sequence length
  vocab_size: int = 50257 # number of tokens: 50,000 BPE merges + 256 bytes tokens + 1 <|endoftoken|>
  n_layer: int = 12 #number of layers
  n_head: int = 12 #number of heads
  n_embd: int = 768 #embedding dimension

class GPT(nn.Module):#turns into pytorch module
  def __init__(self, config):
    super().__init__()
    self.config = config

    self.transformer = nn.ModuleDict(dict( #we can query using name of param (identical to GPT2)
        wte = nn.Embedding(config.vocab_size, config.n_embd),
        wpe = nn.Embedding(config.block_size, config.n_embd),
        h = nn.ModuleList([Block(config) for _ in range(config.n_layer)]), #all the different blocks one after the other, #we can query using number of layer (identical to GPT2)
        ln_f = nn.LayerNorm(config.n_embd) #additional thing added to transformer arch by openai
    ))
    self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False) #final linear layer

    #weight sharing scheme
    #why we do this ?
    # 1. this weight is 768*50257 = 40 mil params, (30% of 124M). So makes training more efficient
    # 2.  Input embeddings and output embeddings represent the same space: tokens ↔ embeddings.
    # 3. Tying ensures consistency: the representation used to encode a word is also used when predicting it.
    # 4. Empirically, this improves perplexity
    self.transformer.wte.weight = self.lm_head.weight #(vocab_size, n_embd)

    #init params
    #iteralte all modules here
    self.apply(self._init_weights)

  #Prevent vanishing/exploding activations
  def _init_weights(self, module):
    if isinstance(module, nn.Linear):
      std = 0.02
      #per-layer scaling trick applied to some linear layers (specifically the projection layers inside attention and MLP blocks)
      #It rescales the initialization standard deviation to account for residual connections.
      #This way, variance is preserved across depth, so activations don’t explode/vanish.
      if hasattr(module, 'NANOGPT_SCALE_INIT'):
        std *= (2 * module.NANOGPT_SCALE_INIT) ** -0.5  #for variance handling
      torch.nn.init.normal_(module.weight, mean=0.0, std = std)
      if module.bias is not None:
        torch.nn.init.zeros_(module.bias)
    elif isinstance(module, nn.Embedding):
      torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)

  def configure_optimizers(self, weight_decay, learning_rate, device):
    #start with all of the candidate parameters(that require grad)
    param_dict = {pn: p for pn, p in self.named_parameters()}
    param_dict = {pn: p for pn, p in param_dict.items() if p.requires_grad}
    # create optim groups. Any parameters that is 2D+ will be weight decayed, otherwise no.
    # ie all weight tensors in matmuls + embeddings decay, all biases and layernorms dont
    # notes(16)
    decay_params = [p for n, p in param_dict.items() if p.dim() >= 2]
    nondecay_params = [p for n, p in param_dict.items() if p.dim() < 2]
    optim_groups = [
        {'params': decay_params, 'weight_decay': weight_decay},
        {'params': nondecay_params, 'weight_decay': 0.0}
    ]
    num_decay_params = sum(p.numel() for p in decay_params)
    num_non_decay_params = sum(p.numel() for p in nondecay_params)
    print(f"num decayed parameter tensors: {len(decay_params)}, with {num_decay_params:,} parameters")
    print(f"num non-decayed parameter tensors: {len(nondecay_params)}, with {num_non_decay_params:,} parameters")
    #Create AdamW optimizer and use the fused version if it is available
    # notes(17)
    fused_available = 'fused' in inspect.signature(torch.optim.AdamW).parameters
    use_fused = fused_available and 'cuda' in device
    print(f"using fused AdamW:  {use_fused}")
    optimizer = torch.optim.AdamW(optim_groups, lr=learning_rate, betas=(0.9, 0.95), eps=1e-8, fused=use_fused)
    return optimizer

  def forward(self, idx, targets = None):
    # idx is of shape (B,T) , token indices, batch dim of B, time dim of upto t, t
    B, T = idx.size()
    assert T<= self.config.block_size, f"Cannot forward sequence of length {T},block size is"
    #forward the token and position embeddings
    pos = torch.arange(0, T, dtype = torch.long, device = idx.device) #shape(T)
    pos_emb = self.transformer.wpe(pos) #POSITION EMBEDDINGS of shape (T,n_embd)
    tok_emb = self.transformer.wte(idx) #token embeddings of shape (B,T,n_embd)
    x = tok_emb + pos_emb
    #forward the blocks of the transformer
    for block in self.transformer.h:
      x = block(x)

    x = self.transformer.ln_f(x)  #(B,T,n_embd)
    logits = self.lm_head(x) #(B,T,vocab_size) - essentially give for each batch, gives t+1th probable element
    loss = None
    if targets is not None:
      # here cross entropy cannot take 3 dim vector, so it is flattening it by first making logits of shape (B*T, vocab_size), and targets to shape (B*T)
      loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1))
    return logits, loss


  @classmethod
  def from_pretrained(cls, model_type):
    '''Loads pretrained GPT2 model weights from huggingface'''
    assert model_type in {'gpt2', 'gpt2-medium', 'gpt2-large', 'gpt2-xl'}
    from transformers import GPT2LMHeadModel
    print("loading weights from pretrained gpt: %s" % model_type)

    #n_layer, n_head, and n_embd are determined from model_type
    config_args = {
        'gpt2': dict(n_layer=12, n_head=12, n_embd=768), #124M Params
        'gpt2-medium': dict(n_layer=24, n_head=16, n_embd=1024), #350M params
        'gpt2-large': dict(n_layer=36, n_head=20, n_embd=1280), #774M params
        'gpt2-xl': dict(n_layer=48, n_head=25, n_embd=1600), #1558M params
    }[model_type]
    config_args['vocab_size'] = 50257 # always 50257 for GPT model checkpoints
    config_args['block_size'] = 1024 # always 1024 for GPT model checkpoints
    # create a from scratch initialized miniGPT model
    config = GPTConfig(**config_args)
    model = GPT(config)
    sd = model.state_dict()
    sd_keys = sd.keys()
    sd_keys = [k for k in sd_keys if not k.endswith('.attn.bias')] #discard this mask/buffer

    #init a huggingface trasformer model
    model_hf = GPT2LMHeadModel.from_pretrained(model_type)
    sd_hf = model_hf.state_dict()

    sd_keys_hf = sd_hf.keys()
    sd_keys_hf = [k for k in sd_keys_hf if not k.endswith('.attn.masked_bias')] #discard this mask/buff
    sd_keys_hf = [k for k in sd_keys_hf if not k.endswith('.attn.bias')]
    transposed = ['attn.c_attn.weight', 'attn.c_proj.weight', 'mlp.c_fc.weight', 'mlp.c_proj.weight']
    assert len(sd_keys) == len(sd_keys_hf), f"mismatched keys: {len(sd_keys)} != {len(sd_keys_hf)}"
    for k in sd_keys:
      if any(k.endswith(w) for w in transposed):
        assert sd_hf[k].shape[::-1] == sd[k].shape
        with torch.no_grad():
          sd[k].copy_(sd_hf[k].t())
      else:
        assert sd_hf[k].shape == sd[k].shape
        with torch.no_grad():
          sd[k].copy_(sd_hf[k])
    return model

#----------------------------------------------------------------------------------------------------------------------------------------------------
#DATA LOADER

import tiktoken

class DataLoaderLite:
  def __init__(self, B, T):
    self.B = B
    self.T = T
    #load tiny shakespeare dataset
    !wget https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt
    with open('input.txt', 'r', encoding='utf-8') as f:
      text = f.read()

    enc = tiktoken.get_encoding('gpt2')
    tokens = enc.encode(text)
    self.tokens = torch.tensor(tokens)
    print(f"loaded {len(self.tokens)} tokens")
    print(f"1 epoch = {len(self.tokens) // (B*T)} batches")

    #state
    self.current_position = 0

  def next_batch(self):
    B, T = self.B, self.T

    # Now, Process token sequences and feed them into transformer. Rearrange tokens into idx variable feeding into transformer.
    # We dont want single very long one dimensional sequence, we want a batch where each sequence is upto T tokens (T cannot be larger than maximum sequence length).
    # We have B indpendent examples of T sequences. So we, need to create a (B,T) tensor which we can feed to the forward out of this 1 dimensional sequences
    buf = self.tokens[self.current_position : self.current_position+B*T+1] # B*T + 1 because we need the next token of the last (B*Tth) token as well for training
    x = buf[:-1].view(B,T) #we exclude last one because it is the extra target token for (B*Tth) token
    y = buf[1:].view(B,T) # we exclude first one because it is not a target for any token

    #advance the position in the tensor
    self.current_position += B*T
    #if loading the next batch would be out of bounds, reset
    if self.current_position + (B*T+1) > len(self.tokens):
      self.current_position = 0
    return x, y


#----------------------------------------------------------------------------------------------------------------------------------------------------
#TRAINING


# setting up gpu use
device = 'cpu'
if torch.cuda.is_available():
  device = 'cuda'

#for reproducibility
torch.manual_seed(1337)
if torch.cuda.is_available():
  torch.cuda.manual_seed(1337)


train_loader = DataLoaderLite(B=16, T=1024)
torch.set_float32_matmul_precision('high') #for faster training. Every where there is a multiplcation in our linear layers, pytorch will now run this mult on cores, utilizung tf32 precision


# model = GPT.from_pretrained('gpt2')
# model.eval()
# model.to(device)
# print('didnt crash yay!!')
model = GPT(GPTConfig(vocab_size=50304)) #padding of tokens to improve speed. check notes(11)
model.to(device)

#compile for neural nets, like gcc in c. makes stuff faster. reduces python overhead and GPU read/writes
#unlike python interpreter that looks at code sequentially, this looks at all the operations ahead and time and then optimizes things based on that.
model = torch.compile(model)

##Learning Rate Scheduler
max_lr = 6e-4
min_lr = max_lr * 0.1
warmup_steps = 10
max_steps = 50
def get_lr(it):
  #1. linear warmup for warmup_iters steps
  if it < warmup_steps:
    return max_lr * (it+1) / warmup_steps #LR ramping up linearly
  #2. if it > lr_decay_iters, return min learning rate. Learning rate stays at min_lr (flat) after training schedule finishes.
  if it > max_steps:
    return min_lr
  #3. in between, use cosine decay down to min learning rate
  #LR decays cosine-shaped from max_lr down to min_lr
  decay_ratio = (it - warmup_steps) / (max_steps - warmup_steps)
  assert 0 <= decay_ratio <= 1
  coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio)) #coeff starts at 1 and goes to 0
  return min_lr + coeff * (max_lr - min_lr)



import time
#optimize!!
optimizer = model.configure_optimizers(weight_decay = 0.1, learning_rate=6e-4, device=device)

for step in range(max_steps):
  t0 = time.time()
  x, y = train_loader.next_batch()
  x, y = x.to(device), y.to(device)
  optimizer.zero_grad()
  with torch.autocast(device_type=device, dtype=torch.bfloat16): # TF32 for FP32 matmuls outside autocast. BF16 for eligible ops inside autocast, FP32 for critical ops (loss, gradients, optimizer state)
    logits, loss = model(x, y)
  loss.backward()
  norm = torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) #calculate global norm of params, every single grad on all params, we square it, add it all up, and take a big square root and this is the norm of param vector. check note(13)
  #determine and set the learning rate for this iteration
  lr = get_lr(step)
  for param_group in optimizer.param_groups: #there is a notion of diff param groups that can exist in optimizer. in our case there is just 1, but we have to do this
    param_group['lr'] = lr
  optimizer.step()
  torch.cuda.synchronize() #cpu might already reach here, but this will wait for gpus to finish the work they were assigned. is only needed for benchmarking. For real training you can drop it (but keep it if you want accurate dt timings).
  t1 = time.time()
  dt = (t1-t0)*1000 #time difference in miliseconds
  tokens_per_sec = (train_loader.B * train_loader.T) / (t1-t0)
  print(f'step {step}, loss: {loss.item()}, norm: {norm:.4f}, lr: {lr:.4e}, dt: {dt:.2f}, tok/sec: {tokens_per_sec:.2f}')


print(loss)
print(logits.shape)
#


--2025-09-01 12:55:02--  https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.111.133, 185.199.110.133, 185.199.108.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.111.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 1115394 (1.1M) [text/plain]
Saving to: ‘input.txt.3’


2025-09-01 12:55:02 (18.0 MB/s) - ‘input.txt.3’ saved [1115394/1115394]

loaded 338025 tokens
1 epoch = 20 batches
num decayed parameter tensors: 50, with 124,354,560 parameters
num non-decayed parameter tensors: 98, with 121,344 parameters
using fused AdamW:  True
step 0, loss: 10.963605880737305, norm: 13.3154, lr: 6.0000e-05, dt: 1610.54, tok/sec: 10172.97
step 1, loss: 9.654325485229492, norm: 6.5502, lr: 1.2000e-04, dt: 96.66, tok/sec: 169497.76
step 2, loss: 9.386366844177246, norm: 6.2438, lr: 1.8000e-04, dt: 88.15, tok/sec: 185857.91
st

Notes:

1. In pytorch you can directly do model.to(device), but for tensors, you need to save them in variable x = x.to(device)
2. All tensors in pytorch by default are float32 which 19.5 teraflops on A100. If numbers have fewer bits of representation, it is easier to move them around. We have a finite capacity of bits our GPU can store but in addition to that there is a limitation to the speed with which we can access. Many of deep learning workload memory bounds, so most of tensor cores are idle because they are waiting around for data. So even if we get 60% utilization of hardware, we are doing well
3. Most of the computation happens in linear layers. Entire transformer is bunch of matrix multiplicaton. Biggest one is the classifier layer at the top that goes from 768 to 50257. Matrix mult becomes faster by lowering precision
4. TF32 is 13 bits lesser in mantissa than FP32. Inputs and outputs are fp32, but internally it switches to tf32 and this increases speed by 8 times, and we cannot tell much difference in the results.
5. Numbers like 16 and 32 are good for Batches and Time. But something like 17 is bad.
6. Lowering precision in model training typically decreases training time because lower-precision data (like FP16 compared to FP32) allows hardware to perform computations faster, requires less memory bandwidth, and enables higher parallelism—so more operations happen per clock cycle. However, we expect slight less accurate results but empirically this is a worth it tradeoff. Because you can train longer to make up for the precision decrease
7. HBM is connected with GPU. GPU is where most calc happens, but it also has some memory. Most of memory is in high bandwidth memory(HBM). These are two separate chips. HBM is off chip. On GPU, there are large number of streaming processor all of which are SM, and this is where lot of the calculations happen. Single SM has 4 quadrants, each has a tensor core, and different subunits where calcs happen. On GPU chips there is L2 cache, but then on the SM in GPU, there is l1 cache,and registers. The way the memory is stored on GPU is quite different from HBM. So, there is meomery inside GUP but it is not a lot of memory.
8. Now, even if main memory of computer is very large, it is very inefficient as GPU would have to go through CPU to reach disk, and this is very time intensive. Then GPUs have HBM which are large in memory, but are also expensive to access. Then, on the GPU chips itself everything is very fast, but there is very few memory on it(in MBs as opposed to gbS), but it is lightning fast. So basically whenver, we have these kernels, we take these inputs which live on HBM, we start streaming data to gpu chip, we do computations, and then send it back to HBM. If we dont use torch.compile, we we are doing this HBM->GPU->HBM transfers many more times. But when we use it, since it already know stuff ahead in time, it optimizes data transfer and does all the computations together. So operation fusion, allows to keep chunk of data on chip, do lots of computation, and then do a single transfer back to HBM.
9. Flash attention: Fusing matmul, dropout, softmax, mask, and matmul to a fused kernel of flash attention. It is a kernel fusing algorithm torch.compile cannot find, and the reason is algorithmic rewrite of how attention is implemented. Flashattention does more flops(arithmetic ops) than regular attention. But it is significantly faster(7.6x) because very mindful of the memory heirarchy described above. Very mindful about what is in HBM, what is in shared memory. Very careful of how it orchastrates the computation, such that we have a fewer reads and writes with HBM. So even though we are doing more flops, the expensive part is the load and store. The NXN matrix (att in our code), never gets materialized at any point in HBM, and never gets read or written to HBM. For each head,
10. It is good to deal with powers of two because that is how cuda works
11. When we pass in custom vocab_size=50304, which is > 50257, (to make a an ugly number into a nicer number power of 2 to increase speed), wte becomes larger, but these newer tokens rows/vectors are never used because GPT2 only has 50257 tokens. We will never index into these rows, so wasting a little bit of space. Now, it is share with classifier at the end,  so we are predicting probabilities for tokens that will never be present in the training set and so therefore the network has to learn that these probabilities have to be driven to 0. And so the logits that the network produces have to drive those dimensions of the output to negative infinity. But this is no different to tokens not present in our dataset, ie Shakespeare dataset only uses 1000 of the 50000+ tokens. So functionally nothing breaks, we just use extra memory. It is running faster because many kernels use block tiles which are powers of 2, so calcs done in chunks of 64. When desired calc does not neatly fit into those block tiles, there are all kind of boundary kernels that can kick in to do the last part. In a lot of kernels, they will truncate up your input and will do the nice part first and then they have a whole second phase where they come back to anything that remains and process it, but this could be very inefficient. So instead pad the input, and make it fit nicely.
12. According to Andrej, gpt 2 has open weights, but paper does not have much detail for training, for gpt 3, paper has a lot of info, but weights not released. Roughly speaking they are very similar architecturally  , apart from some hyperparameters,context length size, training time, more data
13. We make sure length is not more than 1.0. People like to use this gradient norm clipping, because sometimes you can get unlucky with optimization because of bad data batch or something like that. If you get very unlucky in batch you might get really high loss and really high loss could lead to really high gradient and this could basically shock your model and the optimization . So this prevents model from getting too big of shocks in terms of gradient magnitudes. Bit of a hacky solution, patch on top of bigger issues. You can visualize this norm, and if norm of gradient is horizontal, it is well behaved and fine, if it is climbing, things are bad, sometimes you can get spike which is also bad due to training instability. Norm very high in begining as learning a lot of new stuff.
14. For context, I trained the model on A100 GPU
15. Cosine learning rate that we used here has been popularize by gpt 2 and 3 papers, but there are other learning rate schedules and this is an active area of research.
16. We weight decay weights involved in matrix multiplications, embeddings, etc, but not those which are biases, or layer norms, etc. We decay weights, because it is like regularization. when you are pulling down weights all the weights, you are forcing the optimization to use more of the weights and you are not allowing any of the weights to be way too large. You are forcing the network to distribute the work across more channels
17. Fused AdamW - Not available by default, need to turn it on when running on cuda. What it does is, instead of running in a for loop over all the paramter tensors and updating them, that would launch a lot of kernels. Fused just means all those kernels are fused into a single kernel, you get rid of a lot of overhead, and you single time on all the parameters call a kernel that updates them. So it is basically a kernel fusion for AdamW update instead of updating all the tensors.
    

In [None]:
num_return_sequences = 5
max_length = 30