<a href="https://colab.research.google.com/github/samitha278/gpt2-lite/blob/main/repro_gpt2.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F

from dataclasses import dataclass

from transformers import GPT2LMHeadModel , pipeline , set_seed


In [None]:
model = GPT2LMHeadModel.from_pretrained("gpt2")
params_dict = model.state_dict()

for k,v in params_dict.items():
  print(k,v.shape)

transformer.wte.weight torch.Size([50257, 768])
transformer.wpe.weight torch.Size([1024, 768])
transformer.h.0.ln_1.weight torch.Size([768])
transformer.h.0.ln_1.bias torch.Size([768])
transformer.h.0.attn.c_attn.weight torch.Size([768, 2304])
transformer.h.0.attn.c_attn.bias torch.Size([2304])
transformer.h.0.attn.c_proj.weight torch.Size([768, 768])
transformer.h.0.attn.c_proj.bias torch.Size([768])
transformer.h.0.ln_2.weight torch.Size([768])
transformer.h.0.ln_2.bias torch.Size([768])
transformer.h.0.mlp.c_fc.weight torch.Size([768, 3072])
transformer.h.0.mlp.c_fc.bias torch.Size([3072])
transformer.h.0.mlp.c_proj.weight torch.Size([3072, 768])
transformer.h.0.mlp.c_proj.bias torch.Size([768])
transformer.h.1.ln_1.weight torch.Size([768])
transformer.h.1.ln_1.bias torch.Size([768])
transformer.h.1.attn.c_attn.weight torch.Size([768, 2304])
transformer.h.1.attn.c_attn.bias torch.Size([2304])
transformer.h.1.attn.c_proj.weight torch.Size([768, 768])
transformer.h.1.attn.c_proj.bias 

In [None]:
set_seed(220064)
gen = pipeline('text-generation' , model = 'gpt2')
prompt = "Who is the first president is sri lanka?"
gen(prompt)

Device set to use cpu
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


