In [None]:
%cd ..

import numpy as np
import random

import KittyLM
from KittyLM.layers import Attention, KarpathyCausalSelfAttention
import torch
import torch.nn as nn

print(KittyLM.__version__)

# Set seeds for reproducibility
random.seed(42)               # Python random seed
np.random.seed(42)            # Numpy random seed
torch.manual_seed(42)         # PyTorch CPU seed
# torch.cuda.manual_seed(42)    # PyTorch GPU seed (if using CUDA)



class KittyLMConfig:
    """
    Config according to the GPT-2 weights on huggingface.
    Using a vocab size that is a multiple of 64 to speed up the processing

    """
    block_size = 1024
    vocab_size = 50304 # 50257 in the original and hf implementation weights
    n_layer = 12
    n_heads = 12
    d_model = 768
    dropout = 0.0
    bias = True

def parity_check_attn(config, input_B, input_T):

    # create random input tensor
    B, T, dim, n_heads = input_B, input_T, config.d_model, config.n_heads
    input_tensor = torch.randn(B, T, dim)

    # Calculate attention on input tensor using custom implemented attention class 
    attention_layer = Attention(config)
    custom_output = attention_layer(input_tensor)
    print(custom_output)
    k_attention = KarpathyCausalSelfAttention(config)
    k_output = k_attention(input_tensor)
    print(k_output)
    # Calculate attention using torch.nn.MultiheadAttention
    # https://pytorch.org/docs/stable/generated/torch.nn.MultiheadAttention.html
    multihead_attn = nn.MultiheadAttention(embed_dim=dim, num_heads=n_heads, dropout=config.dropout, bias=config.bias, batch_first=True )
    query = input_tensor.view(B, T, dim)
    key = query.clone()
    value = query.clone()

    attn_output, attn_output_weights = multihead_attn(query, key, value)
    print(attn_output)
    assert k_output.size() == custom_output.size(), f"custom attn output and pytorch attn output not same size: {custom_output.size()} vs. {attn_output.size()}"
    
    diff = torch.max(torch.abs(k_output -  attn_output))

    return "diff btwn custom implemented and pytorch multihead attn", diff.item()
    
print(parity_check_attn(KittyLMConfig, 1, 10))



In [15]:
import math
import torch
import torch.nn as nn
import torch.nn.functional as F

class MLP(nn.Module):
    #pass 
    def __init__(self, config):
        super().__init__()
        self.c_fc = nn.Linear(config.d_model, 4*config.d_model, bias = config.bias)
        self.c_proj = nn.Linear(4*config.d_model, config.d_model, bias = config.bias)
        self.activation = nn.GELU() # avoid sudden zeroout of gradients and have a smoother actovation 
        self.dropout = nn.Dropout(config.dropout)

    def forward(self, input):
        return self.dropout(self.c_proj(self.activation(self.c_fc(input))))
        

