<a href="https://colab.research.google.com/github/samitha278/gpt2-lite/blob/main/repro_gpt2.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [15]:
import torch
import torch.nn as nn
import torch.nn.functional as F

from dataclasses import dataclass

from transformers import GPT2LMHeadModel , pipeline , set_seed



In [None]:
model = GPT2LMHeadModel.from_pretrained("gpt2")
params_dict = model.state_dict()

for k,v in params_dict.items():
  print(k,v.shape)

transformer.wte.weight torch.Size([50257, 768])
transformer.wpe.weight torch.Size([1024, 768])
transformer.h.0.ln_1.weight torch.Size([768])
transformer.h.0.ln_1.bias torch.Size([768])
transformer.h.0.attn.c_attn.weight torch.Size([768, 2304])
transformer.h.0.attn.c_attn.bias torch.Size([2304])
transformer.h.0.attn.c_proj.weight torch.Size([768, 768])
transformer.h.0.attn.c_proj.bias torch.Size([768])
transformer.h.0.ln_2.weight torch.Size([768])
transformer.h.0.ln_2.bias torch.Size([768])
transformer.h.0.mlp.c_fc.weight torch.Size([768, 3072])
transformer.h.0.mlp.c_fc.bias torch.Size([3072])
transformer.h.0.mlp.c_proj.weight torch.Size([3072, 768])
transformer.h.0.mlp.c_proj.bias torch.Size([768])
transformer.h.1.ln_1.weight torch.Size([768])
transformer.h.1.ln_1.bias torch.Size([768])
transformer.h.1.attn.c_attn.weight torch.Size([768, 2304])
transformer.h.1.attn.c_attn.bias torch.Size([2304])
transformer.h.1.attn.c_proj.weight torch.Size([768, 768])
transformer.h.1.attn.c_proj.bias 

In [None]:
set_seed(220064)
gen = pipeline('text-generation' , model = 'gpt2')
prompt = "Who is the first president is sri lanka?"
gen(prompt)

Device set to use cpu
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


[{'generated_text': 'Who is the first president is sri lanka? sri sri lanka?\n\nPresident\n\npresident, sri, president, sri, president, president, president, president, president, president, president, president, president, president, president, president, president, president, president, president, president, president, president, president, president, president, president, president, president, president, president, president, president, president, president, president, president, president, president, president, president, president, president, president, president, president, president, president, president, president, president, president, president, president, president, president, president, president, president, president, president, president, president, president, president, president, president, president, president, president, president, president, president, president, president, president, president, president, president, president, president, president, president, presid

### test

In [9]:
temp = torch.tensor([
    [[1,2,3],
    [4,5,6]],

    [[7,8,9],
    [12,11,10]],

    ])

print(temp@temp.transpose(-2,-1))

tensor([[[ 14,  32],
         [ 32,  77]],

        [[194, 262],
         [262, 365]]])


### Configurations

In [66]:
@dataclass
class GPT2Config:
  batch_size = 32
  block_size = 8
  n_head = 4
  n_layer = 6
  n_embd = 128

## Self Attention Head class + Multi Head class -> Single class

In [67]:
class SelfAttentionHead(nn.Module):


  def __init__(self,config):
    super().__init__()

    block_size = config.block_size
    n_embd = config.n_embd
    n_head = config.n_head
    head_size = n_embd // n_head


    self.key = nn.Linear(n_embd,head_size)
    self.query = nn.Linear(n_embd,head_size)
    self.value = nn.Linear(n_embd,head_size)
    self.register_buffer('tril' , torch.tril(torch.ones(block_size,block_size)))



  def forward(self,x):

    key = self.key(x)
    query = self.query(x)

    weight = query @ key.transpose(-1,-2)
    weight = weight.masked_fill(self.tril[:]==0,float('-inf'))
    weight = F.softmax(weight,dim = -1)

    value = self.value(x)

    out = weight @ value

    return out



In [68]:
class MultiHead(nn.Module):

  def __init__(self,config):
    super().__init__()

    n_head = config.n_head
    n_embd = config.n_embd


    self.self_attns = nn.ModuleList([SelfAttentionHead(config)  for i in range(n_head)])
    self.projection = nn.Linear(n_embd,n_embd)


  def forward(self,x):

    out = torch.cat([sa(x) for sa in self.self_atts])
    out = self.projection(out)

    return out





In [69]:
config = GPT2Config()
h = MultiHead(config)
dic = h.state_dict()
print(h)
for k,v, in dic.items():
  print(k,v.shape)

MultiHead(
  (self_attns): ModuleList(
    (0-3): 4 x SelfAttentionHead(
      (key): Linear(in_features=128, out_features=32, bias=True)
      (query): Linear(in_features=128, out_features=32, bias=True)
      (value): Linear(in_features=128, out_features=32, bias=True)
    )
  )
  (projection): Linear(in_features=128, out_features=128, bias=True)
)
self_attns.0.tril torch.Size([8, 8])
self_attns.0.key.weight torch.Size([32, 128])
self_attns.0.key.bias torch.Size([32])
self_attns.0.query.weight torch.Size([32, 128])
self_attns.0.query.bias torch.Size([32])
self_attns.0.value.weight torch.Size([32, 128])
self_attns.0.value.bias torch.Size([32])
self_attns.1.tril torch.Size([8, 8])
self_attns.1.key.weight torch.Size([32, 128])
self_attns.1.key.bias torch.Size([32])
self_attns.1.query.weight torch.Size([32, 128])
self_attns.1.query.bias torch.Size([32])
self_attns.1.value.weight torch.Size([32, 128])
self_attns.1.value.bias torch.Size([32])
self_attns.2.tril torch.Size([8, 8])
self_attns

In [70]:
sa = SelfAttentionHead(config)
dic = sa.state_dict()

print(sa)
for k,v, in dic.items():
  print(k,v.shape)

SelfAttentionHead(
  (key): Linear(in_features=128, out_features=32, bias=True)
  (query): Linear(in_features=128, out_features=32, bias=True)
  (value): Linear(in_features=128, out_features=32, bias=True)
)
tril torch.Size([8, 8])
key.weight torch.Size([32, 128])
key.bias torch.Size([32])
query.weight torch.Size([32, 128])
query.bias torch.Size([32])
value.weight torch.Size([32, 128])
value.bias torch.Size([32])


In [71]:
x = torch.randn((32,8,128))
sa(x).shape

torch.Size([32, 8, 32])