In [2]:
import torch as t
import gpt_tests
import numpy as np
import matplotlib.pyplot as plt
from einops import rearrange, reduce, repeat
from dataclasses import dataclass
from torchtyping import TensorType
import sys
sys.path.append('../w2d1')
from bert_sol import Bert
import transformers
tokenizer = transformers.AutoTokenizer.from_pretrained("gpt2")
bert_tokenizer = transformers.AutoTokenizer.from_pretrained("bert-base-cased")


In [3]:
%load_ext autoreload
%autoreload 2

In [4]:
class UniMultiAttention(t.nn.Module):
    def __init__(self, hidden_size, num_heads):
        super().__init__()
        self.hidden_size = hidden_size
        self.num_heads = num_heads
        self.head_size = hidden_size // num_heads
        self.attention = t.nn.Linear(hidden_size, hidden_size * 3)
        self.softmax = t.nn.Softmax(dim=-1)
        self.projection = t.nn.Linear(hidden_size, hidden_size)
    
    def forward(self, x, past_key_values=None, return_key_values=False): # [batch, seq_len, hidden_size]
        if past_key_values is not None:
            past_k, past_v = t.split(past_key_values, self.head_size, dim=-1)
            x = self.attention(x)
            q, k, v = t.split(x, self.hidden_size, dim=-1)
            q = rearrange(q, "b n (h c) -> b h n c", h = self.num_heads)[0]
            k = rearrange(k, "b n (h c) -> b h n c", h = self.num_heads)[0]
            v = rearrange(v, "b n (h c) -> b h n c", h = self.num_heads)[0]
            new_kv = t.cat((k, v), dim=-1).unsqueeze(0)
            k = t.cat((past_k, k), dim=-2)
            v = t.cat((past_v, v), dim=-2)
            x = t.einsum("hnc,hmc->hnm", q, k)
            x /= np.sqrt(self.head_size)
            x = t.einsum("hnm,hmc->hnc", self.softmax(x), v)
            res = rearrange(x, "h n c -> 1 n (h c)")
        else:
            x = self.attention(x)
            q, k, v = t.split(x, self.hidden_size, dim=-1)
            q = rearrange(q, "b n (h c) -> b h n c", h = self.num_heads)
            k = rearrange(k, "b n (h c) -> b h n c", h = self.num_heads)
            v = rearrange(v, "b n (h c) -> b h n c", h = self.num_heads)
            x = t.einsum("bhnc,bhmc->bhnm", q, k)
            x /= np.sqrt(self.head_size)

            mask = t.arange(0, x.shape[-1]) <= t.arange(0, x.shape[-1]).unsqueeze(1)
            x = t.where(mask, x, t.tensor(-1e4, dtype=t.float))
            x = t.einsum("bhnm,bhmc->bhnc", self.softmax(x), v)
            res = rearrange(x, "b h n c -> b n (h c)")
        if return_key_values:
            return (self.projection(res), new_kv)
        return self.projection(res)

gpt_tests.test_unidirectional_attn(UniMultiAttention)
gpt_tests.test_attn_cache(UniMultiAttention)

Congrats! You've passed the test!
Checking encoding:
Congrats! You've passed the test!
Checking new key and value:
Congrats! You've passed the test!


In [5]:
class GPT2Block(t.nn.Module):
    def __init__(self, hidden_size: int, num_heads: int, dropout: float, layer_norm_epsilon: float) -> None:
        super().__init__()
        self.layer_norm = t.nn.LayerNorm(hidden_size, layer_norm_epsilon)
        self.attention = UniMultiAttention(hidden_size, num_heads)
        self.layers2 = t.nn.Sequential(
            t.nn.LayerNorm(hidden_size, layer_norm_epsilon),
            t.nn.Linear(hidden_size, hidden_size * 4),
            t.nn.GELU(),
            t.nn.Linear(hidden_size*4, hidden_size),
            t.nn.Dropout(dropout)
        )

    def forward(self, x, past_key_values=None, return_key_values=False):
        y = self.layer_norm(x)
        if return_key_values:
            y, kv = self.attention(y, past_key_values, return_key_values)
        else:
            y = self.attention(y)
        z = x + y
        z = z + self.layers2(z)
        if return_key_values:
            return z, kv
        return z
gpt_tests.test_gpt_block(GPT2Block)


