In [1]:
import gpt_tests
from bert_sol import BertEmbedding, BertBlock, mapkey
from bert_tests import get_pretrained_bert

import math

import torch as t
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt

import transformers
from einops import rearrange, reduce, repeat
from dataclasses import dataclass
from torch import einsum
from torchtyping import TensorType
from typing import Optional

In [2]:
DEVICE = t.device("cpu")

In [3]:
class Attention(nn.Module):
    
    def __init__(self, hidden_size: int, num_heads: int):
        super().__init__()
        self.hidden_size = hidden_size
        self.num_heads = num_heads
        self.head_size = hidden_size // num_heads
        self.attn_layer = nn.Linear(hidden_size, 3*hidden_size)
        self.output_layer = nn.Linear(hidden_size, hidden_size)
        
    def forward(
            self,
            input: TensorType["batch_size", "seq_len", "hidden_size"],
            past_key_values: Optional[TensorType["num_heads", "seq_len", "2*head_size"]] = None,
            return_key_values: bool = False,
            neg_inf=-1e4
        ):

        if past_key_values is None:
            batch_size, seq_len, hidden_size = input.shape
            queries, keys, values = t.split(self.attn_layer(input), hidden_size, dim=-1)
            queries = rearrange(queries, "b n (h p)-> b h n p", h=self.num_heads)
            keys = rearrange(keys, "b n (h p) -> b h n p", h=self.num_heads)
            values = rearrange(values, "b n (h p) -> b h n p", h=self.num_heads)
            attn_score = einsum("b h t p, b h f p -> b h t f", queries, keys) / math.sqrt(self.head_size)
            mask = t.ones((seq_len, seq_len), dtype=t.bool, device=DEVICE).triu(diagonal=1)
            mask = repeat(mask, "t f -> b h t f", b=batch_size, h=self.num_heads)
            attn_score[mask] = neg_inf

            # calculate output
            attn_pattern = t.softmax(attn_score, dim=-1)
            attn_concat = einsum("b h t f, b h f p -> b h t p", attn_pattern, values)
            attn_concat = rearrange(attn_concat, "b h t p -> b t (h p)")
            
            return self.output_layer(attn_concat)
        else:
            _, _, hidden_size = input.shape
            num_heads, seq_len, _ = past_key_values.shape
            input = input[0, 0]
            query, key, value = t.split(self.attn_layer(input), hidden_size, dim=-1)
            query = rearrange(query, "(h p) -> h p", h=num_heads)
            key = rearrange(key, "(h p) -> h p", h=num_heads).unsqueeze(-2)
            value = rearrange(value, "(h p) -> h p", h=num_heads).unsqueeze(-2)
            past_keys, past_values = t.split(past_key_values, self.head_size, dim=-1)
            keys = t.cat((past_keys, key), dim=-2)
            values = t.cat((past_values, value), dim=-2)
            attn_score = einsum("h p, h n p -> h n", query, keys) / math.sqrt(self.head_size)
            attn_pattern = t.softmax(attn_score, dim=-1)
            attn = einsum("h n, h n p -> h p", attn_pattern, values)
            output = self.output_layer(rearrange(attn, "h p -> (h p)"))
            return (
                output,
                t.cat((key, value), dim=-1)
            ) if return_key_values else output

gpt_tests.test_unidirectional_attn(Attention)
gpt_tests.test_attn_cache(Attention)

Congrats! You've passed the test!
Checking encoding:
Congrats! You've passed the test!
Checking new key and value:
Congrats! You've passed the test!


