In [54]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
from einops.layers.torch import Rearrange
from types import SimpleNamespace

In [55]:
def params(m):
    return sum([p.numel() for p in m.parameters()])

In [56]:
class GPT2Attention(nn.Module):
    def __init__(self,config):
        super().__init__()
        self.embed_dim = config.embed_dim
        self.n_heads = config.num_heads
        assert self.embed_dim % self.n_heads == 0, 'embedding dimension by be divisible by number of heads'
        self.head_size = self.embed_dim // self.n_heads
        self.seq_len = config.seq_len
        
        self.c_attn = nn.Linear(self.embed_dim, self.head_size * self.n_heads * 3,bias=True)
        self.scale = self.head_size ** -0.5
        
        self.register_buffer('mask',torch.tril(torch.ones(1,1,self.seq_len,self.seq_len)))
        
        self.c_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=True)
        
        self.attn_dropout = nn.Dropout(config.attention_dropout)
        self.resid_dropout = nn.Dropout(config.residual_dropout)
        
        
    def forward(self, x):
        b,t,c = x.shape
        # q,k,v shape individually: batch_size x seq_len x embed_dim
        # we know that qk_t = q x k_t, where q=bxtxhead_dim, k_t=bxhead_timxt
        q,k,v = self.c_attn(x).chunk(3,dim=-1)
        q = rearrange(q,'b t (h n) -> b n t h',n=self.n_heads) # h = head_size
        k = rearrange(k,'b t (h n) -> b n t h',n=self.n_heads)
        v = rearrange(v,'b t (h n) -> b n t h',n=self.n_heads)
        
        # qk_t = einsum(q,k,'b n t1 h, b n t2 h -> b n t1 t2') * self.scale
        qk_t = (q@k.transpose(-2,-1)) * self.scale
        # fun fact, limit mask to [:,:,:t,:t] else short prompts will not work
        qk_t = qk_t.masked_fill(self.mask==0,float('-inf'))
        qk_t = F.softmax(qk_t,dim=-1)
        
        weights = self.attn_dropout(qk_t)
        
        attention = weights @ v # batch x n_heads x seq_len x head_size
        attention = rearrange(attention,'b n t h -> b t (n h)') # batch x n_heads x seq_len x embed_dim
        
        out = self.c_proj(attention)
        out = self.resid_dropout(out)
        
        return out

In [57]:
config = SimpleNamespace(
    vocab_size = 50_257,
    embed_dim = 768,
    num_heads = 12,
    seq_len = 1024,
    depth = 12,
    attention_dropout = 0.1,
    residual_dropout = 0.1,
    mlp_ratio = 4,
    mlp_dropout = 0.1,
    emb_dropout = 0.1,
) # gpt2 small
config

namespace(vocab_size=50257,
          embed_dim=768,
          num_heads=12,
          seq_len=1024,
          depth=12,
          attention_dropout=0.1,
          residual_dropout=0.1,
          mlp_ratio=4,
          mlp_dropout=0.1,
          emb_dropout=0.1)

In [58]:
attn = GPT2Attention(config)
x = torch.rand(1,config.seq_len,config.embed_dim)
x.shape, params(attn)

(torch.Size([1, 1024, 768]), 2362368)

In [59]:
attn.state_dict().keys()

odict_keys(['mask', 'c_attn.weight', 'c_attn.bias', 'c_proj.weight', 'c_proj.bias'])

In [60]:
attn

GPT2Attention(
  (c_attn): Linear(in_features=768, out_features=2304, bias=True)
  (c_proj): Linear(in_features=768, out_features=768, bias=True)
  (attn_dropout): Dropout(p=0.1, inplace=False)
  (resid_dropout): Dropout(p=0.1, inplace=False)
)

In [61]:
class GPT2MLP(nn.Module):
    def __init__(self,config):
        super().__init__()
        self.embed_dim = config.embed_dim
        self.mlp_ratio = config.mlp_ratio
        self.mlp_dropout = config.mlp_dropout
        
        self.c_fc = nn.Linear(self.embed_dim,self.embed_dim*self.mlp_ratio)
        self.c_proj = nn.Linear(self.embed_dim*self.mlp_ratio,self.embed_dim)
        self.act = nn.GELU()
        self.dropout = nn.Dropout(self.mlp_dropout)
        
    def forward(self,x):
        x = self.c_fc(x)
        x = self.act(x)
        x = self.c_proj(x)
        x = self.dropout(x)
        return x

In [62]:
mlp = GPT2MLP(config)
mlp.state_dict().keys()

odict_keys(['c_fc.weight', 'c_fc.bias', 'c_proj.weight', 'c_proj.bias'])

In [63]:
class GPT2Block(nn.Module):
    def __init__(self,config):
        super().__init__()
        self.embed_dim = config.embed_dim
        self.ln_1 = nn.LayerNorm(self.embed_dim)
        self.attn = GPT2Attention(config)
        self.ln_2 = nn.LayerNorm(self.embed_dim)
        self.mlp = GPT2MLP(config)
        
    def forward(self,x):
        x = x+self.attn(self.ln_1(x))
        x = x+self.mlp(self.ln_2(x))
        return x

