In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import torch as t
from typing import *
from torch import einsum
from einops import rearrange, reduce, repeat
import gpt_tests
from torch import nn
from torch.nn import Module
from math import sqrt
import bert_sol

## Making the GPT-2 module

In [23]:
class MultiHeadedAttention(Module):
    def __init__(self, hidden_size, num_heads):
        super().__init__()
        self.hidden_size = hidden_size
        self.num_heads = num_heads
        self.head_size = hidden_size/num_heads
        self.attn_lin = nn.Linear(hidden_size, hidden_size*3)
        self.out_lin = nn.Linear(hidden_size, hidden_size)
    
    def forward(self, x):
        product = self.attn_lin(x)
        seq_len = x.shape[1]
        good_format = rearrange(product, 'b n (qkv h p) -> qkv b h n p', qkv = 3, h = self.num_heads)
        queries, keys, values = good_format[0], good_format[1], good_format[2]
        attn_score = t.einsum('bhfp,bhtp -> bhft', keys, queries) / sqrt(self.head_size)
        
        arange = t.arange(seq_len, device=x.device)
        arange_rows = repeat(arange, 'a -> b a', b = seq_len)
        arange_cols = repeat(arange, 'a -> a b', b = seq_len)
        attn_score[:,:,arange_rows < arange_cols] = -1e4
        
        attn_pattn = t.softmax(attn_score, dim=-2)
        # attn_pattn: b h n n; values: b h n p
        out_by_head = t.einsum('bhft,bhfp->bhtp', attn_pattn, values)
        out = rearrange(out_by_head, 'b h t p -> b t (h p)') # b n hidden_size
        return self.out_lin(out)

In [4]:
gpt_tests.test_unidirectional_attn(MultiHeadedAttention)

Congrats! You've passed the test!


In [5]:
class GPT2Block(Module):
    def __init__(self, hidden_size, num_heads, dropout, layer_norm_epsilon):
        super().__init__()
        self.ln1 = nn.LayerNorm(hidden_size, eps=layer_norm_epsilon)
        self.attn = MultiHeadedAttention(hidden_size, num_heads)
        self.ln2 = nn.LayerNorm(hidden_size, eps=layer_norm_epsilon)
        self.mlp = nn.Sequential(
            nn.Linear(hidden_size, hidden_size*4),
            nn.GELU(),
            nn.Linear(hidden_size*4, hidden_size))
        
    def forward(self, x):
        x1 = self.ln1(x)
        x2 = self.attn(x1) + x
        x3 = self.ln2(x2)
        return self.mlp(x3) + x2
        

In [6]:
gpt_tests.test_gpt_block(GPT2Block)

Congrats! You've passed the test!


In [7]:
from dataclasses import dataclass
from torchtyping import TensorType


@dataclass
class GPT2Output:
    logits: TensorType["batch_size", "vocab_size"]
    final_encoding: TensorType["batch_size", "hidden_size"]

    
class GPT2(Module):
    def __init__(self, num_layers, num_heads, vocab_size, hidden_size,
                 max_position_embeddings, dropout, layer_norm_epsilon):
        super().__init__()
        
        self.token_embedding = nn.Parameter(t.randn(vocab_size, hidden_size))
        self.pos_embedding = nn.Parameter(t.randn(max_position_embeddings, hidden_size))
        
        self.dropout = nn.Dropout(dropout)
        
        self.blocks = nn.Sequential(
            *[GPT2Block(hidden_size, num_heads, dropout, layer_norm_epsilon)
              for _ in range(num_layers)],
            nn.LayerNorm(hidden_size, eps=layer_norm_epsilon)
        )
        
        
    
    def forward(self, x):
        seq_len = x.shape[1]
        embedding = self.token_embedding[x] + self.pos_embedding[:seq_len]
        
        encoding = self.blocks(self.dropout(embedding))
        final_encoding = encoding[:,-1]
        
        logits = t.einsum('vc,bc->bv', self.token_embedding, final_encoding)
        
        return GPT2Output(logits=logits, final_encoding=final_encoding)
    
    
    
    