In [6]:
class Block(nn.Module):

    def __init__(
        self,
        hidden_size: int,
        num_heads: int,
        dropout: float,
        layer_norm_epsilon: float
    ):
        super().__init__()
        self.layer_norm_1 = nn.LayerNorm(hidden_size, eps=layer_norm_epsilon)
        self.attention = Attention(hidden_size, num_heads)
        self.layer_norm_2 = nn.LayerNorm(hidden_size, eps=layer_norm_epsilon)
        self.linear_1 = nn.Linear(hidden_size, 4*hidden_size)
        self.linear_2 = nn.Linear(4*hidden_size, hidden_size)
        self.dropout = nn.Dropout(p=dropout)

    def forward(
            self,
            input: TensorType["batch_size", "seq_len", "hidden_size"],
            past_key_values: Optional[TensorType["num_heads", "seq_len", "2*head_size"]] = None,
            return_key_values: bool = False,
        ):
        if past_key_values is None:
            residual_attn = input + self.attention(self.layer_norm_1(input))
            normed_attn = self.layer_norm_2(residual_attn)
            mlp = self.linear_2(F.gelu(self.linear_1(normed_attn)))
            dropout = self.dropout(mlp)
            return residual_attn + dropout
        else:
            residual_attn, key_values = input + self.attention(
                self.layer_norm_1(input),
                past_key_values=past_key_values,
                return_key_values=True,
            )
            normed_attn = self.layer_norm_2(residual_attn)
            mlp = self.linear_2(F.gelu(self.linear_1(normed_attn)))
            dropout = self.dropout(mlp)
            output = residual_attn + dropout
            return (output, key_values) if return_key_values else output


gpt_tests.test_gpt_block(Block)

Congrats! You've passed the test!


In [None]:
@dataclass
class GPT2Output:
    logits: TensorType["batch_size", "vocab_size"]
    final_encoding: TensorType["batch_size", "hidden_size"]
    encoding: TensorType["batch_size", "seq_len", "hidden_size"]

In [7]:
t.zeros(2, 0)

tensor([], size=(2, 0))

In [None]:
class GPT2(nn.Module):

    def __init__(
        self,
        num_layers: int,
        num_heads: int,
        vocab_size: int,
        hidden_size: int,
        max_position_embeddings: int,
        dropout: float,
        layer_norm_epsilon: float,
        use_cache: bool = False,
    ):
        super().__init__()
        self.token_embedding = nn.Embedding(vocab_size, hidden_size)
        self.pos_embedding = nn.Embedding(max_position_embeddings, hidden_size)
        self.dropout = nn.Dropout(p=dropout)
        self.blocks = nn.Sequential(
            *(Block(
                hidden_size=hidden_size,
                num_heads=num_heads,
                dropout=dropout,
                layer_norm_epsilon=layer_norm_epsilon
            ) for _ in range(num_layers))
        )
        self.layer_norm = nn.LayerNorm(hidden_size, eps=layer_norm_epsilon)
        head_size = hidden_size // num_heads
        self.use_cache = use_cache
        self.cache = t.zeros(num_layers, num_heads, 0, 2 * head_size, device=DEVICE)

    def forward(self, input_ids: TensorType["batch_size", "seq_length"]):
        if self.use_cache:
            for block in self.blocks:
                keys = block()
        else:
            batch_size, seq_len = input_ids.shape
            token_embedding = self.token_embedding(input_ids)
            pos_embedding = self.pos_embedding(t.arange(seq_len, device=DEVICE))
            embedding = self.dropout(token_embedding + pos_embedding)
            block_output = self.blocks(embedding)
            enconding = self.layer_norm(block_output)
            final_encoding = enconding[:,-1,:]
            logits = final_encoding @ self.token_embedding.weight.T
            return GPT2Output(
                logits=logits,
                final_encoding=final_encoding,
                encoding=enconding
            )


In [None]:
gpt_tests.test_gpt(GPT2)

In [None]:
our_gpt = GPT2(num_layers=12, num_heads=12, vocab_size=50257, hidden_size=768, max_position_embeddings=1024, dropout=0.1, layer_norm_epsilon=1e-5)
their_gpt = gpt_tests.get_pretrained_gpt()

