In [177]:
!pip install einops -q

In [178]:
import torch
import torch.nn as nn
from transformers import GPT2LMHeadModel, GPT2Tokenizer
from types import SimpleNamespace
from einops import rearrange
import torch.nn.functional as F

In [179]:
gpt2 = GPT2LMHeadModel.from_pretrained('gpt2')

In [180]:
gpt2

GPT2LMHeadModel(
  (transformer): GPT2Model(
    (wte): Embedding(50257, 768)
    (wpe): Embedding(1024, 768)
    (drop): Dropout(p=0.1, inplace=False)
    (h): ModuleList(
      (0-11): 12 x GPT2Block(
        (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (attn): GPT2Attention(
          (c_attn): Conv1D()
          (c_proj): Conv1D()
          (attn_dropout): Dropout(p=0.1, inplace=False)
          (resid_dropout): Dropout(p=0.1, inplace=False)
        )
        (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (mlp): GPT2MLP(
          (c_fc): Conv1D()
          (c_proj): Conv1D()
          (act): NewGELUActivation()
          (dropout): Dropout(p=0.1, inplace=False)
        )
      )
    )
    (ln_f): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
  )
  (lm_head): Linear(in_features=768, out_features=50257, bias=False)
)

In [181]:
sd = gpt2.state_dict()

print('transformer')
for key in sd.keys():
    if 'transformer' in key:
        print('\t',key.replace('transformer.',''))
    else:
        print(key)

transformer
	 wte.weight
	 wpe.weight
	 h.0.ln_1.weight
	 h.0.ln_1.bias
	 h.0.attn.c_attn.weight
	 h.0.attn.c_attn.bias
	 h.0.attn.c_proj.weight
	 h.0.attn.c_proj.bias
	 h.0.ln_2.weight
	 h.0.ln_2.bias
	 h.0.mlp.c_fc.weight
	 h.0.mlp.c_fc.bias
	 h.0.mlp.c_proj.weight
	 h.0.mlp.c_proj.bias
	 h.1.ln_1.weight
	 h.1.ln_1.bias
	 h.1.attn.c_attn.weight
	 h.1.attn.c_attn.bias
	 h.1.attn.c_proj.weight
	 h.1.attn.c_proj.bias
	 h.1.ln_2.weight
	 h.1.ln_2.bias
	 h.1.mlp.c_fc.weight
	 h.1.mlp.c_fc.bias
	 h.1.mlp.c_proj.weight
	 h.1.mlp.c_proj.bias
	 h.2.ln_1.weight
	 h.2.ln_1.bias
	 h.2.attn.c_attn.weight
	 h.2.attn.c_attn.bias
	 h.2.attn.c_proj.weight
	 h.2.attn.c_proj.bias
	 h.2.ln_2.weight
	 h.2.ln_2.bias
	 h.2.mlp.c_fc.weight
	 h.2.mlp.c_fc.bias
	 h.2.mlp.c_proj.weight
	 h.2.mlp.c_proj.bias
	 h.3.ln_1.weight
	 h.3.ln_1.bias
	 h.3.attn.c_attn.weight
	 h.3.attn.c_attn.bias
	 h.3.attn.c_proj.weight
	 h.3.attn.c_proj.bias
	 h.3.ln_2.weight
	 h.3.ln_2.bias
	 h.3.mlp.c_fc.weight
	 h.3.mlp.c_fc.bias


In [182]:
gpt2.transformer.h[0]

GPT2Block(
  (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
  (attn): GPT2Attention(
    (c_attn): Conv1D()
    (c_proj): Conv1D()
    (attn_dropout): Dropout(p=0.1, inplace=False)
    (resid_dropout): Dropout(p=0.1, inplace=False)
  )
  (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
  (mlp): GPT2MLP(
    (c_fc): Conv1D()
    (c_proj): Conv1D()
    (act): NewGELUActivation()
    (dropout): Dropout(p=0.1, inplace=False)
  )
)

In [183]:
gpt2.transformer.h[0].mlp.state_dict().keys()

odict_keys(['c_fc.weight', 'c_fc.bias', 'c_proj.weight', 'c_proj.bias'])

In [184]:
gpt2.transformer.h[0].ln_1.state_dict().keys()

odict_keys(['weight', 'bias'])

In [185]:
gpt2.transformer.wpe

Embedding(1024, 768)

In [186]:
def params(m):
    return sum([p.numel() for p in m.parameters()])

In [187]:
params(gpt2)

124439808

In [188]:
class GPT2Attention(nn.Module):
    def __init__(self,config):
        super().__init__()
        self.embed_dim = config.embed_dim
        self.n_heads = config.num_heads
        assert self.embed_dim % self.n_heads == 0, 'embedding dimension by be divisible by number of heads'
        self.head_size = self.embed_dim // self.n_heads
        self.seq_len = config.seq_len
        
        self.c_attn = nn.Linear(self.embed_dim, self.head_size * self.n_heads * 3,bias=True)
        self.scale = self.head_size ** -0.5
        
        self.register_buffer('mask',torch.tril(torch.ones(1,1,self.seq_len,self.seq_len)))
        
        self.c_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=True)
        
        self.attn_dropout = nn.Dropout(config.attention_dropout)
        self.resid_dropout = nn.Dropout(config.residual_dropout)
        
        
    def forward(self, x):
        b,t,c = x.shape
        # q,k,v shape individually: batch_size x seq_len x embed_dim
        # we know that qk_t = q x k_t, where q=bxtxhead_dim, k_t=bxhead_timxt
        q,k,v = self.c_attn(x).chunk(3,dim=-1)
        q = rearrange(q,'b t (h n) -> b n t h',n=self.n_heads) # h = head_size
        k = rearrange(k,'b t (h n) -> b n t h',n=self.n_heads)
        v = rearrange(v,'b t (h n) -> b n t h',n=self.n_heads)
        
        # qk_t = einsum(q,k,'b n t1 h, b n t2 h -> b n t1 t2') * self.scale
        qk_t = (q@k.transpose(-2,-1)) * self.scale
        # fun fact, limit mask to [:,:,:t,:t] else short prompts will not work
        qk_t = qk_t.masked_fill(self.mask[:,:,:t,:t]==0,float('-inf'))
        qk_t = F.softmax(qk_t,dim=-1)
        
        weights = self.attn_dropout(qk_t)
        
        attention = weights @ v # batch x n_heads x seq_len x head_size
        attention = rearrange(attention,'b n t h -> b t (n h)') # batch x n_heads x seq_len x embed_dim
        
        out = self.c_proj(attention)
        out = self.resid_dropout(out)
        
        return out
    

class GPT2MLP(nn.Module):
    def __init__(self,config):
        super().__init__()
        self.embed_dim = config.embed_dim
        self.mlp_ratio = config.mlp_ratio
        self.mlp_dropout = config.mlp_dropout
        
        self.c_fc = nn.Linear(self.embed_dim,self.embed_dim*self.mlp_ratio)
        self.c_proj = nn.Linear(self.embed_dim*self.mlp_ratio,self.embed_dim)
        self.act = nn.GELU()
        self.dropout = nn.Dropout(self.mlp_dropout)
        
    def forward(self,x):
        x = self.c_fc(x)
        x = self.act(x)
        x = self.c_proj(x)
        x = self.dropout(x)
        return x
    

class GPT2Block(nn.Module):
    def __init__(self,config):
        super().__init__()
        self.embed_dim = config.embed_dim
        self.ln_1 = nn.LayerNorm(self.embed_dim)
        self.attn = GPT2Attention(config)
        self.ln_2 = nn.LayerNorm(self.embed_dim)
        self.mlp = GPT2MLP(config)
        
    def forward(self,x):
        x = x+self.attn(self.ln_1(x))
        x = x+self.mlp(self.ln_2(x))
        return x
    

class GPT2Model(nn.Module):
    def __init__(self,config):
        super().__init__()
        
        self.config = config
        
        self.transformer = nn.ModuleDict(dict(
            wte = nn.Embedding(config.vocab_size,config.embed_dim),
            wpe = nn.Embedding(config.seq_len,config.embed_dim),
            drop = nn.Dropout(config.emb_dropout),
            h = nn.ModuleList([GPT2Block(config) for _ in range(config.depth)]),
            ln_f = nn.LayerNorm(config.embed_dim)
        ))
        self.lm_head = nn.Linear(config.embed_dim,config.vocab_size,bias=False)
        self.transformer.wte.weight = self.lm_head.weight
        
    def forward(self,x):
        b,t = x.shape
        token_embeddings = self.transformer.wte(x) # batch x seq_len
        pos_embs = torch.arange(0,x.size(1)).to(x.device)
        positional_embeddings = self.transformer.wpe(pos_embs)
        x = self.transformer.drop(token_embeddings+positional_embeddings)
        for h in self.transformer.h:
            x = h(x) # batch_size x seq_len x embed_dim
        x = self.transformer.ln_f(x)[:,[-1],:] # get last hidden state: batch_size x 1 x embed_dim
        x = self.lm_head(x) # batch_size x vocab_size
        
        return x
    
    @torch.no_grad()
    def generate(self,idx,max_new_tokens=5,temperature=1.0):
        
        for _ in range(max_new_tokens):
            
            inp = idx if idx.size(1) <= self.config.seq_len else inp[:,-self.config.seq_len:]
            out = self(inp)
            out = out[:, -1, :] / temperature
            probs = F.softmax(out, dim=-1)
            idx_next = torch.multinomial(probs, num_samples=1)
            idx = torch.cat((idx, idx_next), dim=1)
            
        return idx

In [189]:
gpt2_small = SimpleNamespace(
        vocab_size = 50_257,
        embed_dim = 768,
        num_heads = 12,
        seq_len = 1024,
        depth = 12,
        attention_dropout = 0.1,
        residual_dropout = 0.1,
        mlp_ratio = 4,
        mlp_dropout = 0.1,
        emb_dropout = 0.1,
    )

In [190]:
mygpt = GPT2Model(gpt2_small)

In [191]:
params(mygpt), params(gpt2)

(124439808, 124439808)

In [192]:
my_unique_keys = set(mygpt.state_dict().keys())
trained_unique_keys = set(gpt2.state_dict().keys())
my_unique_keys - trained_unique_keys

{'transformer.h.0.attn.mask',
 'transformer.h.1.attn.mask',
 'transformer.h.10.attn.mask',
 'transformer.h.11.attn.mask',
 'transformer.h.2.attn.mask',
 'transformer.h.3.attn.mask',
 'transformer.h.4.attn.mask',
 'transformer.h.5.attn.mask',
 'transformer.h.6.attn.mask',
 'transformer.h.7.attn.mask',
 'transformer.h.8.attn.mask',
 'transformer.h.9.attn.mask'}

In [193]:
@torch.no_grad()
def copy_state_dict(my_sd, hf_sd):
    sd_keys_my = [k for k in hf_sd.keys() if not k.endswith('.attn.mask')]
    sd_keys_hf = hf_sd.keys()
    sd_keys_hf = [k for k in sd_keys_hf if not k.endswith('.attn.masked_bias')] # ignore these, just a buffer
    sd_keys_hf = [k for k in sd_keys_hf if not k.endswith('.attn.bias')] # same, just the mask (buffer)
    transposed = ['attn.c_attn.weight', 'attn.c_proj.weight', 'mlp.c_fc.weight', 'mlp.c_proj.weight']
    
    len_hf = len(sd_keys_hf)
    len_my = len(sd_keys_my)
    
    assert len_hf == len_my
    
    for key in sd_keys_hf:
        if any(key.endswith(w) for w in transposed):
            print('transposed',key,my_sd[key].shape,hf_sd[key].shape[::-1])
            assert my_sd[key].shape == hf_sd[key].shape[::-1]
            my_sd[key].copy_(hf_sd[key].t())
        else:
            print('normal',key,my_sd[key].shape,hf_sd[key].shape)
            assert my_sd[key].shape == hf_sd[key].shape
            my_sd[key].copy_(hf_sd[key])
            
    return my_sd

In [194]:
copied_sd = copy_state_dict(mygpt.state_dict(),gpt2.state_dict())

normal transformer.wte.weight torch.Size([50257, 768]) torch.Size([50257, 768])
normal transformer.wpe.weight torch.Size([1024, 768]) torch.Size([1024, 768])
normal transformer.h.0.ln_1.weight torch.Size([768]) torch.Size([768])
normal transformer.h.0.ln_1.bias torch.Size([768]) torch.Size([768])
transposed transformer.h.0.attn.c_attn.weight torch.Size([2304, 768]) torch.Size([2304, 768])
normal transformer.h.0.attn.c_attn.bias torch.Size([2304]) torch.Size([2304])
transposed transformer.h.0.attn.c_proj.weight torch.Size([768, 768]) torch.Size([768, 768])
normal transformer.h.0.attn.c_proj.bias torch.Size([768]) torch.Size([768])
normal transformer.h.0.ln_2.weight torch.Size([768]) torch.Size([768])
normal transformer.h.0.ln_2.bias torch.Size([768]) torch.Size([768])
transposed transformer.h.0.mlp.c_fc.weight torch.Size([3072, 768]) torch.Size([3072, 768])
normal transformer.h.0.mlp.c_fc.bias torch.Size([3072]) torch.Size([3072])
transposed transformer.h.0.mlp.c_proj.weight torch.Size(

In [195]:
mygpt.load_state_dict(copied_sd)

<All keys matched successfully>