In [1]:
import torch as t
import torch
from torch import einsum
from torch import nn
from einops import rearrange, reduce, repeat
import mlab_tests

In [2]:
class Attention(t.nn.Module):
    def __init__(self, hidden_size, num_heads):
        super().__init__()
        self.head_size = hidden_size // num_heads
        self.hidden_size = hidden_size
        self.num_heads = num_heads
        
        self.attention = nn.Linear(hidden_size,hidden_size*3)
        self.output_projection = nn.Linear(hidden_size, hidden_size)
    
    def forward(self, x, past_key_values=None, return_key_values=False):
        
        if past_key_values is not None:
            k, v = torch.tensor_split(past_key_values.unsqueeze(0), 2, 3)
            out = self.attention(x)
            q, k_final, v_final = out.tensor_split(3,dim=-1)
            q = rearrange(q, "b n (h c) -> b h n c", h = self.num_heads)
            k_final = rearrange(k_final, "b n (h c) -> b h n c", h = self.num_heads)
            v_final = rearrange(v_final, "b n (h c) -> b h n c", h = self.num_heads)
            
            k = torch.cat((k, k_final), 2)
            v = torch.cat((v, v_final), 2)
            attention_score = einsum("bhtc,bhfc->bhft", q, k)

            
        else:     
            out = self.attention(x)
            q, k, v = out.tensor_split(3,dim=-1)
            
            q = rearrange(q, "b n (h c) -> b h n c", h = self.num_heads)
            k = rearrange(k, "b n (h c) -> b h n c", h = self.num_heads)
            v = rearrange(v, "b n (h c) -> b h n c", h = self.num_heads)

            attention_score = einsum("bhtc,bhfc->bhft", q, k)
                        
            attention_score = t.triu(attention_score) - 1e4*t.tril(t.ones_like(attention_score), diagonal = -1)
                
        scaled_attn_score = 1/torch.sqrt(t.tensor(self.head_size)) * attention_score
                
        attention_pattern = t.nn.Softmax(dim = -2)(scaled_attn_score)
            
        out = einsum("bhft,bhfc -> bhtc", attention_pattern, v)
                
        out = rearrange(out, "b h n c -> b n (h c)")
        
        out = self.output_projection(out)
        
        if return_key_values and past_key_values is not None:
            return out, torch.cat((k_final, v_final), 3)
        else:
            return out
# gpt_tests.test_attn_cache(Attention)
# gpt_tests.test_unidirectional_attn(Attention)

In [3]:
import gpt_tests

gpt_tests.test_unidirectional_attn(Attention)