In [8]:
gpt_tests.test_gpt(GPT2)

Checking logits:
Congrats! You've passed the test!
Checking final encodings:
Congrats! You've passed the test!


## Loading pretrained weights

In [9]:
my_gpt = GPT2(num_layers=12, num_heads=12, vocab_size=50257, hidden_size=768,
                 max_position_embeddings=1024, dropout=0.1, layer_norm_epsilon=1e-5)
pretrained_gpt = gpt_tests.get_pretrained_gpt()

In [10]:
my_gpt = GPT2(num_layers=12, num_heads=12, vocab_size=50257, hidden_size=768,
                 max_position_embeddings=1024, dropout=0.1, layer_norm_epsilon=1e-5)

In [11]:
def string_replace(s):
    s = s.replace("embedding.weight", "embedding")
    s = s.replace("linear1", "mlp.0")
    s = s.replace("linear2", "mlp.2")
    s = s.replace("ln.", "blocks.12.")
    return s

their_dict = pretrained_gpt.state_dict()
for k in list(their_dict.keys()):
    their_dict[string_replace(k)] = their_dict.pop(k)

In [12]:
my_gpt.load_state_dict(their_dict)

<All keys matched successfully>

## Efficient text generation

In [13]:
import transformers
%env TOKENIZERS_PARALLELISM=false

tokenizer = transformers.GPT2Tokenizer.from_pretrained("gpt2")
# print(tokenizer(['Hello, I am a sentence.']))


env: TOKENIZERS_PARALLELISM=false


In [14]:
def feed_gpt(model: nn.Module, text: str, tokenizer, top_k: int = 10):
    input_ids: List[int] = tokenizer(text)["input_ids"]
    logits = model(t.tensor([input_ids], dtype=t.long)).logits
    probs = t.softmax(logits, dim=-1)
    top_logit_idxs = t.argsort(logits, descending=True)[0,:top_k]
    top_logit_words = tokenizer.decode(top_logit_idxs)
    print(top_logit_words)
    print(probs[0,top_logit_idxs])

In [15]:
pretrained_gpt.cpu()
feed_gpt(pretrained_gpt, "Students at the machine learning bootcamp really enjoyed the", tokenizer)

  return self._grad


 experience learning opportunity program process course training time work challenge
tensor([0.1915, 0.0389, 0.0338, 0.0290, 0.0174, 0.0165, 0.0162, 0.0151, 0.0129,
        0.0129], grad_fn=<IndexBackward0>)


In [16]:
def feed_gpt_top(model: nn.Module, input_ids: List[int], top_k: int = 10):
    logits = model(t.tensor([input_ids], dtype=t.long)).logits
    probs = t.softmax(logits, dim=-1)
    return t.argsort(logits, descending=True)[0,0]

In [17]:
my_gpt.eval();

In [18]:
start_str = "The machine learning bootcamp started out nicely. But soon, I got an ominous feeling. Shockingly, I discovered"
input_ids = tokenizer(start_str)["input_ids"]
for i in range(100):
    new_token = feed_gpt_top(my_gpt, input_ids, 1)
    print(tokenizer.decode(new_token), end = " ")
    input_ids.append(new_token)

 that  the  machine  learning  boot camp  was  actually  a  very  bad  idea . 
 
 I  was  not  alone .  I  was  also  not  alone  in  my  own  experience .  I  was  also  not  alone  in  my  own  experience . 
 
 I  was  not  alone  in  my  own  experience .  I  was  also  not  alone  in  my  own  experience . 
 
 I  was  not  alone  in  my  own  experience .  I  was  also  not  alone  in  my  own  experience . 
 
 I  was  not  alone  in  my  own  experience .  I  was  also  not  alone  in  my 

