In [1]:
%matplotlib notebook

In [None]:
from dataclasses import dataclass
import einops
from flax import nnx
import jax.numpy as jnp
import jax
import tokenizer/tokneizer

## Jax Transformer Implementation

In [74]:
@dataclass
class TransformerConfig:
    debug: bool = True
    d_model: int = 768
    d_vocab: int = 50257
    d_head: int = 64
    n_layers: int = 12
    n_heads: int = 12
    ctx_len: int = 1024
    stddev: float = 0.02
    d_mlp: int = d_model*4

In [75]:
class LayerNorm(nnx.Module):
    def __init__(self, cfg: TransformerConfig, key, eps: float = 1e-05):
        self.cfg = cfg
        self.d_model = self.cfg.d_model
        self.w = nnx.Param(jax.random.normal(key, (self.d_model))) # [d_model]
        self.b = nnx.Param(jnp.zeros(self.d_model,)) # [d_model]
        self.eps = eps
    
    def __call__(self, residual: jax.Array):
        # resdiual: [batch x len x d_model]
        # Make mean 0 and normalize to have variance 1
        y = (residual - jnp.mean(residual, axis=1, keepdims=True)) / (jnp.sqrt(jnp.var(residual) + self.eps))
        # Scale with learned weights
        y = y * self.w
        # Translate with learned bias
        y = y + self.b
        return y

In [76]:
class Embed(nnx.Module):
    def __init__(self, cfg: TransformerConfig, key):
        self.cfg = cfg
        self.key = key
        self.W_E = nnx.Param(jax.random.normal(self.key, (self.cfg.d_vocab, self.cfg.d_model)) * self.cfg.stddev)

    def __call__(self, tokens: jnp.ndarray) -> jnp.ndarray:
        # tokens: [batch length]
        return self.W_E[tokens]

In [77]:
class PosEmbed(nnx.Module):
    def __init__(self, cfg: TransformerConfig, key):
        self.cfg = cfg
        self.key = key
        self.W_pos = nnx.Param(jax.random.normal(self.key, (cfg.ctx_len, cfg.d_model)) * self.cfg.stddev)

    def __call__(self, tokens: jnp.ndarray) -> jnp.ndarray:
        # tokens: [batch length]
        batch, length = tokens.shape
        return einops.repeat(self.W_pos[:length], 'length d_model -> batch length d_model', batch=batch)

In [96]:
class Attention(nnx.Module):
    def __init__(self, cfg: TransformerConfig, key):
        self.cfg = cfg
        self.key = key
        self.W_Q = nnx.Param(jax.random.normal(self.key, (cfg.n_heads, cfg.d_model, cfg.d_head))) # [num_heads, d_model, d_head]
        self.W_K = nnx.Param(jax.random.normal(self.key, (cfg.n_heads, cfg.d_model, cfg.d_head))) # [num_heads, d_model, d_head]
        self.W_V = nnx.Param(jax.random.normal(self.key, (cfg.n_heads, cfg.d_model, cfg.d_head))) # [num_heads, d_model, d_head]
        self.W_O = nnx.Param(jax.random.normal(self.key, (cfg.n_heads, cfg.d_head, cfg.d_model))) # [num_heads, d_head, d_model]
        self.b_Q = nnx.Param(jnp.zeros((cfg.n_heads, cfg.d_head)))
        self.b_K = nnx.Param(jnp.zeros((cfg.n_heads, cfg.d_head)))
        self.b_V = nnx.Param(jnp.zeros((cfg.n_heads, cfg.d_head)))
        self.b_O = nnx.Param(jnp.zeros((cfg.d_model)))

    def __call__(self, normal_pre_resid: jnp.ndarray) -> jnp.ndarray:
        """
        b = batch
        l = length
        m = d_model
        n = num_heads
        h = d_head
        q = q_pos
        k = k_pos
        """
        # normal_pre_resid: [batch length d_model]
        q = jnp.einsum('blm, nmh -> blnh', normal_pre_resid, self.W_Q) + self.b_Q
        k = jnp.einsum('blm, nmh -> blnh', normal_pre_resid, self.W_K) + self.b_K
        v = jnp.einsum('blm, nmh -> blnh', normal_pre_resid, self.W_V) + self.b_V

        attn_scores = jnp.einsum('bqnh, bknh -> bnqk', q, k)
        attn_scores = self.apply_casual_mask(attn_scores / self.cfg.d_head ** 0.5)
        attn_probs = jax.nn.softmax(attn_scores, axis=-1) # [batch x n_heads x q_pos x k_pos]

        # [batch x q_pos x n_heads x d_head]
        z = jnp.einsum('bnqk, bknh -> bqnh', attn_probs, v)

        out = jnp.einsum('bqnh, nhm -> bqnm', z, self.W_O)
        out = jnp.einsum('bqnm -> bqm', out) + self.b_O
        return out

    def apply_casual_mask(self, attn_scores: jnp.ndarray) -> jnp.ndarray:
        # attn_scores: [batch n_heads q_pos k_pos]
        mask = jnp.triu(attn_scores).astype(bool)
        masked_attn_scores = jnp.where(mask,jax.lax.broadcast(-jnp.inf, attn_scores.shape), attn_scores)
        
        return masked_attn_scores

