In [13]:
import torch as t
import torch
from torch import einsum
from torch import nn
from einops import rearrange, reduce, repeat
import mlab_tests

In [14]:
class Attention(t.nn.Module):
    def __init__(self, hidden_size, num_heads):
        super().__init__()
        self.head_size = hidden_size // num_heads
        self.hidden_size = hidden_size
        self.num_heads = num_heads
        
        self.attention = nn.Linear(hidden_size,hidden_size*3)
        self.output_projection = nn.Linear(hidden_size, hidden_size)
    
    def forward(self, x, past_key_values=None, return_key_values=False):
        
        if past_key_values is not None:
            k, v = torch.tensor_split(past_key_values.unsqueeze(0), 2, 3)
            out = self.attention(x)
            q, k_final, v_final = out.tensor_split(3,dim=-1)
            q = rearrange(q, "b n (h c) -> b h n c", h = self.num_heads)
            k_final = rearrange(k_final, "b n (h c) -> b h n c", h = self.num_heads)
            v_final = rearrange(v_final, "b n (h c) -> b h n c", h = self.num_heads)
            
            k = torch.cat((k, k_final), 2)
            v = torch.cat((v, v_final), 2)
            attention_score = einsum("bhtc,bhfc->bhft", q, k)

            
        else:     
            out = self.attention(x)
            q, k, v = out.tensor_split(3,dim=-1)
            
            q = rearrange(q, "b n (h c) -> b h n c", h = self.num_heads)
            k = rearrange(k, "b n (h c) -> b h n c", h = self.num_heads)
            v = rearrange(v, "b n (h c) -> b h n c", h = self.num_heads)

            attention_score = einsum("bhtc,bhfc->bhft", q, k)
                        
            attention_score = t.triu(attention_score) - 1e4*t.tril(t.ones_like(attention_score), diagonal = -1)
                
        scaled_attn_score = 1/torch.sqrt(t.tensor(self.head_size)) * attention_score
                
        attention_pattern = t.nn.Softmax(dim = -2)(scaled_attn_score)
            
        out = einsum("bhft,bhfc -> bhtc", attention_pattern, v)
                
        out = rearrange(out, "b h n c -> b n (h c)")
        
        out = self.output_projection(out)
        
        if return_key_values and past_key_values is not None:
            return out, torch.cat((k_final, v_final), 3)
        else:
            return out
gpt_tests.test_attn_cache(Attention)
# gpt_tests.test_unidirectional_attn(Attention)

Checking encoding:
Congrats! You've passed the test!
Checking new key and value:
Congrats! You've passed the test!


In [15]:
import gpt_tests

gpt_tests.test_unidirectional_attn(Attention)

