In [1]:
%load_ext autoreload

In [2]:
import wandb
import random

In [4]:
%autoreload 2

from lovely_tensors.patch import monkey_patch; monkey_patch()
import torch
from transformers import GPT2Tokenizer

In [5]:
import math


def new_gelu(input):
    return (
        0.5
        * input
        * (
            1.0
            + torch.tanh(
                math.sqrt(2.0 / math.pi) * (input + 0.044715 * torch.pow(input, 3.0))
            )
        )
    )


def conv_1d(x, weight, bias=None):
    size_out = x.size()[:-1] + (weight.size(-1),)
    x = torch.addmm(bias, x.view(-1, x.size(-1)), weight)
    x = x.view(size_out)
    return x


def transformer_block(i, input_hidden_state, model_state):
    def block_state(key):
        return model_state[f"h.{i}.{key}"]

    def attention_state(key):
        return model_state[f"h.{i}.attn.{key}"]

    # attention block
    ln1 = torch.nn.functional.layer_norm(
        input=input_hidden_state,
        weight=block_state("ln_1.weight"),
        bias=block_state("ln_1.bias"),
        normalized_shape=(768,),
    )

    w_q, w_k, w_v = attention_state("c_attn.weight").chunk(3, dim=1)
    b_q, b_k, b_v = attention_state("c_attn.bias").chunk(3, dim=0)

    q = conv_1d(ln1, w_q, b_q)
    k = conv_1d(ln1, w_k, b_k)
    v = conv_1d(ln1, w_v, b_v)

    q_chunked = torch.stack(q.chunk(12, dim=-1))
    k_chunked = torch.stack(k.chunk(12, dim=-1))
    v_chunked = torch.stack(v.chunk(12, dim=-1))

    attention = torch.matmul(q_chunked, k_chunked.transpose(-1, -2))

    attention_rescaled = attention / (64**0.5)

    mask = torch.triu(torch.ones_like(attention_rescaled), diagonal=1).bool()
    attention_masked = attention_rescaled.masked_fill(
        mask, torch.finfo(torch.float32).min
    )

    attention_softmaxed = torch.nn.functional.softmax(attention_masked, dim=-1)
    attention_output = torch.matmul(attention_softmaxed, v_chunked)

    out_tuple = [x[0] for x in attention_output.chunk(12, dim=0)]
    combined_attention_output = torch.cat(out_tuple, dim=-1)

    w_cproj = attention_state("c_proj.weight")
    b_cproj = attention_state("c_proj.bias")

    crosstalk = conv_1d(combined_attention_output, w_cproj, b_cproj)
    after_residual = crosstalk + input_hidden_state

    # mlp block
    before_ln2 = after_residual

    ln2 = torch.nn.functional.layer_norm(
        input=after_residual,
        weight=block_state("ln_2.weight"),
        bias=block_state("ln_2.bias"),
        normalized_shape=(768,),
    )

    w_fc = block_state("mlp.c_fc.weight")
    b_fc = block_state("mlp.c_fc.bias")

    after_up = conv_1d(ln2, w_fc, b_fc)
    activated = new_gelu(after_up)

    w_proj = block_state("mlp.c_proj.weight")
    b_proj = block_state("mlp.c_proj.bias")

    after_down = conv_1d(activated, w_proj, b_proj)

    after_residual_2 = after_down + before_ln2

    return after_residual_2


def transformer(token_ids, model_state):
    token_embeddings = model_state["wte.weight"][token_ids]
    positions = torch.arange(len(token_ids))  # [0,1,2,3...]
    position_embeddings = model_state["wpe.weight"][positions]
    embeddings = token_embeddings + position_embeddings

    hs = embeddings
    for i in range(12):
        hs = transformer_block(i, hs, model_state)

    ln_w = model_state["ln_f.weight"]
    ln_b = model_state["ln_f.bias"]

    ln = torch.nn.functional.layer_norm(
        input=hs, weight=ln_w, bias=ln_b, normalized_shape=(768,)
    )

    return ln

In [6]:
def language_model(token_ids, model_state):
    transformer_output = transformer(token_ids, model_state)
    logits = torch.matmul(transformer_output, model_state["wte.weight"].T)
    return logits

In [7]:
token_ids = torch.tensor([1, 2, 3, 4, 5, 6, 7, 8, 9, 10])
model_params = torch.load("pytorch_model.bin")

In [8]:
out = language_model(token_ids, model_params)
out

# out[0].chans(scale=4, cmap="seismic")
# out[1].chans(scale=4, cmap="seismic")
# out[2].chans(scale=4, cmap="seismic")
# out[3]

tensor[10, 50257] n=502570 (1.9Mb) x∈[-110.039, -29.114] μ=-76.757 σ=15.804