In [1]:
import torch
from torch import einsum
from torch import nn 
from einops import rearrange, reduce, repeat
import math
import random
from collections import OrderedDict
import transformers
import torchtext
from tqdm import tqdm
import matplotlib.pyplot as plt
import gpt_tests
import bert_sol
import bert_tests

In [2]:
class UnidirectionalMultiheadAttention(nn.Module):
    def __init__(self, hidden_size: int, num_heads: int):
        super().__init__()
        self.hidden_size = hidden_size
        self.num_heads = num_heads 
        self.head_size = hidden_size // num_heads
        assert self.head_size * num_heads == hidden_size
        self.attentionLL = nn.Linear(hidden_size, num_heads*self.head_size*3) 
        self.outputLL = nn.Linear(num_heads*self.head_size, hidden_size)
    
    def forward(self, x: torch.Tensor): # [batch, seq_length, hidden_size]
        # Shape: batch seq_len hidden_size*3
        KQV = self.attentionLL(x)
        KQV = rearrange(KQV, "batch seq_len (three num_heads head_size) -> batch num_heads seq_len head_size three ", num_heads=self.num_heads, three=3)
        Q = KQV[:, :, :, :, 0]
        K = KQV[:, :, :, :, 1]
        V = KQV[:, :, :, :, 2]
        # Multiplying K and Q
        attention_pattern = einsum('b n s h, b n t h -> b n s t', K, Q)
        # Scale
        attention_pattern = attention_pattern / math.sqrt(self.head_size)
        # Key (row) must be less than Query (col), if not we set it to 1e-4
        # print(torch.triu(attention_pattern))
        # print((-1e4) * torch.tril(torch.ones_like(attention_pattern), diagonal=-1))
        attention_pattern = torch.triu(attention_pattern) + (-1e4) * torch.tril(torch.ones_like(attention_pattern), diagonal=-1)        
        # Softmax: batch num_heads key_len query_len, so we want to softmax over the keys
        #  so dim=2
        attention_pattern = torch.nn.Softmax(dim=2)(attention_pattern)
        # Multiply by V
        out = einsum('b n k q, b n k h -> b n q h', attention_pattern, V)
        out = rearrange(out, 'batch num_heads seq_len head_size -> batch seq_len (num_heads head_size)')
        out = self.outputLL(out) 
        #print(out)
        return out

gpt_tests.test_unidirectional_attn(UnidirectionalMultiheadAttention)

Congrats! You've passed the test!


In [3]:
class GPT2Block(nn.Module):
    def __init__(self, hidden_size: int, num_heads: int, 
                dropout: float, layer_norm_epsilon: float):
        super().__init__()
        self.ln1 = nn.LayerNorm(hidden_size, eps=layer_norm_epsilon)
        self.attn = UnidirectionalMultiheadAttention(hidden_size, num_heads)
        self.ln2 = nn.LayerNorm(hidden_size, eps=layer_norm_epsilon)
        self.linear1 = nn.Linear(hidden_size, hidden_size * 4)
        self.linear2 = nn.Linear(hidden_size * 4, hidden_size)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x: torch.Tensor):
        x = x + self.attn(self.ln1(x))
        x = x + self.dropout(self.linear2(torch.nn.functional.gelu(self.linear1(self.ln2(x)))))
        return x 

gpt_tests.test_gpt_block(GPT2Block)

Congrats! You've passed the test!


In [4]:
from dataclasses import dataclass
from torchtyping import TensorType

@dataclass
class GPT2Output:
    logits: TensorType["batch_size", "vocab_size"]
    final_encoding: TensorType["batch_size", "hidden_size"]


In [5]:
class GPT2(nn.Module):
    def __init__(self, num_layers, num_heads, vocab_size,
                hidden_size, max_position_embeddings, dropout, 
                layer_norm_epsilon):
        super().__init__()
        self.token_embedding = nn.Embedding(vocab_size, hidden_size)
        self.position_embedding = nn.Embedding(max_position_embeddings, hidden_size)
        self.dropout = nn.Dropout(dropout)
        self.GPTBlocks = nn.Sequential(
            *[GPT2Block(hidden_size, num_heads, dropout, layer_norm_epsilon) 
                for i in range(num_layers)]
        )
        self.last_token_encodings = None
        self.layer_norm = nn.LayerNorm(hidden_size, layer_norm_epsilon)

    def forward(self, input_ids): # [batch, seq_len]
        tokens = self.token_embedding(input_ids)
        batch, seq_len = input_ids.shape
        position_ids = repeat(torch.arange(seq_len), 's -> b s', b = batch) 
        positions = self.position_embedding(position_ids)
        embedding = tokens + positions
        x = self.dropout(embedding)
        x = self.GPTBlocks(x)
        self.last_token_encodings = x
        final_encodings = self.layer_norm(x)[:,-1,:]
        logits = einsum('b c, v c -> b v', final_encodings, self.token_embedding.weight)
        return GPT2Output(logits, final_encodings)