our out tensor([[[-2.4267e-01, -6.7535e-01,  1.6882e-03,  1.5479e-01,  1.8791e-01,
           1.0629e-01, -2.1500e-01,  5.4569e-01, -2.8674e-01,  1.5972e-01,
           1.9140e-01, -4.1335e-02,  3.4724e-01, -4.7367e-01, -3.7063e-01,
          -3.4441e-01, -1.9860e-01,  4.2505e-01,  4.3122e-01,  2.2667e-01,
          -2.7387e-03,  2.3899e-02,  1.7695e-01,  2.5153e-01],
         [-4.8123e-01, -2.6815e-01, -1.7251e-01,  9.6488e-02,  1.1933e-01,
           7.1637e-03,  1.1151e-01,  3.4887e-01,  1.9976e-01,  7.5051e-02,
           2.6465e-01, -1.3239e-01,  2.2056e-01, -1.8246e-01, -2.0666e-01,
          -2.6338e-01, -2.2939e-01,  3.9007e-01,  1.1968e-01,  3.3997e-01,
           4.3774e-01, -3.1492e-02,  5.4310e-02,  1.2460e-01],
         [-3.3502e-01, -4.1158e-01,  3.3808e-02,  4.1109e-03,  1.8947e-01,
           1.1793e-01, -1.7818e-02,  2.9381e-01,  7.5902e-02,  4.2964e-02,
           9.8588e-02, -4.1910e-01,  2.6740e-01, -1.0821e-01,  3.8267e-03,
          -9.4716e-02, -1.8388e-01,  3.13

In [4]:
class Block(nn.Module):
    def __init__(self, hidden_size, num_heads, dropout, layer_norm_epsilon):
        super().__init__()
        self.hidden_size = hidden_size
        self.num_heads = num_heads
        self.dropout = dropout
        self.layer_norm_epsilon = layer_norm_epsilon
        
        self.ln1 = nn.LayerNorm(self.hidden_size, eps=layer_norm_epsilon)
        self.attention = Attention(self.hidden_size, self.num_heads)
        self.ln2 = nn.LayerNorm(self.hidden_size, eps=layer_norm_epsilon)
        self.MLP = nn.Sequential(nn.Linear(self.hidden_size, self.hidden_size * 4),
                                   nn.GELU(),
                                   nn.Linear(self.hidden_size * 4, self.hidden_size),
                                   nn.Dropout(self.dropout)
                                  )
        
    def forward(self, x, past_key_values = None, return_key_values = False):
        if return_key_values:
            layer1, new_key_values = self.attention(self.ln1(x)) 
            layer1 = layer1 + x
            return self.ML
        # past_key_values = self.attention(self.ln2(layer1)) + layer1
            
        layer1 = self.attention(self.ln1(x), past_key_values) + x
        return self.MLP(self.ln2(layer1)) + layer1

# gpt_tests.test_gpt_block(Block)

In [5]:
from dataclasses import dataclass
from torchtyping import TensorType
from collections import Counter
import transformers

@dataclass
class GPT2Output:
    logits: TensorType["batch_size", "vocab_size"]
    final_encoding: TensorType["batch_size", "hidden_size"]
    # L_encoding: TensorType["batch_size", "hidden_size"]

    
class GPT2(nn.Module):
    def __init__(self, num_layers, num_heads, vocab_size, hidden_size, 
                 max_position_embeddings, dropout, layer_norm_epsilon, L=0, use_cache=False):
        super().__init__()
        self.vocab_size = vocab_size
        self.head_size = hidden_size // num_heads
        self.num_heads = num_heads
        self.num_layers = num_layers
        self.token_embeddings = nn.Embedding(vocab_size, hidden_size)
        self.position_embeddings = nn.Embedding(max_position_embeddings, hidden_size)
        self.dropout = nn.Dropout(dropout)
        self.blocks = nn.ModuleList([Block(hidden_size, num_heads, dropout, layer_norm_epsilon)
                                      for _ in range(num_layers)])
        
        self.ln = nn.LayerNorm(hidden_size, eps=layer_norm_epsilon)
        
        self.use_cache = use_cache
        if self.use_cache:
            self.cache = t.zeros((num_layers, num_heads, 0, 2*self.head_size))
        
    def forward(self, input_ids): # [batch, seq_len]
        positions = torch.arange(input_ids.shape[1], device=input_ids.get_device() 
                                 if input_ids.get_device() >= 0 else 'cpu')
        embeddings = self.token_embeddings(input_ids) + self.position_embeddings(positions)
        embeddings = self.dropout(embeddings)
        
        if self.use_cache and self.cache.shape[1] == 0:
                        
            nl, nh, _, hs = self.cache.shape
            cache_updates = t.tensor((0, nh, 1, hs))
            
            if self.cache.shape[2] == 0:
                tokens_to_process = range(input_ids.shape[1])
            
            else:
                tokens_to_process = [-1]
            
            for i in tokens_to_process:
                
                out = embeddings[:,i]
                for block_num, block in enumerate(self.blocks):
                    out, new_keys = block(out, cache[block_num], True)
                    cache_updates = t.cat((cache_updates,new_keys),0)
                
                self.cache = t.cat((self.cache, cache_updates), 2)
        else:
            out = embeddings
            for block_num, block in enumerate(self.blocks):
                out = block(out)        
        
        embeddings = self.ln(out)
        output = GPT2Output(einsum("bnc, vc -> bnv", embeddings, self.token_embeddings.weight)[:, -1],
                            embeddings[:, -1])
        
        
        return output
    
    def next_token(self, input_ids, temperature, freq_penalty=2.0):
        output = self.forward(torch.tensor(input_ids).unsqueeze(0))
        highly_creative_counter = Counter(input_ids)
        
        frequencies = torch.zeros(self.vocab_size)
        for word, count in highly_creative_counter.items():
            frequencies[word] += count
            
        return nn.Softmax(dim=0)(output.logits.squeeze(0) / temperature - frequencies * freq_penalty).argmax(dim=0)
    
    def generate(self, text, max_length=30, temperature=1.0, freq_penalty=2.0):
        self.cache = t.zeros((self.num_layers, self.num_heads, 0, 2*self.head_size))
        tokenizer = transformers.GPT2Tokenizer.from_pretrained("gpt2")
        input_ids = tokenizer(text)["input_ids"]
        
        length = len(input_ids)
        cur_token = input_ids[-1]
        while length < max_length and cur_token != tokenizer.eos_token_id:
            length += 1
            cur_token = self.next_token(input_ids, temperature, freq_penalty)
            input_ids.append(cur_token)
        
        return tokenizer.decode(input_ids)
        
    

In [6]:
bert = GPT2(num_layers=12, num_heads=12, vocab_size=50257, hidden_size=768, 
                 max_position_embeddings=1024, dropout=0.1, layer_norm_epsilon=1e-5, L = 11, use_cache=True)

pretrained_gpt = gpt_tests.get_pretrained_gpt()
pretrained_state_dict = pretrained_gpt.state_dict()

zipped_params = [param for param in zip(pretrained_state_dict.keys(), bert.state_dict().keys())]
new_state_dict = {}
for pretrained, ours in zipped_params:
    new_state_dict[ours] = pretrained_state_dict[pretrained]

bert.load_state_dict(new_state_dict)

<All keys matched successfully>

In [None]:
bert.generate("""What the fuck did you just fucking say about me, you little bitch? I'll have you know I graduated top of my class in the Navy Seals, and I've been involved in numerous secret raids on Al-Quaeda, and I have over 300 confirmed kills. I am trained in gorilla warfare and I'm the top sniper in the entire US armed forces. You are nothing to me but just another target. I will wipe you the fuck out with precision the likes of which has never been seen before on this Earth, mark my fucking words. You think you can get away with saying that shit to me over the Internet? Think again, fucker. As we speak I am contacting my secret network of spies across the USA and your IP is being traced right now so you better prepare for the storm, maggot. The storm that wipes out the pathetic little thing you call your life. You're fucking dead, kid. I can be anywhere, anytime, and I can kill you in over seven hundred ways, and that's just with my bare hands. Not only am I extensively trained in unarmed combat, but I have access to the entire arsenal of the United States Marine Corps and I will use it to its full extent to wipe your miserable ass off the face of the continent, you little shit. If only you could have known what unholy retribution your little "clever" comment was about to bring down upon you, maybe you would have held your fucking tongue.
Spongbob: """, max_length=1500)

In [8]:
# bert.generate("Here's my favorite joke: ", max_length=400)

In [9]:
my_gpt = GPT2(num_layers=12, num_heads=12, vocab_size=50257, hidden_size=768, 
                 max_position_embeddings=1024, dropout=0.1, layer_norm_epsilon=1e-5, L = 11)

pretrained_gpt = gpt_tests.get_pretrained_gpt()
pretrained_state_dict = pretrained_gpt.state_dict()
#my_gpt.load_state_dict(pretrained_state_dict)

In [10]:
zipped_params = [param for param in zip(pretrained_state_dict.keys(), my_gpt.state_dict().keys())]
new_state_dict = {}
for pretrained, ours in zipped_params:
    new_state_dict[ours] = pretrained_state_dict[pretrained]

my_gpt.load_state_dict(new_state_dict)

<All keys matched successfully>

In [11]:
import transformers

from bert_sol import Bert

my_bert = Bert(
    vocab_size=28996, hidden_size=768, max_position_embeddings=512, 
    type_vocab_size=2, dropout=0.1, intermediate_size=3072, 
    num_heads=12, num_layers=12, L = 11,
)


tokenizer = transformers.AutoTokenizer.from_pretrained("bert-base-cased")

texts = ["My life motto:",
         "My life motto: Fortune",
         "My life motto: Fortune favors",
         "My life motto: Fortune favors the",
         "My life motto: Fortune favors the bold"]


def encoding(texts, i):
    # inputs = t.zeros((len(texts), len(texts[0])))
    tokens = tokenizer(texts, padding = "longest")
    
    inputs = t.tensor(tokens["input_ids"])
    
    my_gpt.eval()
    my_bert.eval()
    
    gpt_encodings = my_gpt(inputs).L_encoding[:,i]
    
    bert_encodings = my_bert(inputs).L_encoding[:,i]
    
    print(gpt_encodings)
    print(bert_encodings)
    
    
encoding(texts,2)

AttributeError: 'GPT2Output' object has no attribute 'L_encoding'