class Attention(nn.Module):
    #pass
    def __init__(self, config):
        super(Attention, self).__init__()
        self.q_proj = nn.Linear(config.d_model, config.d_model, bias = config.bias)
        self.k_proj = nn.Linear(config.d_model, config.d_model, bias = config.bias)
        self.v_proj = nn.Linear(config.d_model, config.d_model, bias = config.bias)
        # self.c_attn = nn.Linear(config.d_model, 3 * config.d_model, bias=config.bias)
        # final projection after attention
        self.projection = nn.Linear(config.d_model, config.d_model, bias = config.bias)

        # these are self-explanatory
        self.attention_dropout = nn.Dropout(config.dropout)
        self.residual_dropout = nn.Dropout(config.dropout)

        self.n_heads = config.n_heads
        self.d_model = config.d_model
        self.dropout = config.dropout
        self.head_size = self.d_model // self.n_heads

        self.register_buffer(
            'causal_mask', 
            torch.tril(torch.ones(config.block_size, config.block_size)) # create a block_size * block_size mask
            .view(1, 1, config.block_size, config.block_size) # add singletons so that shape is B * nh * block_size * block_size
        )

    def forward(self, input):
        B, T, D = input.size() # batch, length, dimension

        # reshape q,k,v to (B, nh, T, hs) from (B, T, D) -> (B, T, nh, hs) -> (B, nh, T, hs)
        # view shouldnt be used to transpose / permute as it messes up the data. chain a 
        # seperate transpose operation to transpose the the sequence length and head dimensions 
        # q, k, v  = self.c_attn(input).split(self.d_model, dim=2)
        q = self.q_proj(input).view(B, T, self.n_heads, self.head_size).transpose(1, 2)
        k = self.k_proj(input).view(B, T, self.n_heads, self.head_size).transpose(1, 2)
        v = self.v_proj(input).view(B, T, self.n_heads, self.head_size).transpose(1, 2)

        # lets manually compute the attention score without einsum
        e = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
        e = e.masked_fill(self.causal_mask[:, :, :T, :T] == 0, float('-inf'))  # masking only the actual inportant information across sequencelength and head dimension
        alpha = F.softmax(e, dim = -1)
        alpha = self.attention_dropout(alpha)
        attention = alpha @ v
        attention = attention.transpose(1, 2).contiguous().view(B, T, D) # hstack all heads
        attention = self.projection(attention)
        attention = self.residual_dropout(attention)

        return attention


class LayerNorm(nn.Module):
    #pass
    def __init__(self, d_model, bias):
        super().__init__()
        self.weight = nn.Parameter(torch.ones(d_model))
        if bias is not None:
            self.bias = nn.Parameter(torch.ones(d_model))

    def forward(self, input):
        ln = F.layer_norm(
            input = input,
            normalized_shape = self.weight.shape,
            weight = self.weight,
            bias = self.bias
        )
        return ln




In [16]:
import math
#import torch
#import torch.nn as nn
#import torch.nn.functional as F

#from layers import MLP, Attention, LayerNorm

class KittyLMConfig:
    """
    Config according to the GPT-2 weights on huggingface.
    Using a vocab size that is a multiple of 64 to speed up the processing

    """
    block_size = 1024
    vocab_size = 50257 # 50257 in the original and hf implementation weights but 50304 is faster
    n_layer = 12
    n_heads = 12
    d_model = 768
    dropout = 0.0
    bias = None