In [19]:
class GPT2Modified(Module):
    def __init__(self, num_layers, num_heads, vocab_size, hidden_size,
                 max_position_embeddings, dropout, layer_norm_epsilon):
        super().__init__()
        self.token_embedding = nn.Parameter(t.randn(vocab_size, hidden_size))
        self.pos_embedding = nn.Parameter(t.randn(max_position_embeddings, hidden_size))
        self.dropout = nn.Dropout(dropout)
        self.blocks = nn.Sequential(
            *[GPT2Block(hidden_size, num_heads, dropout, layer_norm_epsilon)
              for _ in range(num_layers)],
            nn.LayerNorm(hidden_size, eps=layer_norm_epsilon)
        )
        
    def forward(self, x):
        seq_len = x.shape[1]
        embedding = self.token_embedding[x] + self.pos_embedding[:seq_len]
        encoding = self.blocks(self.dropout(embedding))
        self.encoding = encoding
        final_encoding = encoding[:,-1]
        logits = t.einsum('vc,bc->bv', self.token_embedding, final_encoding)
        return GPT2Output(logits=logits, final_encoding=final_encoding)

In [20]:
my_gpt_modified = GPT2Modified(num_layers=12, num_heads=12, vocab_size=50257, hidden_size=768,
                 max_position_embeddings=1024, dropout=0.1, layer_norm_epsilon=1e-5)

In [21]:
my_gpt_modified.load_state_dict(their_dict)

<All keys matched successfully>

In [22]:
def create_padded_thing(tokenizer):
    s = 'My life motto: Fortune favors the bold'
    ids = tokenizer(s)["input_ids"]
    return [ids[:i]+[0]*(10-i) for i in range(4,9)]

test_batch = t.tensor(create_padded_thing(tokenizer))
bert_batch = t.tensor(create_padded_thing(transformers.AutoTokenizer.from_pretrained("bert-base-cased")))

def decode_batch(batch, tokenizer):
    print([tokenizer.decode(batch[i]) for i in range(len(batch))])

In [23]:
my_gpt_modified.eval()
my_gpt_modified(test_batch);
my_gpt_modified.encoding[:,:,:2]