[{'generated_text': 'Who is the first president is sri lanka? sri sri lanka?\n\nPresident\n\npresident, sri, president, sri, president, president, president, president, president, president, president, president, president, president, president, president, president, president, president, president, president, president, president, president, president, president, president, president, president, president, president, president, president, president, president, president, president, president, president, president, president, president, president, president, president, president, president, president, president, president, president, president, president, president, president, president, president, president, president, president, president, president, president, president, president, president, president, president, president, president, president, president, president, president, president, president, president, president, president, president, president, president, president, presid

### test

In [9]:
temp = torch.tensor([
    [[1,2,3],
    [4,5,6]],

    [[7,8,9],
    [12,11,10]],

    ])

print(temp@temp.transpose(-2,-1))

tensor([[[ 14,  32],
         [ 32,  77]],

        [[194, 262],
         [262, 365]]])


### Configurations

In [29]:
@dataclass
class GPT2Config:
  batch_size = 32
  block_size = 8
  n_head = 4
  n_layer = 6
  n_embd = 128

## Attention

In [89]:
class SelfAttentionHead(nn.Module):


  def __init__(self,config):
    super().__init__()

    block_size = config.block_size
    n_embd = config.n_embd
    n_head = config.n_head
    head_size = n_embd // n_head

    self.head_size = head_size


    self.key = nn.Linear(n_embd,head_size)
    self.query = nn.Linear(n_embd,head_size)
    self.value = nn.Linear(n_embd,head_size)
    self.register_buffer('tril' , torch.tril(torch.ones(block_size,block_size)))



  def forward(self,x):

    key = self.key(x)
    query = self.query(x)

    weight = (query @ key.transpose(-1,-2)  ) * (self.head_size**-0.5)
    weight = weight.masked_fill(self.tril[:]==0,float('-inf'))
    weight = F.softmax(weight,dim = -1)

    value = self.value(x)

    out = weight @ value

    return out



In [90]:
class MultiHead(nn.Module):

  def __init__(self,config):
    super().__init__()

    n_head = config.n_head
    n_embd = config.n_embd


    self.self_attns = nn.ModuleList([SelfAttentionHead(config)  for i in range(n_head)])
    self.projection = nn.Linear(n_embd,n_embd)


  def forward(self,x):

    out = torch.cat([sa(x) for sa in self.self_attns],dim=-1)
    out = self.projection(out)

    return out





In [91]:
config = GPT2Config()
h = MultiHead(config)
dic = h.state_dict()
print(h)
for k,v, in dic.items():
  print(k,v.shape)

MultiHead(
  (self_attns): ModuleList(
    (0-3): 4 x SelfAttentionHead(
      (key): Linear(in_features=128, out_features=32, bias=True)
      (query): Linear(in_features=128, out_features=32, bias=True)
      (value): Linear(in_features=128, out_features=32, bias=True)
    )
  )
  (projection): Linear(in_features=128, out_features=128, bias=True)
)
self_attns.0.tril torch.Size([8, 8])
self_attns.0.key.weight torch.Size([32, 128])
self_attns.0.key.bias torch.Size([32])
self_attns.0.query.weight torch.Size([32, 128])
self_attns.0.query.bias torch.Size([32])
self_attns.0.value.weight torch.Size([32, 128])
self_attns.0.value.bias torch.Size([32])
self_attns.1.tril torch.Size([8, 8])
self_attns.1.key.weight torch.Size([32, 128])
self_attns.1.key.bias torch.Size([32])
self_attns.1.query.weight torch.Size([32, 128])
self_attns.1.query.bias torch.Size([32])
self_attns.1.value.weight torch.Size([32, 128])
self_attns.1.value.bias torch.Size([32])
self_attns.2.tril torch.Size([8, 8])
self_attns

In [92]:
sa = SelfAttentionHead(config)
dic = sa.state_dict()

print(sa)
for k,v, in dic.items():
  print(k,v.shape)

SelfAttentionHead(
  (key): Linear(in_features=128, out_features=32, bias=True)
  (query): Linear(in_features=128, out_features=32, bias=True)
  (value): Linear(in_features=128, out_features=32, bias=True)
)
tril torch.Size([8, 8])
key.weight torch.Size([32, 128])
key.bias torch.Size([32])
query.weight torch.Size([32, 128])
query.bias torch.Size([32])
value.weight torch.Size([32, 128])
value.bias torch.Size([32])


In [93]:
x = torch.randn((32,8,128))
sa(x).shape

torch.Size([32, 8, 32])

In [94]:
h(x).shape

torch.Size([32, 8, 128])

## Self Attention Head class + Multi Head class -> Single class

In [106]:
torch.manual_seed(278)
out = torch.randint(5,(32,4,4,3))
out[0][0],out[0][1],out[0][2],out[0][3]

(tensor([[0, 4, 0],
         [0, 1, 1],
         [0, 0, 4],
         [4, 2, 4]]),
 tensor([[4, 0, 4],
         [0, 2, 4],
         [2, 2, 4],
         [3, 2, 2]]),
 tensor([[0, 2, 4],
         [2, 3, 3],
         [0, 1, 4],
         [1, 4, 4]]),
 tensor([[3, 1, 1],
         [3, 1, 4],
         [3, 4, 4],
         [3, 0, 3]]))

In [105]:
B,nh,T,C = out.shape
out = out.permute(0,2,1,3)
out = out.reshape(B,T,nh*C)
out[0]

tensor([[0, 4, 0, 0, 2, 4, 1, 0, 0, 4, 0, 0],
        [0, 1, 1, 2, 3, 3, 1, 2, 2, 4, 3, 1],
        [0, 0, 4, 0, 1, 4, 4, 2, 4, 2, 1, 2],
        [4, 2, 4, 1, 4, 4, 0, 3, 3, 3, 1, 1],
        [4, 0, 4, 3, 1, 1, 1, 0, 1, 2, 4, 1],
        [0, 2, 4, 3, 1, 4, 2, 4, 4, 2, 2, 1],
        [2, 2, 4, 3, 4, 4, 2, 4, 3, 2, 2, 1],
        [3, 2, 2, 3, 0, 3, 0, 4, 1, 3, 1, 0]])

In [26]:
key = nn.ModuleList(nn.Linear(128,32) for _ in range(4))
x = torch.randn((32,8,128))
out1 = torch.stack([k(x) for k in key],dim=1)
out2 = torch.stack([k(x) for k in key],dim=1)
out = out1 @ out2.transpose(-1,-2)
out.shape

torch.Size([32, 4, 8, 8])

In [82]:
class Attention(nn.Module):


  def __init__(self,config):
    super().__init__()

    block_size = config.block_size
    n_embd = config.n_embd
    n_head = config.n_head
    head_size = n_embd // n_head
    self.head_size = head_size



    self.key = nn.ModuleList(nn.Linear(n_embd,head_size) for _ in range(n_head))
    self.query = nn.ModuleList(nn.Linear(n_embd,head_size) for _ in range(n_head))
    self.value = nn.ModuleList(nn.Linear(n_embd,head_size) for _ in range(n_head))

    self.register_buffer('tril',torch.tril(torch.ones(block_size,block_size)))


  def forward(self,x):


    key = torch.stack([k(x) for k in self.key],dim=1)
    query = torch.stack([q(x) for q in self.query],dim=1)

    weight = query @ key.transpose(-1,-2)
    weight = weight.masked_fill(self.tril[:]==0,float('-inf'))
    weight = F.softmax(weight,dim=-1)

    value = torch.stack([v(x) for v in self.value],dim=1)

    out = weight @ value

    B,nh,T,C = out.shape
    out = out.permute(0,2,1,3)
    out = out.reshape(B,T,nh*C)

    return out

In [83]:
config = GPT2Config()
h = Attention(config)
dic = h.state_dict()
print(h)
for k,v, in dic.items():
  print(k,v.shape)

Attention(
  (key): ModuleList(
    (0-3): 4 x Linear(in_features=128, out_features=32, bias=True)
  )
  (query): ModuleList(
    (0-3): 4 x Linear(in_features=128, out_features=32, bias=True)
  )
  (value): ModuleList(
    (0-3): 4 x Linear(in_features=128, out_features=32, bias=True)
  )
)
tril torch.Size([8, 8])
key.0.weight torch.Size([32, 128])
key.0.bias torch.Size([32])
key.1.weight torch.Size([32, 128])
key.1.bias torch.Size([32])
key.2.weight torch.Size([32, 128])
key.2.bias torch.Size([32])
key.3.weight torch.Size([32, 128])
key.3.bias torch.Size([32])
query.0.weight torch.Size([32, 128])
query.0.bias torch.Size([32])
query.1.weight torch.Size([32, 128])
query.1.bias torch.Size([32])
query.2.weight torch.Size([32, 128])
query.2.bias torch.Size([32])
query.3.weight torch.Size([32, 128])
query.3.bias torch.Size([32])
value.0.weight torch.Size([32, 128])
value.0.bias torch.Size([32])
value.1.weight torch.Size([32, 128])
value.1.bias torch.Size([32])
value.2.weight torch.Size([32

In [84]:
x = torch.randn((32,8,128))

In [85]:
out = h(x)

In [86]:
out.shape

torch.Size([32, 8, 128])