In [113]:
class MLP(nnx.Module):
    def __init__(self, cfg: TransformerConfig, key):
        self.cfg = cfg
        self.key = key
        self.W_in = nnx.Param(jax.random.normal(self.key, (cfg.d_model, cfg.d_mlp))) # [d_model, d_mlp]
        self.W_out = nnx.Param(jax.random.normal(self.key, (cfg.d_mlp, cfg.d_model))) # [d_mlp, d_model]
        self.b_in = nnx.Param(jnp.zeros((cfg.d_mlp)))
        self.b_out = nnx.Param(jnp.zeros((cfg.d_model)))

    def __call__(self, normal_resid_mid: jnp.ndarray) -> jnp.ndarray:
        # normal_resid_mid [batch x length x d_model]
        """
        b = batch
        l = length
        m = d_model
        p = d_mlp
        """
        out = jnp.einsum('blm, mp -> blp', normal_resid_mid, self.W_in) + self.b_in
        out = jax.nn.gelu(out)
        out = jnp.einsum('blp, pm -> blm', out, self.W_out) + self.b_out
        return out


In [121]:
class TransformerBlock(nnx.Module):
    def __init__(self, cfg: TransformerConfig, key):
        self.cfg = cfg
        self.key = key
        self.ln1 = LayerNorm(self.cfg, self.key)
        self.ln2 = LayerNorm(self.cfg, self.key)
        self.attn = Attention(self.cfg, self.key)
        self.mlp = MLP(self.cfg, self.key)

    def __call__(self, resid_pre: jnp.ndarray) -> jnp.ndarray:
        resid_mid = self.attn(self.ln1(resid_pre))
        resid_post = self.mlp(self.ln2(resid_pre))
        return(resid_post)

In [130]:
class Unembed(nnx.Module):
    def __init__(self, cfg: TransformerConfig, key):
        self.cfg = cfg
        self.key = key
        self.W_U = nnx.Param(jax.random.normal(self.key, (cfg.d_model, cfg.d_vocab)))
        self.b_U = nnx.Param(jnp.zeros(cfg.d_vocab))

    def __call__(self, normal_resid_post: jnp.ndarray) -> jnp.ndarray:
        # normal_resid_post: [batch x length x d_model]
        """
        b = batch
        l = length
        m = d_model
        b = d_vocab
        """
        return jnp.einsum('blm, mv -> blv', normal_resid_post, self.W_U) + self.b_U

In [137]:
class Transformer(nnx.Module):
    def __init__(self, cfg, key):
        self.cfg = cfg
        self.key = key
        self.embed = Embed(self.cfg, self.key)
        self.pos_embed = PosEmbed(self.cfg, self.key)
        self.blocks = [TransformerBlock(self.cfg, self.key) for _ in range(cfg.n_layers)]
        self.ln_final = LayerNorm(self.cfg, self.key)
        self.unembed = Unembed(self.cfg, self.key)

    def __call__(self, tokens: jnp.ndarray) -> jnp.ndarray:
        resid = self.embed(tokens) + self.pos_embed(tokens)
        for block in self.blocks:
            resid = block(resid)
        logits = self.unembed(self.ln_final(resid))
        return logits