tensor([[[-0.0340, -0.0429],
         [ 0.2075,  0.3059],
         [-0.0342, -0.2292],
         [ 0.3671, -0.0707],
         [-0.0323, -0.3814],
         [-0.0505, -0.2386],
         [ 0.0306, -0.1637],
         [ 0.0751, -0.1712],
         [ 0.0916, -0.1977],
         [ 0.1046, -0.2181]],

        [[-0.0340, -0.0429],
         [ 0.2075,  0.3059],
         [-0.0342, -0.2292],
         [ 0.3671, -0.0707],
         [-0.1335,  0.9349],
         [ 0.3443, -0.1450],
         [ 0.2800, -0.1220],
         [ 0.2185, -0.1960],
         [ 0.2100, -0.1533],
         [ 0.2249, -0.1559]],

        [[-0.0340, -0.0429],
         [ 0.2075,  0.3059],
         [-0.0342, -0.2292],
         [ 0.3671, -0.0707],
         [-0.1335,  0.9349],
         [ 0.1363,  0.0811],
         [ 0.4627, -0.2614],
         [ 0.3536, -0.2568],
         [ 0.1176, -0.2094],
         [ 0.1202, -0.1308]],

        [[-0.0340, -0.0429],
         [ 0.2075,  0.3059],
         [-0.0342, -0.2292],
         [ 0.3671, -0.0707],
        

In [24]:
my_bert = bert_sol.Bert(
    vocab_size=28996, hidden_size=768, max_position_embeddings=512, 
    type_vocab_size=2, dropout=0.1, intermediate_size=3072, 
    num_heads=12, num_layers=12
)
my_bert.eval()
my_bert(bert_batch);
my_bert._enc[:,:,:2];

## Fast GPT2

In [147]:
class FastMultiHeadedAttention(Module):
    def __init__(self, hidden_size, num_heads):
        super().__init__()
        self.hidden_size = hidden_size
        self.num_heads = num_heads
        self.head_size = hidden_size/num_heads
        self.attn_lin = nn.Linear(hidden_size, hidden_size*3)
        self.out_lin = nn.Linear(hidden_size, hidden_size)
    
    def forward(self, x, past_key_values=None, return_key_values=False):
        product = self.attn_lin(x)
        seq_len = x.shape[1]
        good_format = rearrange(product, 'b n (qkv h p) -> qkv b h n p', qkv = 3, h = self.num_heads)
        queries, keys, values = good_format[0], good_format[1], good_format[2]

        if past_key_values != None:
            # x has shape [b=1, n=1, hidden_size]
            past_keys, past_values = t.split(past_key_values, past_key_values.shape[-1]//2, dim=-1)
            keys = t.cat((t.unsqueeze(past_keys, dim=0), keys), dim=2)
            values = t.cat((t.unsqueeze(past_values, dim = 0), values), dim=2)            
        
        attn_score = t.einsum('bhfp,bhtp -> bhft', keys, queries) / sqrt(self.head_size)
        
        if past_key_values is None:
            arange = t.arange(seq_len, device=x.device)
            arange_rows = repeat(arange, 'a -> b a', b = seq_len)
            arange_cols = repeat(arange, 'a -> a b', b = seq_len)
            attn_score[:,:,arange_rows < arange_cols] = -1e4
        
        attn_pattn = t.softmax(attn_score, dim=-2)
        # attn_pattn: b h n n; values: b h n p
        out_by_head = t.einsum('bhft,bhfp->bhtp', attn_pattn, values)
        out = rearrange(out_by_head, 'b h t p -> b t (h p)') # b n hidden_size

        encoding = self.out_lin(out)
        if return_key_values:
            if past_key_values is None:
                return encoding, t.cat((keys, values), dim=-1)
            return encoding, t.cat((keys[:,:,-1:], values[:,:,-1:]), dim=-1)
        return encoding

In [20]:
gpt_tests.test_attn_cache(FastMultiHeadedAttention)

Checking encoding:
Congrats! You've passed the test!
Checking new key and value:
Congrats! You've passed the test!


In [148]:
class FastGPT2Block(Module):
    def __init__(self, hidden_size, num_heads, dropout, layer_norm_epsilon):
        super().__init__()
        self.ln1 = nn.LayerNorm(hidden_size, eps=layer_norm_epsilon)
        self.attn = FastMultiHeadedAttention(hidden_size, num_heads)
        self.ln2 = nn.LayerNorm(hidden_size, eps=layer_norm_epsilon)
        self.mlp = nn.Sequential(
            nn.Linear(hidden_size, hidden_size*4),
            nn.GELU(),
            nn.Linear(hidden_size*4, hidden_size))
        
    def forward(self, x, past_key_values=None, return_key_values=False):
        x1 = self.ln1(x)
        x2 = self.attn(x1, past_key_values, return_key_values)
        if return_key_values:
            x3 = x2[0] + x
        else:
            x3 = x2 + x
        x4 = self.ln2(x3)
        if return_key_values:
            return self.mlp(x4) + x3, x2[1]
        return self.mlp(x4) + x3
        

In [188]:
from dataclasses import dataclass
from torchtyping import TensorType
import numpy as np
import transformers


@dataclass
class GPT2Output:
    logits: TensorType["batch_size", "vocab_size"]
    final_encoding: TensorType["batch_size", "hidden_size"]

    
class FastGPT2(Module):
    def __init__(self, num_layers, num_heads, vocab_size, hidden_size,
                 max_position_embeddings, dropout, layer_norm_epsilon, use_cache=False):
        super().__init__()
        self.num_layers = num_layers
        self.num_heads = num_heads
        self.use_cache = use_cache
        self.hidden_size = hidden_size
        self.vocab_size = vocab_size
        self.token_embedding = nn.Parameter(t.randn(vocab_size, hidden_size))
        self.pos_embedding = nn.Parameter(t.randn(max_position_embeddings, hidden_size))
        self.clear_cache()
        
        self.dropout = nn.Dropout(dropout)
        
        self.blocks = nn.ModuleList(
            [FastGPT2Block(hidden_size, num_heads, dropout, layer_norm_epsilon)
              for _ in range(num_layers)]
        )

        self.layernorm = nn.LayerNorm(hidden_size, eps=layer_norm_epsilon)

        self.tokenizer = transformers.GPT2Tokenizer.from_pretrained('gpt2')
        
    def clear_cache(self):
        self.cache = t.zeros(self.num_layers, self.num_heads, 0, 2*self.hidden_size//self.num_heads)
        
    
    def forward(self, x):
        pre_seq_len = x.shape[1]

        if self.cache.shape[-2] != 0:
            x = x[:,-1:]
            pos_embedding = self.pos_embedding[pre_seq_len-1:pre_seq_len]
        else:
            pos_embedding = self.pos_embedding[:pre_seq_len]

        seq_len = x.shape[1]
        
        encoding = self.token_embedding[x] + pos_embedding
        encoding = self.dropout(encoding)

        if self.use_cache:
            if self.cache.shape[-2] == 0:
                new_cache = t.zeros(self.num_layers, self.num_heads, seq_len, 2*self.hidden_size//self.num_heads)
                for i,block in enumerate(self.blocks):
                    encoding, new_cache[i] = block(encoding, past_key_values=None, return_key_values=True)
            else:
                new_cache = t.zeros(self.num_layers, self.num_heads, 1, 2*self.hidden_size//self.num_heads)
                for i,block in enumerate(self.blocks):
                    encoding, new_cache[i] = block(encoding, past_key_values=self.cache[i], return_key_values=True)
            self.cache = t.cat((self.cache, new_cache), dim=-2) #wip
        else:
            for i,block in enumerate(self.blocks):
                encoding = block(encoding)

        final_encoding = self.layernorm(encoding)[:,-1]
        
        logits = t.einsum('vc,bc->bv', self.token_embedding, final_encoding)
        
        return GPT2Output(logits=logits, final_encoding=final_encoding)
    
    def next_token(self, input_ids, temperature, freq_penalty=2.0):
        logits = self.forward(t.tensor(input_ids).reshape(1,len(input_ids))).logits

        freqs = t.zeros(self.vocab_size)
        for i in input_ids:
            freqs[i] += 1

        probs = t.softmax(logits/temperature - freqs*freq_penalty, dim=1)
        return np.random.choice(probs.flatten().shape[0], p=probs.flatten().detach().numpy())

    def generate(self, text, max_length=30, temperature=1.0, freq_penalty=2.0):
        ids = self.tokenizer(text)['input_ids']
        self.clear_cache()
        while len(ids) < max_length:
            new_token = self.next_token(ids, temperature, freq_penalty)
            ids.append(new_token)
        print(self.tokenizer.decode(ids))
        return ids
        

In [189]:
# gpt_tests.test_gpt_block(GPT2Block)
# gpt_tests.test_gpt(FastGPT2)
gpt_tests.test_gpt_cache(FastGPT2)

Congrats! Your GPT returns the same results with and without cache.
It took 2.351s to generate a 500-token sentence without cache and 0.756s with cache.


In [190]:
my_gpt = FastGPT2(num_layers=12, num_heads=12, vocab_size=50257, hidden_size=768,
                 max_position_embeddings=1024, dropout=0.1, layer_norm_epsilon=1e-5, use_cache=True)
# pretrained_gpt = gpt_tests.get_pretrained_gpt()

In [191]:
def string_replace(s):
    s = s.replace("embedding.weight", "embedding")
    s = s.replace("linear1", "mlp.0")
    s = s.replace("linear2", "mlp.2")
    s = s.replace("ln.", "layernorm.")
    return s

their_dict = pretrained_gpt.state_dict()
for k in list(their_dict.keys()):
    their_dict[string_replace(k)] = their_dict.pop(k)

In [192]:
my_gpt.load_state_dict(their_dict)

<All keys matched successfully>

In [248]:
my_gpt.eval()
my_gpt.generate('My favorite emoticon is:', freq_penalty=2);


My favorite emoticon is: Michael Jackson ANT

I like the NFL as a whole since it's hipsters vs. men, but @


In [249]:
__

<All keys matched successfully>