class KittyLMBlock(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.preln = LayerNorm(config.d_model, bias = config.bias)
        self.attention = Attention(config)
        self.postln = LayerNorm(config.d_model, bias = config.bias)
        self.mlp = MLP(config)

    def forward(self, input):
        input = self.preln(input)
        input = self.attention(input)
        input = self.postln(input)
        output = self.mlp(input)
        return output
        # pass

class KittyLM(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.token_embeddings = nn.Embedding(num_embeddings = config.vocab_size, embedding_dim = config.d_model)
        self.position_embeddings = nn.Embedding(num_embeddings = config.block_size, embedding_dim = config.d_model)
        self.blocks = nn.ModuleList([KittyLMBlock(config) for _ in range(config.n_layer)])
        self.dropout = nn.Dropout(config.dropout)
        self.ln_f = LayerNorm(config.d_model, bias = config.bias)
        self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias = False)

        # weight tying
        self.token_embeddings.weight = self.lm_head.weight

        #init weights
        self.apply(self._init_weights)
        for name, parameter in self.named_parameters():
            if name.endswith('projection.weight'):
                nn.init.normal_(parameter, mean = 0.0, std = 0.2 / math.sqrt(2 * config.n_layer))

        #print(" parameter count : %.2fM" % (self._get_parameter_count(non_embedding = False) / 1e6))
        print(" parameter count : ", (self._get_parameter_count(non_embedding = False)))

    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            nn.init.normal_(module.weight, mean = 0.0, std = 0.2)
            if module.bias is not None:
                nn.init.zeros_(module.bias)
        elif isinstance(module, nn.Embedding):
            nn.init.normal_(module.weight, mean = 0.0, std = 0.2)

    def _get_parameter_count(self, non_embedding = True):

        nparams = sum(param.numel() for param in self.parameters())
        if non_embedding:
            nparams -= self.position_embeddings.weight.numel()
        return nparams

    
    def forward(self, input_ids):
        B, T = input_ids.size()
        assert T <= self.config.block_size, "Sequence length cannnot be greater than model capacity"

        token_embeddings = self.token_embeddings(input_ids)
        position_ids = torch.arange(0, T, dtype=torch.long, device=input_ids.device).unsqueeze(0)
        position_embeddings = self.position_embedding(position_ids)

        x = token_embeddings + position_embeddings
        x = self.dropout(x)
        for block in self.blocks:
            x = block(x)

        x = self.ln_f(x)
        logits = self.lm_head(x)
        return logits



In [17]:
k = KittyLM(KittyLMConfig)

 parameter count :  124337664


In [36]:
named_layers = dict(k.named_modules())
print(named_layers)

{'': KittyLM(
  (token_embeddings): Embedding(50257, 768)
  (position_embeddings): Embedding(1024, 768)
  (blocks): ModuleList(
    (0-11): 12 x KittyLMBlock(
      (preln): LayerNorm()
      (attention): Attention(
        (q_proj): Linear(in_features=768, out_features=768, bias=False)
        (k_proj): Linear(in_features=768, out_features=768, bias=False)
        (v_proj): Linear(in_features=768, out_features=768, bias=False)
        (projection): Linear(in_features=768, out_features=768, bias=False)
        (attention_dropout): Dropout(p=0.0, inplace=False)
        (residual_dropout): Dropout(p=0.0, inplace=False)
      )
      (postln): LayerNorm()
      (mlp): MLP(
        (c_fc): Linear(in_features=768, out_features=3072, bias=False)
        (c_proj): Linear(in_features=3072, out_features=768, bias=False)
        (activation): GELU(approximate='none')
        (dropout): Dropout(p=0.0, inplace=False)
      )
    )
  )
  (dropout): Dropout(p=0.0, inplace=False)
  (ln_f): LayerNorm(

In [37]:
# for idx, (module, param) in enumerate(k.named_parameters()):
#    print(f"{module}.{idx} parameter count: {sum(p.numel() for p in param)}")

for module, param in k.named_parameters():
    info = module.split(".")
    if len(info) > 2:
        
    print(f"{module} : {sum(p.numel() for p in param)}")

IndentationError: expected an indented block after 'if' statement on line 6 (2775956006.py, line 8)

In [39]:
from collections import defaultdict
d = defaultdict(lambda: defaultdict(int))
for module, param in k.named_parameters():
    info = module.split(".")
    if len(info) > 2:
        if len(info) == 5:
            d[info[2]][info[3]] += param.numel()
        elif len(info) == 4:
            d[info[2]][""] += param.numel()

    else:
        d[info[0]][""] += param.numel()
    #print(f"{module} : {sum(p.numel() for p in param)}")



In [40]:
for module_name, sub_dict in d.items():
    print(f"Module: {module_name}")
    
    for param_name, count in sub_dict.items():
        print(f"  {param_name}: {count}")

Module: token_embeddings
  : 38597376
Module: position_embeddings
  : 786432
Module: preln
  : 9216
Module: attention
  q_proj: 7077888
  k_proj: 7077888
  v_proj: 7077888
  projection: 7077888
Module: postln
  : 9216
Module: mlp
  c_fc: 28311552
  c_proj: 28311552
Module: ln_f
  : 768


In [41]:
d["attention"]

defaultdict(int,
            {'q_proj': 7077888,
             'k_proj': 7077888,
             'v_proj': 7077888,
             'projection': 7077888})

In [1]:
import lightning as L

In [1]:
from huggingface_hub import snapshot_download

snapshot_download(repo_id='openai-community/gpt2', allow_patterns='*.safetensors')

Fetching 1 files:   0%|          | 0/1 [00:00<?, ?it/s]

model.safetensors:   0%|          | 0.00/548M [00:00<?, ?B/s]

'/home/tororo.in/.cache/huggingface/hub/models--openai-community--gpt2/snapshots/607a30d783dfa663caf39e06633721c8d4cfcd7e'

In [2]:
from transformers import AutoTokenizer, AutoModelForCausalLM

hf_model = "openai-community/gpt2"
tokenizer = AutoTokenizer.from_pretrained(hf_model)
hf_gpt2 = AutoModelForCausalLM.from_pretrained(hf_model)

tokenizer_config.json:   0%|          | 0.00/26.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/665 [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/1.04M [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.36M [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/124 [00:00<?, ?B/s]

In [3]:
print(hf_gpt2)

GPT2LMHeadModel(
  (transformer): GPT2Model(
    (wte): Embedding(50257, 768)
    (wpe): Embedding(1024, 768)
    (drop): Dropout(p=0.1, inplace=False)
    (h): ModuleList(
      (0-11): 12 x GPT2Block(
        (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (attn): GPT2Attention(
          (c_attn): Conv1D()
          (c_proj): Conv1D()
          (attn_dropout): Dropout(p=0.1, inplace=False)
          (resid_dropout): Dropout(p=0.1, inplace=False)
        )
        (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (mlp): GPT2MLP(
          (c_fc): Conv1D()
          (c_proj): Conv1D()
          (act): NewGELUActivation()
          (dropout): Dropout(p=0.1, inplace=False)
        )
      )
    )
    (ln_f): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
  )
  (lm_head): Linear(in_features=768, out_features=50257, bias=False)
)


In [6]:
c_attn_weight = hf_gpt2.transformer.h[0].attn.c_attn.weight

In [7]:
c_attn_weight.shape

torch.Size([768, 2304])

In [51]:
hf_q, hf_k, hf_v = torch.split(c_attn_weight, 768, dim = -1)
hf_q.shape

torch.Size([768, 768])

In [10]:
import torch 
hidden_size = hf_gpt2.config.hidden_size
input_ids = torch.randint(0, hf_gpt2.config.vocab_size, (2, 10))

input_embed = hf_gpt2.transformer.wte(input_ids)
input_embed.shape

torch.Size([2, 10, 768])

In [52]:
# passing embd through first layer of hf gpt2 model

hf_0 = hf_gpt2.transformer.h[0] # layer 0

hf_qkv = hf_0.attn.c_attn(input_embed)
hf_qkv.shape # should be batch x length x 3 * d_model

torch.Size([2, 10, 2304])

In [53]:
out_hf_q, out_hf_k, out_hf_v = hf_qkv.split(768, dim = -1)

In [54]:
out_hf_q.shape

torch.Size([2, 10, 768])

In [106]:
with torch.no_grad():
    k.blocks[0].attention.q_proj.weight = torch.nn.Parameter(hf_q)
    k.blocks[0].attention.k_proj.weight = torch.nn.Parameter(hf_k)
    k.blocks[0].attention.v_proj.weight = torch.nn.Parameter(hf_v)

In [107]:
out_q_custom = k.blocks[0].attention.q_proj(input_embed)
out_k_custom = k.blocks[0].attention.k_proj(input_embed)
out_v_custom = k.blocks[0].attention.v_proj(input_embed)



In [108]:
assert torch.allclose(out_hf_q, out_q_custom, atol=1e-5), "Query outputs do not match!"
assert torch.allclose(out_hf_k, out_k_custom, atol=1e-5), "Key outputs do not match!"
assert torch.allclose(out_hf_v, out_v_custom, atol=1e-5), "Value outputs do not match!"

AssertionError: Query outputs do not match!

In [109]:
print(k.blocks[0].attention.q_proj.weight)

Parameter containing:
tensor([[-0.4738, -0.2614, -0.0978,  ...,  0.3237, -0.0483, -0.2235],
        [ 0.0874,  0.1473,  0.2387,  ..., -0.0770, -0.1492,  0.1507],
        [ 0.0039,  0.0695,  0.3668,  ..., -0.1235, -0.1660, -0.0480],
        ...,
        [-0.2592, -0.0164,  0.1991,  ..., -0.0335,  0.1455,  0.0333],
        [ 0.1517,  0.2170,  0.1043,  ...,  0.0827, -0.0533, -0.0071],
        [-0.4100, -0.1924, -0.2400,  ...,  0.2170,  0.1470, -0.0557]],
       requires_grad=True)


In [110]:
print(hf_q)

tensor([[-0.4738, -0.2614, -0.0978,  ...,  0.3237, -0.0483, -0.2235],
        [ 0.0874,  0.1473,  0.2387,  ..., -0.0770, -0.1492,  0.1507],
        [ 0.0039,  0.0695,  0.3668,  ..., -0.1235, -0.1660, -0.0480],
        ...,
        [-0.2592, -0.0164,  0.1991,  ..., -0.0335,  0.1455,  0.0333],
        [ 0.1517,  0.2170,  0.1043,  ...,  0.0827, -0.0533, -0.0071],
        [-0.4100, -0.1924, -0.2400,  ...,  0.2170,  0.1470, -0.0557]],
       grad_fn=<SplitBackward0>)


In [111]:
print(out_q_custom, out_q_custom.shape)

tensor([[[-1.1243,  1.8245, -0.1568,  ..., -0.1497, -0.3568,  0.9658],
         [-0.3063,  0.7632,  0.6724,  ...,  0.0888, -1.1617,  0.3162],
         [-0.9941, -0.0599,  1.4607,  ..., -1.6881,  0.9942,  2.1041],
         ...,
         [ 1.1804,  0.3327, -0.2058,  ..., -0.1540, -0.1996,  0.8665],
         [-0.8858,  0.5276,  0.6065,  ..., -0.8970,  1.4199, -0.4245],
         [-0.8390,  0.0107, -1.0630,  ..., -0.9640, -0.7801,  0.4529]],

        [[ 1.0096, -0.2976,  0.5638,  ..., -0.3491, -0.1468,  0.2377],
         [-0.9576, -1.4123,  0.5077,  ...,  0.1489,  0.3774,  0.0109],
         [-1.4285, -1.2211,  1.0694,  ...,  0.0785, -0.1604,  0.3087],
         ...,
         [ 0.0854, -0.5799,  1.0094,  ..., -0.5405, -0.3010,  0.2759],
         [ 0.2736, -0.5153,  0.7052,  ..., -0.3818, -0.3548, -1.7167],
         [ 0.1421, -0.2851, -0.1371,  ..., -1.1382, -0.1506, -0.9867]]],
       grad_fn=<UnsafeViewBackward0>) torch.Size([2, 10, 768])


In [112]:
print(out_hf_q, out_hf_q.shape)

tensor([[[ 0.5046, -0.1808, -2.2492,  ...,  0.0149,  0.4377, -1.7048],
         [-0.4218, -1.0439, -1.6336,  ...,  1.4847, -0.2523, -1.0871],
         [-1.7190,  0.4998, -2.7904,  ...,  1.1032,  0.0187, -1.6178],
         ...,
         [ 0.6493, -0.1796, -0.4503,  ...,  0.8349,  0.2548, -1.8746],
         [ 0.2289, -1.3614, -1.0591,  ...,  1.6146, -0.6533, -1.7017],
         [ 1.1903, -0.2325, -1.1019,  ...,  0.5890, -0.5762, -1.6670]],

        [[ 0.3174,  1.1568, -0.8205,  ...,  1.0714, -0.4499, -1.3563],
         [-0.4712,  0.0095, -1.8863,  ...,  0.8182,  1.2508, -1.9979],
         [ 0.1650, -0.3094, -0.0883,  ...,  1.3423,  0.1067, -1.1818],
         ...,
         [-1.1883,  0.7363, -0.6516,  ...,  1.4726, -0.3120, -1.2400],
         [ 0.2648,  1.5586, -1.9121,  ...,  1.3525,  0.2994, -1.8890],
         [-1.1054,  0.3067, -0.6996,  ...,  1.4577,  0.7222, -0.9795]]],
       grad_fn=<SplitBackward0>) torch.Size([2, 10, 768])


In [84]:
assert torch.allclose(k.blocks[0].attention.q_proj.weight.t(), hf_q), "Query weights mismatch!"


In [85]:
out_hf_q_manual = input_embed @ hf_q

print(out_hf_q_manual)

tensor([[[ 0.0243,  0.3447, -1.8199,  ..., -0.2612,  0.8607, -1.4671],
         [-0.9021, -0.5185, -1.2044,  ...,  1.2086,  0.1707, -0.8494],
         [-2.1993,  1.0252, -2.3611,  ...,  0.8270,  0.4417, -1.3800],
         ...,
         [ 0.1689,  0.3458, -0.0210,  ...,  0.5588,  0.6778, -1.6368],
         [-0.2514, -0.8359, -0.6299,  ...,  1.3385, -0.2302, -1.4640],
         [ 0.7100,  0.2929, -0.6726,  ...,  0.3129, -0.1532, -1.4293]],

        [[-0.1630,  1.6822, -0.3913,  ...,  0.7953, -0.0268, -1.1185],
         [-0.9515,  0.5349, -1.4570,  ...,  0.5421,  1.6738, -1.7602],
         [-0.3153,  0.2160,  0.3410,  ...,  1.0662,  0.5297, -0.9441],
         ...,
         [-1.6686,  1.2618, -0.2224,  ...,  1.1965,  0.1110, -1.0023],
         [-0.2155,  2.0840, -1.4828,  ...,  1.0764,  0.7224, -1.6513],
         [-1.5857,  0.8321, -0.2703,  ...,  1.1816,  1.1453, -0.7418]]],
       grad_fn=<UnsafeViewBackward0>)


In [90]:
print(torch.all(out_hf_q_manual - out_q_custom) == 0)

tensor(True)


In [94]:
assert torch.allclose(out_hf_q, out_q_custom), "Q mismatch"

AssertionError: Q mismatch

In [95]:
out_hf_q

tensor([[[ 0.5046, -0.1808, -2.2492,  ...,  0.0149,  0.4377, -1.7048],
         [-0.4218, -1.0439, -1.6336,  ...,  1.4847, -0.2523, -1.0871],
         [-1.7190,  0.4998, -2.7904,  ...,  1.1032,  0.0187, -1.6178],
         ...,
         [ 0.6493, -0.1796, -0.4503,  ...,  0.8349,  0.2548, -1.8746],
         [ 0.2289, -1.3614, -1.0591,  ...,  1.6146, -0.6533, -1.7017],
         [ 1.1903, -0.2325, -1.1019,  ...,  0.5890, -0.5762, -1.6670]],

        [[ 0.3174,  1.1568, -0.8205,  ...,  1.0714, -0.4499, -1.3563],
         [-0.4712,  0.0095, -1.8863,  ...,  0.8182,  1.2508, -1.9979],
         [ 0.1650, -0.3094, -0.0883,  ...,  1.3423,  0.1067, -1.1818],
         ...,
         [-1.1883,  0.7363, -0.6516,  ...,  1.4726, -0.3120, -1.2400],
         [ 0.2648,  1.5586, -1.9121,  ...,  1.3525,  0.2994, -1.8890],
         [-1.1054,  0.3067, -0.6996,  ...,  1.4577,  0.7222, -0.9795]]],
       grad_fn=<SplitBackward0>)