In [None]:
def align_state_dicts(our_model, their_model):
    print(f"{len(our_model.state_dict())} {len(their_model.state_dict())}")

    for our_key, their_key in zip(our_model.state_dict(), their_model.state_dict()):
        print(f"{our_key:40}{their_key}")

# align_state_dicts(our_gpt, their_gpt)

In [None]:
def transfer_their_weights_to_ours(our_gpt, their_gpt):
    new_state_dict = {key: their_value for key, their_value in zip(our_gpt.state_dict().keys(), their_gpt.state_dict().values())}
    our_gpt.load_state_dict(new_state_dict)

transfer_their_weights_to_ours(our_gpt=our_gpt, their_gpt=their_gpt)

In [None]:
class Bert(nn.Module):
    def __init__(self, vocab_size, hidden_size, max_position_embeddings, type_vocab_size,
                 dropout, intermediate_size, num_heads, num_layers):
        super().__init__()
        self.embed = BertEmbedding(vocab_size, hidden_size, max_position_embeddings,
                                   type_vocab_size, dropout)
        self.blocks = nn.Sequential(*[
            BertBlock(hidden_size, intermediate_size, num_heads, dropout)
            for _ in range(num_layers)
        ])
        self.lin = nn.Linear(hidden_size, hidden_size)
        self.layer_norm = nn.LayerNorm(hidden_size)
        self.unembed = nn.Linear(hidden_size, vocab_size)
        
    def forward(self, input_ids):
        token_type_ids = t.zeros_like(input_ids, dtype=int)
        emb = self.embed(input_ids, token_type_ids)
        self._enc = enc = self.blocks(emb)
        enc = self.lin(enc)
        logits = self.unembed(self.layer_norm(F.gelu(enc)))
        return GPT2Output(
            logits=logits[:,-1,:],
            final_encoding=enc[:,-1,:],
            encoding=enc
        )

config = {
    "vocab_size": 28996,
    "intermediate_size": 3072,
    "hidden_size": 768,
    "num_layers": 12,
    "num_heads": 12,
    "max_position_embeddings": 512,
    "dropout": 0.1,
    "type_vocab_size": 2,
}

our_bert = Bert(**config)
pretrained_bert = get_pretrained_bert()
mapped_state_dict = { mapkey(k): v for k, v in pretrained_bert.state_dict().items() if not k.startswith('classification_head') }
our_bert.load_state_dict(mapped_state_dict)

In [None]:
sentences = [
    "My life motto:",
    "My life motto: Fortune",
    "My life motto: Fortune favors",
    "My life motto: Fortune favors the",
    "My life motto: Fortune favors the bold"
]

tokenizer = transformers.AutoTokenizer.from_pretrained("bert-base-cased")
input_ids = t.tensor(tokenizer(sentences, padding="longest").input_ids)

In [None]:
our_gpt.eval()
our_bert.eval()

encodings_gpt = our_gpt(input_ids).encoding
encodings_bert = our_bert(input_ids).encoding

for i in range(len(sentences)):
    print(f"Sentence #{i + 1}")
    print(f"GPT: {encodings_gpt[i, 3, :10]}\nBERT: {encodings_bert[i, 3, :10]}")

### Is the embedding matrix orthogonal?

In [None]:
token_embedding_weight = our_gpt.token_embedding.weight
matrix = t.detach(token_embedding_weight.T @ token_embedding_weight)

In [None]:
normed_matrix = (matrix - matrix.mean()) / matrix.std()

In [None]:
normed_matrix.eig()

In [None]:
normed_matrix.trace() / 768

In [None]:
t.norm(normed_matrix - 762.1544 * t.eye(*matrix.shape))

In [None]:
matrix.abs().trace() / matrix.abs().sum(dim=1).mean()

In [None]:
plt.imshow(t.abs(matrix) / t.norm(matrix, dim=0, keepdim=True))

In [None]:
plt.imshow(t.eye(*matrix.shape) / t.norm(matrix, dim=0, keepdim=True))