In [4]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [5]:
import torch as t
from torch import einsum
from einops import rearrange, reduce, repeat
import gpt_tests
from torch import nn
from torch.nn import Module
from math import sqrt

## Making the GPT-2 module

In [44]:
class MultiHeadedAttention(Module):
    def __init__(self, hidden_size, num_heads):
        super().__init__()
        self.hidden_size = hidden_size
        self.num_heads = num_heads
        self.head_size = hidden_size/num_heads
        self.attn_lin = nn.Linear(hidden_size, hidden_size*3)
        self.out_lin = nn.Linear(hidden_size, hidden_size)
    
    def forward(self, x):
        product = self.attn_lin(x)
        seq_len = x.shape[1]
        good_format = rearrange(product, 'b n (qkv h p) -> qkv b h n p', qkv = 3, h = self.num_heads)
        queries, keys, values = good_format[0], good_format[1], good_format[2]
        attn_score = t.einsum('bhfp,bhtp -> bhft', keys, queries) / sqrt(self.head_size)
        
        arange = t.arange(seq_len, device=x.device)
        arange_rows = repeat(arange, 'a -> b a', b = seq_len)
        arange_cols = repeat(arange, 'a -> a b', b = seq_len)
        attn_score[:,:,arange_rows < arange_cols] = -1e4
        
        attn_pattn = t.softmax(attn_score, dim=-2)
        # attn_pattn: b h n n; values: b h n p
        out_by_head = t.einsum('bhft,bhfp->bhtp', attn_pattn, values)
        out = rearrange(out_by_head, 'b h t p -> b t (h p)') # b n hidden_size
        return self.out_lin(out)

In [45]:
gpt_tests.test_unidirectional_attn(MultiHeadedAttention)

Congrats! You've passed the test!


In [46]:
class GPT2Block(Module):
    def __init__(self, hidden_size, num_heads, dropout, layer_norm_epsilon):
        super().__init__()
        self.ln1 = nn.LayerNorm(hidden_size, eps=layer_norm_epsilon)
        self.attn = MultiHeadedAttention(hidden_size, num_heads)
        self.ln2 = nn.LayerNorm(hidden_size, eps=layer_norm_epsilon)
        self.mlp = nn.Sequential(
            nn.Linear(hidden_size, hidden_size*4),
            nn.GELU(),
            nn.Linear(hidden_size*4, hidden_size))
        
    def forward(self, x):
        x1 = self.ln1(x)
        x2 = self.attn(x1) + x
        x3 = self.ln2(x2)
        return self.mlp(x3) + x2
        

In [47]:
gpt_tests.test_gpt_block(GPT2Block)

Congrats! You've passed the test!


In [64]:
from dataclasses import dataclass
from torchtyping import TensorType


@dataclass
class GPT2Output:
    logits: TensorType["batch_size", "vocab_size"]
    final_encoding: TensorType["batch_size", "hidden_size"]

    
class GPT2(Module):
    def __init__(self, num_layers, num_heads, vocab_size, hidden_size,
                 max_position_embeddings, dropout, layer_norm_epsilon):
        super().__init__()
        
        self.token_embedding = nn.Parameter(t.randn(vocab_size, hidden_size))
        self.pos_embedding = nn.Parameter(t.randn(max_position_embeddings, hidden_size))
        
        self.dropout = nn.Dropout(dropout)
        
        self.blocks = nn.Sequential(
            *[GPT2Block(hidden_size, num_heads, dropout, layer_norm_epsilon)
              for _ in range(num_layers)],
            nn.LayerNorm(hidden_size, eps=layer_norm_epsilon)
        )
        
        
    
    def forward(self, x):
        seq_len = x.shape[1]
        embedding = self.token_embedding[x] + self.pos_embedding[:seq_len]
        
        encoding = self.blocks(self.dropout(embedding))
        final_encoding = encoding[:,-1]
        
        logits = t.einsum('vc,bc->bv', self.token_embedding, final_encoding)
        
        return GPT2Output(logits=logits, final_encoding=final_encoding)
    
    
    
    

In [49]:
gpt_tests.test_gpt(GPT2)

Checking logits:
Congrats! You've passed the test!
Checking final encodings:
Congrats! You've passed the test!


## Loading pretrained weights

In [61]:
my_gpt = GPT2(num_layers=12, num_heads=12, vocab_size=50257, hidden_size=768,
                 max_position_embeddings=1024, dropout=0.1, layer_norm_epsilon=1e-5)
pretrained_gpt = gpt_tests.get_pretrained_gpt()

In [50]:
my_gpt = GPT2(num_layers=12, num_heads=12, vocab_size=50257, hidden_size=768,
                 max_position_embeddings=1024, dropout=0.1, layer_norm_epsilon=1e-5)

In [62]:
def string_replace(s):
    s = s.replace("embedding.weight", "embedding")
    s = s.replace("linear1", "mlp.0")
    s = s.replace("linear2", "mlp.2")
    s = s.replace("ln.", "blocks.12.")
    return s

their_dict = pretrained_gpt.state_dict()
for k in list(their_dict.keys()):
    their_dict[string_replace(k)] = their_dict.pop(k)

In [63]:
for k in their_dict:
    print(k)

token_embedding
pos_embedding
blocks.0.ln1.weight
blocks.0.ln1.bias
blocks.0.attn.attn_lin.weight
blocks.0.attn.attn_lin.bias
blocks.0.attn.out_lin.weight
blocks.0.attn.out_lin.bias
blocks.0.ln2.weight
blocks.0.ln2.bias
blocks.0.mlp.0.weight
blocks.0.mlp.0.bias
blocks.0.mlp.2.weight
blocks.0.mlp.2.bias
blocks.1.ln1.weight
blocks.1.ln1.bias
blocks.1.attn.attn_lin.weight
blocks.1.attn.attn_lin.bias
blocks.1.attn.out_lin.weight
blocks.1.attn.out_lin.bias
blocks.1.ln2.weight
blocks.1.ln2.bias
blocks.1.mlp.0.weight
blocks.1.mlp.0.bias
blocks.1.mlp.2.weight
blocks.1.mlp.2.bias
blocks.2.ln1.weight
blocks.2.ln1.bias
blocks.2.attn.attn_lin.weight
blocks.2.attn.attn_lin.bias
blocks.2.attn.out_lin.weight
blocks.2.attn.out_lin.bias
blocks.2.ln2.weight
blocks.2.ln2.bias
blocks.2.mlp.0.weight
blocks.2.mlp.0.bias
blocks.2.mlp.2.weight
blocks.2.mlp.2.bias
blocks.3.ln1.weight
blocks.3.ln1.bias
blocks.3.attn.attn_lin.weight
blocks.3.attn.attn_lin.bias
blocks.3.attn.out_lin.weight
blocks.3.attn.out_lin.b