gpt_tests.test_gpt(GPT2) 

Checking final encodings:
Congrats! You've passed the test!
Checking logits:
Congrats! You've passed the test!


In [6]:
my_gpt = GPT2(num_layers=12, num_heads=12, vocab_size=50257, 
    hidden_size=768, max_position_embeddings=1024, dropout=0.1, 
    layer_norm_epsilon=1e-5)

pretrained_gpt = gpt_tests.get_pretrained_gpt()

In [7]:
new_dict = {}
pretrained_dict = pretrained_gpt.state_dict()
for a,b in zip(pretrained_dict.keys(), my_gpt.state_dict().keys()):
    new_dict[b] = pretrained_dict[a] 
my_gpt.load_state_dict(new_dict)


<All keys matched successfully>

In [14]:
tokenizer = transformers.GPT2Tokenizer.from_pretrained("gpt2")
bert_tokenizer = transformers.AutoTokenizer.from_pretrained("bert-base-cased")

In [9]:
def generateText(model, text):
    model.eval()
    tokens = torch.tensor(tokenizer.encode(text), dtype=torch.long).unsqueeze(0)
    gpt_output = model(tokens)
    predicted_word = gpt_output.logits.argmax(dim=-1)
    new_text = text + tokenizer.decode(predicted_word.item())
    return new_text

text = "My life motto:"
for _ in range(10):
    text = generateText(my_gpt, text)
    print(text)

My life motto: "
My life motto: "Don
My life motto: "Don't
My life motto: "Don't be
My life motto: "Don't be afraid
My life motto: "Don't be afraid to
My life motto: "Don't be afraid to be
My life motto: "Don't be afraid to be yourself
My life motto: "Don't be afraid to be yourself."
My life motto: "Don't be afraid to be yourself."



In [10]:
my_bert = bert_sol.Bert(
    vocab_size=28996, hidden_size=768, max_position_embeddings=512, 
    type_vocab_size=2, dropout=0.1, intermediate_size=3072, 
    num_heads=12, num_layers=12
)
pretrained_bert = bert_tests.get_pretrained_bert()
mapped_params = {bert_sol.mapkey(k): v for k, v in pretrained_bert.state_dict().items()
                if not k.startswith('classification_head')}
my_bert.load_state_dict(mapped_params)

Some weights of the model checkpoint at bert-base-cased were not used when initializing BertForMaskedLM: ['cls.seq_relationship.bias', 'cls.seq_relationship.weight']
- This IS expected if you are initializing BertForMaskedLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForMaskedLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


<All keys matched successfully>

In [11]:
# We now have my_bert, my_gpt

In [19]:
# this is to demonstrate that in gpt the logits for a token 
# only depends on earlier tokens in the sequence, whereas that's not
# true for BERT
strings = """My life motto:
My life motto: Fortune
My life motto: Fortune favors
My life motto: Fortune favors the
My life motto: Fortune favors the bold"""
strings = strings.split("\n")
strings = [bert_tokenizer.encode(s) for s in strings]
pad_length = 10
strings = torch.tensor([s + [0]*(pad_length-len(s)) for s in strings], dtype = torch.long)
print(strings)

my_bert.eval()
my_bert(strings)
print("BERT")
print(my_bert.last_token_encodings[:,3,:])

my_gpt.eval()
my_gpt(strings) 
print("GPT")
print(my_gpt.last_token_encodings[:,3,:])


tensor([[  101,  1422,  1297, 13658,   131,   102,     0,     0,     0,     0],
        [  101,  1422,  1297, 13658,   131, 14555,   102,     0,     0,     0],
        [  101,  1422,  1297, 13658,   131, 14555, 24208,   102,     0,     0],
        [  101,  1422,  1297, 13658,   131, 14555, 24208,  1103,   102,     0],
        [  101,  1422,  1297, 13658,   131, 14555, 24208,  1103,  9009,   102]])
BERT
tensor([[ 1.0108,  1.0003, -0.0554,  ...,  0.4157,  0.4293,  0.1067],
        [ 0.4482, -0.0443,  0.1862,  ..., -0.3393,  0.4696,  0.0145],
        [ 0.0276, -0.0032,  0.2345,  ...,  0.3751,  0.4144,  0.0836],
        [ 0.1360, -0.0919,  0.2080,  ...,  0.5545,  0.4430,  0.1225],
        [ 0.0215, -0.0751,  0.1236,  ...,  0.6215,  0.4317,  0.2077]],
       grad_fn=<SliceBackward0>)
torch.Size([5, 10, 768])
tensor([[ 10.6008,   0.4824,  -0.8409,  ...,   3.7044,   2.1718,  -1.7908],
        [  2.2354,   4.2186,  -8.5184,  ...,  -1.7431,   6.8043, -18.0181],
        [  2.2354,   4.2186,  -8.