Congrats! You've passed the test!


In [6]:
@dataclass
class GPT2Output:
    logits: TensorType["batch_size", "seq_length", "vocab_size"]
    final_encoding: TensorType["batch_size", "hidden_size"]

In [13]:
class GPT2(t.nn.Module):
    def __init__(self, num_layers, num_heads, vocab_size, hidden_size, max_position_embeddings, dropout, layer_norm_epsilon, use_cache = False, tokenizer=None) -> None:
        super().__init__()
        self.token_embedding_layer = t.nn.Embedding(vocab_size, hidden_size)
        self.position_embedding_layer = t.nn.Embedding(max_position_embeddings, hidden_size)
        self.gpt_layers = [GPT2Block(hidden_size, num_heads, dropout, layer_norm_epsilon) for x in range(num_layers)]
        self.dropout = t.nn.Dropout(dropout)
        self.layers = t.nn.Sequential(
            *self.gpt_layers,
        )
        self.layer_norm = t.nn.LayerNorm(hidden_size, layer_norm_epsilon)
        self.num_heads = num_heads
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        self.cached_kv = t.zeros((num_layers, num_heads, 0, 2 * hidden_size // num_heads))
        self.cached_encodings = t.zeros((0, hidden_size))
        self.tokenizer = tokenizer
    
    def clear_cache(self):
        self.cached_kv = t.zeros((self.num_layers, self.num_heads, 0, 2 * self.hidden_size // self.num_heads))
        self.cached_encodings = t.zeros((0, self.hidden_size))

    def forward(self, input_ids, use_cache = False): # [batch, seq_len]
        input_ids = self.token_embedding_layer(input_ids) + self.position_embedding_layer(t.arange(input_ids.shape[-1]))
        input_ids = self.dropout(input_ids)
        if use_cache:
            # if self.cached_kv.shape[2] == 0: # cache empty

            temp = t.zeros((1, self.num_heads, 0, 2 * self.hidden_size // self.num_heads))
            for i, layer in enumerate(self.gpt_layers):
                input_ids, kv = layer(input_ids, past_key_values = self.cached_kv[i], return_key_values = True)
                temp = t.cat((temp, kv), dim=-2)
            temp = rearrange(temp, "1 heads (layers seq) kv -> layers heads seq kv", layers=self.num_layers)
            self.cached_kv = t.cat((self.cached_kv, temp), dim=-2)
        else:
            input_ids = self.layers(input_ids)
        self._enc = input_ids
        if use_cache:
            self.cached_encodings = t.cat((self.cached_encodings, self._enc[0]), dim=0)
            # input_ids = self.cached_encodings
        input_ids = self.layer_norm(input_ids)
        x = input_ids @ self.token_embedding_layer.weight.T
        return GPT2Output(x[...,x.shape[-2]-1,:], input_ids[...,input_ids.shape[-2]-1,:])
    
    def next_token(self, input_ids, temperature, freq_penalty=2.0): # input_ids : [seq_len]
        out = self.forward(input_ids, use_cache=True)
        logits = out.logits[-1]
        id_frequencies = t.zeros(logits.shape)
        id_frequencies[input_ids] += 1
        probs = t.softmax(logits/temperature - id_frequencies * freq_penalty, dim=-1)
        return t.argmax(probs)
    
    def generate(self, text, max_length=30, temperature=1.0, freq_penalty=2.0):
        self.clear_cache()
        tokens = self.tokenizer.encode(text)
        for i in range(max_length):
            next = self.next_token(t.tensor(tokens).unsqueeze(0), temperature, freq_penalty)
            tokens.append(next)
            if next == tokenizer.eos_token_id:
                break
        return tokenizer.decode(tokens)



gpt_tests.test_gpt(GPT2)
gpt_tests.test_gpt_cache(GPT2)

Checking logits:
Congrats! You've passed the test!
Checking final encodings:
Congrats! You've passed the test!
Congrats! Your GPT returns the same results with and without cache.
It took 2.176s to generate a 500-token sentence without cache and 1.873s with cache.


In [14]:
my_gpt = GPT2(num_layers=12, num_heads=12, vocab_size=50257, hidden_size=768, max_position_embeddings=1024, dropout=0.1, layer_norm_epsilon=1e-5, tokenizer=tokenizer, use_cache = True)

pretrained_gpt = gpt_tests.get_pretrained_gpt()


In [15]:
print(len([x for x in my_gpt.state_dict()]))
print(len([x for x in pretrained_gpt.state_dict()]))
mapping = {x: y for x, y in zip(my_gpt.state_dict(), pretrained_gpt.state_dict())}
state_dict = pretrained_gpt.state_dict()
final_dict = {}
for x, y in mapping.items():
    final_dict[x] = state_dict[y]
my_gpt.load_state_dict(final_dict)

148
148


<All keys matched successfully>

In [16]:
my_gpt.eval()
print()




In [26]:
print(my_gpt.generate("my life motto is: ", max_length = 60, freq_penalty=50))

In [96]:
my_gpt.generate("Hello, my name is ")

'Hello, my name is?"\n\n\n\n\n"I am the name\n"is it?\n" is "the name\nThe name\nIs it\n"'

In [23]:
print(my_gpt.generate("Hello, I am ", max_length = 60))

Hello, I am  a man.
I'm a man.
I am a man. I am a man. I am a man. I am a man. I am a man. I am a man. I am a man. I am a man. I am a man. I am a man


In [25]:
print(my_gpt.generate("Hello, I am ", max_length = 60, freq_penalty = 20))

Hello, I am  a man.
I'm a woman. And this is the one of course it's not so, and that she was born to be my wife who are both men were all women in her husband as well for me."<|endoftext|>


In [92]:
tokenizer.eos_token_id

50256

In [163]:
bert = Bert(
    num_layers=12, num_heads=12, vocab_size=bert_tokenizer.vocab_size, hidden_size=768, intermediate_size=768,
    max_position_embeddings=1024, dropout=0.1, type_vocab_size=2
    )

In [171]:

my_gpt.eval()

def gpt_encodings(text):
    my_gpt(t.tensor(tokenizer.encode(text)).unsqueeze(0))
    return t.argmax(t.softmax(my_gpt._enc, dim=-1))

def bert_encodings(text):
    bert(t.tensor(bert_tokenizer.encode(text)).unsqueeze(0))
    return t.softmax(bert._enc, dim=-1)

def gpt_pred(text):
    logits = my_gpt(t.tensor(tokenizer.encode(text)).unsqueeze(0)).logits
    logits = t.nn.functional.softmax(logits, dim=-1)[0]
    print(logits)
    print(t.argmax(logits))
    print(logits[t.argmax(logits)])
    return tokenizer.decode(t.argmax(logits))

#gpt_pred("My life motto: \"")

In [194]:
# #print([" ".join([str(y) for y in x]) for x in gpt_encodings("My life motto:")])
# print
# print(gpt_encodings("motto"))

# print(bert_encodings("My life motto:"))
# print(bert_encodings("motto"))


In [224]:
s1 = "My life motto:"
s2 = "My life motto: Fortune favors"
motto1 = gpt_encodings(s1)[:,1]
motto2 = gpt_encodings(s2)[:,1]
bmotto1 = bert_encodings(s1)[:,3]
bmotto2 = bert_encodings(s2)[:,3]


In [225]:
bert_tokenizer.decode(bert_tokenizer.encode(s1)[3])

'motto'

In [226]:
print("GPT", (motto1 == motto2).all())
print("BERT", (bmotto1 == bmotto2).all())

GPT tensor(True)
BERT tensor(False)


In [165]:
bert(t.tensor(bert_tokenizer.encode("testing testing a")).unsqueeze(0))

tensor([[[-0.5171,  0.6661, -0.6133,  ..., -0.3239, -1.1028,  0.7504],
         [-0.2633,  0.4830, -0.1012,  ...,  0.0960, -0.0229,  0.0053],
         [-0.7542,  0.5852, -0.5270,  ...,  0.6227, -0.1321, -0.0600],
         [-0.3993,  0.9214, -0.8342,  ..., -0.3123, -0.5261,  0.4948],
         [-0.2962,  1.1724, -0.3336,  ..., -0.1051, -0.3679,  0.2982]]],
       grad_fn=<AddBackward0>)

In [166]:
bert._enc.shape

torch.Size([1, 5, 768])

In [193]:
tokenizer.encode("this is a motto")

[5661, 318, 257, 33600]

In [None]:
motto1 = gpt_encodings("hello world")
motto2 = gpt_encodings("hello world how")