our out tensor([[[ 3.2229e-01, -3.7651e-01,  3.6197e-01,  1.8525e-02,  2.0513e-01,
          -4.9755e-01, -3.2238e-01,  1.7564e-01,  6.9831e-02, -4.7030e-01,
           8.4727e-02, -1.8117e-01,  3.2580e-01,  3.3301e-01, -3.8834e-01,
           4.4791e-01, -4.6027e-01, -5.3693e-01,  1.3190e-01, -9.3719e-02,
          -1.3495e-01, -5.6973e-01,  3.2986e-01, -3.2485e-03],
         [ 4.7690e-01,  1.6798e-01, -1.6052e-03,  1.7268e-01,  3.5946e-01,
          -3.4603e-01,  2.7828e-01,  8.2329e-05,  4.7646e-02, -1.1094e-01,
           4.2770e-01,  7.3739e-02,  4.5244e-01, -9.9570e-02, -2.9142e-01,
           2.8042e-01, -4.2251e-02,  1.3468e-01, -1.6897e-01, -3.7332e-01,
           6.9931e-02,  1.4339e-02,  2.0342e-01,  4.3749e-02],
         [ 1.4101e-01, -8.8313e-03,  5.5020e-02,  2.6923e-01,  1.0635e-01,
          -2.5896e-01,  3.7684e-02,  1.9169e-01,  1.2130e-01, -4.1290e-01,
           1.7931e-01, -1.3171e-01,  8.7645e-02, -1.7535e-01, -3.1550e-01,
           1.5904e-01,  1.8276e-02,  2.49

In [9]:
class Block(nn.Module):
    def __init__(self, hidden_size, num_heads, dropout, layer_norm_epsilon):
        super().__init__()
        self.hidden_size = hidden_size
        self.num_heads = num_heads
        self.dropout = dropout
        self.layer_norm_epsilon = layer_norm_epsilon
        
        self.ln1 = nn.LayerNorm(self.hidden_size, eps=layer_norm_epsilon)
        self.attention = Attention(self.hidden_size, self.num_heads)
        self.ln2 = nn.LayerNorm(self.hidden_size, eps=layer_norm_epsilon)
        self.MLP = nn.Sequential(nn.Linear(self.hidden_size, self.hidden_size * 4),
                                   nn.GELU(),
                                   nn.Linear(self.hidden_size * 4, self.hidden_size),
                                   nn.Dropout(self.dropout)
                                  )
        
    def forward(self, x, past_key_values = None, return_key_values = False):
        if return_key_values:
            layer1, new_key_values = self.attention(self.ln1(x)) 
            layer1 = layer1 + x
            return self.ML
        P(self.ln2(layer1)) + layer1, new_key_values
            
        layer1 = self.attention(self.ln1(x), past_key_values) + x
        return self.MLP(self.ln2(layer1)) + layer1

gpt_tests.test_gpt_block(Block)

Congrats! You've passed the test!


In [10]:
from dataclasses import dataclass
from torchtyping import TensorType

@dataclass
class GPT2Output:
    logits: TensorType["batch_size", "vocab_size"]
    final_encoding: TensorType["batch_size", "hidden_size"]
    # L_encoding: TensorType["batch_size", "hidden_size"]

    
class GPT2(nn.Module):
    def __init__(self, num_layers, num_heads, vocab_size, hidden_size, 
                 max_position_embeddings, dropout, layer_norm_epsilon, L=0, use_cache=False):
        super().__init__()
        self.head_size = hidden_size // num_heads
        self.token_embeddings = nn.Embedding(vocab_size, hidden_size)
        self.position_embeddings = nn.Embedding(max_position_embeddings, hidden_size)
        self.dropout = nn.Dropout(dropout)
        self.blocks = nn.ModuleList([Block(hidden_size, num_heads, dropout, layer_norm_epsilon)
                                      for _ in range(num_layers)])
        
        self.ln = nn.LayerNorm(hidden_size, eps=layer_norm_epsilon)
        
        self.use_cache = use_cache
        if self.use_cache:
            self.cache = t.zeros((num_layers, num_heads, 0, 2*self.head_size))
        
    def forward(self, input_ids): # [batch, seq_len]
        positions = torch.arange(input_ids.shape[1], device=input_ids.get_device() 
                                 if input_ids.get_device() >= 0 else 'cpu')
        embeddings = self.token_embeddings(input_ids) + self.position_embeddings(positions)
        embeddings = self.dropout(embeddings)
        
        if self.use_cache and self.cache.shape[1] == 0:
                        
            nl, nh, _, hs = self.cache.shape
            cache_updates = t.tensor((0, nh, 1, hs))
            
            if self.cache.shape[2] == 0:
                tokens_to_process = range(input_ids.shape[1])
            
            else:
                tokens_to_process = [-1]
            
            for i in tokens_to_process:
                
                out = embeddings[:,i]
                for block_num, block in enumerate(self.blocks):
                    out, new_keys = block(out, cache[block_num], True)
                    cache_updates = t.cat((cache_updates,new_keys),0)
                
                self.cache = t.cat((self.cache, cache_updates), 2)
        else:
            out = embeddings
            for block_num, block in enumerate(self.blocks):
                out = block(out)        
        
        embeddings = self.ln(out)
        output = GPT2Output(einsum("bnc, vc -> bnv", embeddings, self.token_embeddings.weight)[:, -1],
                            embeddings[:, -1])
        
        
        return output

gpt_tests.test_gpt_cache(GPT2)

Congrats! Your GPT returns the same results with and without cache.
It took 4.069s to generate a 500-token sentence without cache and 2.895s with cache.


In [21]:
ten = t.tensor((1,0,2))
ten

tensor([1, 0, 2])

In [9]:
my_gpt = GPT2(num_layers=12, num_heads=12, vocab_size=50257, hidden_size=768, 
                 max_position_embeddings=1024, dropout=0.1, layer_norm_epsilon=1e-5, L = 11)

pretrained_gpt = gpt_tests.get_pretrained_gpt()
pretrained_state_dict = pretrained_gpt.state_dict()
#my_gpt.load_state_dict(pretrained_state_dict)

In [22]:
zipped_params = [param for param in zip(pretrained_state_dict.keys(), my_gpt.state_dict().keys())]
new_state_dict = {}
for pretrained, ours in zipped_params:
    new_state_dict[ours] = pretrained_state_dict[pretrained]

my_gpt.load_state_dict(new_state_dict)

<All keys matched successfully>

In [10]:
import transformers

from bert_sol import Bert

my_bert = Bert(
    vocab_size=28996, hidden_size=768, max_position_embeddings=512, 
    type_vocab_size=2, dropout=0.1, intermediate_size=3072, 
    num_heads=12, num_layers=12, L = 11,
)


tokenizer = transformers.AutoTokenizer.from_pretrained("bert-base-cased")

texts = ["My life motto:",
         "My life motto: Fortune",
         "My life motto: Fortune favors",
         "My life motto: Fortune favors the",
         "My life motto: Fortune favors the bold"]


def encoding(texts, i):
    # inputs = t.zeros((len(texts), len(texts[0])))
    tokens = tokenizer(texts, padding = "longest")
    
    inputs = t.tensor(tokens["input_ids"])
    
    my_gpt.eval()
    my_bert.eval()
    
    gpt_encodings = my_gpt(inputs).L_encoding[:,i]
    
    bert_encodings = my_bert(inputs).L_encoding[:,i]
    
    print(gpt_encodings)
    print(bert_encodings)
    
    
encoding(texts,2)

tensor([[ 1.1860,  1.5728,  2.9378,  ..., -0.6505, -0.9438,  1.3388],
        [ 1.1860,  1.5728,  2.9378,  ..., -0.6505, -0.9438,  1.3388],
        [ 1.1860,  1.5728,  2.9378,  ..., -0.6505, -0.9438,  1.3388],
        [ 1.1860,  1.5728,  2.9378,  ..., -0.6505, -0.9438,  1.3388],
        [ 1.1860,  1.5728,  2.9378,  ..., -0.6505, -0.9438,  1.3388]],
       grad_fn=<SelectBackward0>)
tensor([[-0.1044,  0.0473, -0.3324,  ...,  0.3400, -0.0433,  1.3392],
        [-0.0790, -0.0920, -0.2450,  ...,  0.3143, -0.1944,  1.3053],
        [ 0.1155, -0.2108, -0.4075,  ...,  0.2970, -0.1911,  1.2119],
        [ 0.1917, -0.3081, -0.5358,  ...,  0.3087, -0.3592,  1.1066],
        [ 0.1922, -0.5265, -0.5954,  ...,  0.3476, -0.3509,  1.0415]],
       grad_fn=<SelectBackward0>)