In [64]:
block = GPT2Block(config)
print(block)
block.state_dict().keys()

GPT2Block(
  (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
  (attn): GPT2Attention(
    (c_attn): Linear(in_features=768, out_features=2304, bias=True)
    (c_proj): Linear(in_features=768, out_features=768, bias=True)
    (attn_dropout): Dropout(p=0.1, inplace=False)
    (resid_dropout): Dropout(p=0.1, inplace=False)
  )
  (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
  (mlp): GPT2MLP(
    (c_fc): Linear(in_features=768, out_features=3072, bias=True)
    (c_proj): Linear(in_features=3072, out_features=768, bias=True)
    (act): GELU(approximate='none')
    (dropout): Dropout(p=0.1, inplace=False)
  )
)


odict_keys(['ln_1.weight', 'ln_1.bias', 'attn.mask', 'attn.c_attn.weight', 'attn.c_attn.bias', 'attn.c_proj.weight', 'attn.c_proj.bias', 'ln_2.weight', 'ln_2.bias', 'mlp.c_fc.weight', 'mlp.c_fc.bias', 'mlp.c_proj.weight', 'mlp.c_proj.bias'])

In [116]:
class GPT2Model(nn.Module):
    def __init__(self,config):
        super().__init__()
        
        self.config = config
        
        self.transformer = nn.ModuleDict(dict(
            wte = nn.Embedding(config.vocab_size,config.embed_dim),
            wpe = nn.Embedding(config.seq_len,config.embed_dim),
            drop = nn.Dropout(config.emb_dropout),
            h = nn.ModuleList([GPT2Block(config) for _ in range(config.depth)]),
            ln_f = nn.LayerNorm(config.embed_dim)
        ))
        self.lm_head = nn.Linear(config.embed_dim,config.vocab_size,bias=False)
        self.transformer.wte.weight = self.lm_head.weight
        
    def forward(self,x):
        
        token_embeddings = self.transformer.wte(x) # batch x seq_len
        pos_embs = torch.arange(0,x.size(1)).to(x.device)
        positional_embeddings = self.transformer.wpe(pos_embs)
        x = self.transformer.drop(token_embeddings+positional_embeddings)
        for h in self.transformer.h:
            x = h(x) # batch_size x seq_len x embed_dim
        x = self.transformer.ln_f(x)[:,[-1],:] # get last hidden state: batch_size x 1 x embed_dim
        x = self.lm_head(x) # batch_size x vocab_size
        
        return x
    
    @torch.no_grad()
    def generate(self,idx,max_new_tokens=5,temperature=1.0):
        
        for _ in range(max_new_tokens+1):
            
            inp = idx if idx.size(1) <= self.config.seq_len else inp[:,-self.config.seq_len:]
            out = self(inp)
            out = out[:, -1, :] / temperature
            probs = F.softmax(out, dim=-1)
            idx_next = torch.multinomial(probs, num_samples=1)
            idx = torch.cat((idx, idx_next), dim=1)
            
        return idx[:,-max_new_tokens:]

In [117]:
gpt2 = GPT2Model(config)  

In [84]:
gpt2

GPT2Model(
  (transformer): ModuleDict(
    (wte): Embedding(50257, 768)
    (wpe): Embedding(1024, 768)
    (drop): Dropout(p=0.1, inplace=False)
    (h): ModuleList(
      (0-11): 12 x GPT2Block(
        (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (attn): GPT2Attention(
          (c_attn): Linear(in_features=768, out_features=2304, bias=True)
          (c_proj): Linear(in_features=768, out_features=768, bias=True)
          (attn_dropout): Dropout(p=0.1, inplace=False)
          (resid_dropout): Dropout(p=0.1, inplace=False)
        )
        (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (mlp): GPT2MLP(
          (c_fc): Linear(in_features=768, out_features=3072, bias=True)
          (c_proj): Linear(in_features=3072, out_features=768, bias=True)
          (act): GELU(approximate='none')
          (dropout): Dropout(p=0.1, inplace=False)
        )
      )
    )
    (ln_f): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
  )
  (l

In [91]:
x = torch.rand(1,config.seq_len).long()
x.shape

torch.Size([1, 1024])

In [92]:
gpt2(x).shape

torch.Size([1, 1024, 768])


torch.Size([1, 1, 50257])

In [105]:
x = torch.randint(0,config.vocab_size,(1,1024))

In [118]:
res = gpt2.generate(x)
res.shape

torch.Size([1, 5])

In [120]:
x[:,-5:]

tensor([[38890,  9988, 11832, 30680, 33973]])

In [119]:
res

tensor([[ 7894,  7780, 16557, 37860, 30901]])

In [123]:
params(gpt2.transformer)

124439808