In [140]:
key = jax.random.key(101)
cfg = TransformerConfig(
    d_model=64,
    d_vocab=1024,
)

In [None]:
def rand_float_test(cls, key, shape):
    random_input = jax.random.uniform(key, (shape))
    print("Input shape:", random_input.shape)
    output = cls(random_input)
    if isinstance(output, tuple): output = output[0]
    print("Output shape:", output.shape, "\n")

def rand_int_test(cls, key, shape):
    random_input = jax.random.randint(key, (shape), 100, 1000)
    print("Input shape:", random_input.shape)
    output = cls(random_input)
    if isinstance(output, tuple): output = output[0]
    print("Output shape:", output.shape, "\n")

In [186]:
# LayerNorm test
ln = LayerNorm(cfg, key)
rand_int_test(ln, key, (2, 4, cfg.d_model))

# Embed test
emb = Embed(cfg, key)
rand_int_test(emb, key, (2, 128))

# PosEmbed test
pos = PosEmbed(cfg, key)
rand_int_test(emb, key, (2, 128))

# Attention test
attn = Attention(cfg, key)
rand_float_test(attn, key, (2, 128, cfg.d_model))

# MLP test
mlp = MLP(cfg, key)
rand_float_test(attn, key, (2, 128, cfg.d_model))

# TransformerBlock test
tb = TransformerBlock(cfg, key)
rand_float_test(attn, key, (2, 128, cfg.d_model))

# Unembed test
un = Unembed(cfg, key)
rand_float_test(attn, key, (2, 128, cfg.d_model))

# Transformer test
t = Transformer(cfg, key)
rand_int_test(emb, key, (2, 128))

Input shape: (2, 4, 64)
Output shape: (2, 4, 64) 

Input shape: (2, 128)
Output shape: (2, 128, 64) 

Input shape: (2, 128)
Output shape: (2, 128, 64) 

Input shape: (2, 128, 64)
Output shape: (2, 128, 64) 

Input shape: (2, 128, 64)
Output shape: (2, 128, 64) 

Input shape: (2, 128, 64)
Output shape: (2, 128, 64) 

Input shape: (2, 128, 64)
Output shape: (2, 128, 64) 

Input shape: (2, 128)
Output shape: (2, 128, 64) 



In [None]:
text = """
This is sample text that I am going to tokenize.
"""

input_ids = tokenizer(text, return_tensors="np")
input_ids, attention_mask = jnp.array(input_ids['input_ids']), jnp.array(input_ids['attention_mask'])

Tokens: ['Ċ', 'This', 'Ġis', 'Ġsample', 'Ġtext', 'Ġthat', 'ĠI', 'Ġam', 'Ġgoing', 'Ġto', 'Ġtoken', 'ize', '.', 'Ċ']


In [210]:
transformer = Transformer(cfg, key)

logits = transformer(input_ids)

print(logits)
print(logits.shape)

[[[ 9.9230218e-01  1.5166521e-02 -5.0949659e+00 ...  2.2485504e+00
    4.3776274e-02 -2.3710639e+00]
  [-4.2430086e+00 -7.9845352e+00 -5.5063744e+00 ... -6.9265957e+00
   -7.1119471e+00 -1.6937866e+00]
  [-3.8464799e+00  1.7938212e+00  6.7881411e-01 ...  1.6993171e+00
    6.0196104e+00 -3.9587855e-01]
  ...
  [ 8.0927715e+00 -2.7217324e+00 -8.9893293e+00 ... -1.2862199e+01
   -1.2313092e+00 -4.1354780e+00]
  [-1.7582830e+00 -1.7279179e+01  7.8891530e+00 ... -1.8498978e+01
    7.2134991e+00 -4.6433630e+00]
  [-2.3629048e+00 -5.6193485e+00 -3.3724360e+00 ...  9.7430378e-01
   -5.7956591e+00  8.1657858e+00]]]
(1, 14, 1024)
