In [4]:
import jax, optax, os, requests, yaml
import jax.numpy as jnp
from jax import grad, jit, vmap, random
from jax.tree_util import tree_flatten
import plotly as py
from einops import rearrange
from functional import partial
import numpy as np
import pandas as pd
from tqdm import tqdm
from src.data import data_fn

In [5]:
epochs = 100
d      = 5

In [6]:
with open('config.yaml', 'r') as f:
    config = yaml.load(f, Loader=yaml.FullLoader)

In [183]:
def apply_fn(params, x):
    x = embed_fn(params, x)
    for block in params['blocks']:
        x = head_fn(block['head'], x) + mlp_fn(block['mlp'], x)
    logits = x # @ params['lm_head']                       # B x T x V

    return logits


def head_fn(params, x):
    x = [head_apply_fn(params[f'head_{i}'], x) for i in range(len(params))]
    x = jnp.stack(x, axis=-1).mean(axis=-1)
    return x

def head_apply_fn(params, x):
    # tril = jnp.nan_to_num(jnp.absolute(jnp.tril(jnp.ones((T, T))) - 1) * (-jnp.inf), nan=0)
    # mask = jnp.nan_to_num(jnp.triu(jnp.ones((T, T))) * (-jnp.inf), nan=0)
    k  = x @ params['key']                      # B x T x H
    q  = x @ params['query']                    # B x T x H
    z  = q @ k.transpose(0, 2, 1)               # B x T x T
    z /= jnp.sqrt(params['key'].shape[1])       # divide by sqrt to normalize
    z  = params['alpha'] * np.eye(x.shape[1]) + # params['beta'] * x              # - gamma * C  # <-- shaped attention
    return z @ x
    
def mlp_fn(params, x):  # TODO: maybe switch activation
    x = x @ params['dense1'] + params['bias1']
    x = jax.nn.relu(x)
    x = x @ params['dense2'] + params['bias2']
    return x

def embed_fn(params, x):
    n  = x.shape[1]                              # num toks in sample
    x  = params['tok_embedding'][x]              # tok embeddings
    x += params['pos_embedding'][jnp.arange(n)]  # pos embeddings
    return x


In [184]:
def init_head_fn(rng, embed_dim, n_heads, scale):
    head_size = embed_dim // n_heads
    rng, key_key, key_value, key_query = jax.random.split(rng, 4)
    params = {} 
    for i in range(n_heads):
        params[f'head_{i}'] = {
            'key'   : jax.random.normal(key_key, shape=(embed_dim, head_size)) * scale,
            'query' : jnp.zeros((embed_dim, head_size)),
            'alpha' : jnp.array(1),
            'beta'  : jnp.array(0),
            }
    return params

def init_mlp_fn(rng, embed_dim, scale=1e-2):
    rng, key1, key2 = jax.random.split(rng, 3)
    params = {
        'dense1' : jax.random.normal(key1, shape=(embed_dim, 4 * embed_dim)) * scale,
        'bias1'  : jnp.zeros((embed_dim,)),
        'dense2' : jax.random.normal(key2, shape=(4 * embed_dim, embed_dim)) * scale,
        'bias2'  : jnp.zeros((embed_dim,)),
        }
    return params

def init_block_fn(rng, embed_dim, n_heads, scale):
    rng, key1, key2 = jax.random.split(rng, 3)
    params = {
        'head'  : init_head_fn(key1, embed_dim, n_heads, scale),
        'mlp'  : init_mlp_fn(key2, embed_dim, scale),
        }
    return params

def init_fn(rng, config):
    rng, key1, key2, key3 = jax.random.split(rng, 4)
    params = {
        'tok_embedding': jax.random.normal(key1, shape=(config['vocab_size'], config['embed_dim'])) * config['scale'],
        'pos_embedding': jax.random.normal(key2, shape=(config['block_size'], config['embed_dim'])) * config['scale'],
        'lm_head': jax.random.normal(key3, shape=(config['embed_dim'], config['vocab_size'])) * config['scale'],
        'blocks': [init_block_fn(key1, config['embed_dim'], config['n_heads'], scale=config['scale']) for _ in range(config['n_layers'])],
        }
    return params

In [185]:
def loss_fn(params, xb, yb):
    # we cant to minimise cross entropy
    logits = apply_fn(params, xb) # B x T x C
    B, T, C = logits.shape
    yb = yb.reshape(-1)
    logits = logits.reshape(B * T, C)
    logits = jnp.clip(logits, -100, 100)
    loss = jnp.mean(optax.softmax_cross_entropy(logits, jax.nn.one_hot(yb, C)))
    return loss

In [186]:
def generate_fn(rng, params, idx, block_size, length=100, temperature=1.0):
    for _ in tqdm(range(length)):
        rng, key = jax.random.split(rng)
        logits = apply_fn(params, idx[:, -block_size:])         # B x T x C
        logits = logits[:, -1, :] / temperature                 # B x C
        idx_new = jax.random.categorical(key, logits)[:, None]  # B x 1
        idx = jnp.concatenate([idx, idx_new], axis=1)           # B x T + 1
    return idx

In [187]:
def batch_fn(rng, data):
    while True:
        rng, key = random.split(rng)
        idxs = random.permutation(key, data.shape[0])
        yield data[idxs]

In [188]:
rng, key = random.split(random.PRNGKey(0))
data     = jit(partial(data_fn, d))()
batches  = iter(batch_fn(key, data))
params   = init_fn(rng, config)

In [189]:
# py.plot(params['tok_embedding'], kind='density_heatmap')

In [190]:
weights, structure = tree_flatten(params)

In [191]:
sum([w.size for w in weights])

257200

In [192]:
for batch in batches:
    pred = apply_fn(params, batch)
    print(pred.shape)
    break

ValueError: Incompatible shapes for broadcasting: shapes=[(8, 8), (9000, 8, 64)]

In [193]:
params['tok_embedding'][-1]

Array([-3.6377236e-03,  1.3665689e-02, -8.0591515e-03, -5.0637429e-03,
       -1.1708740e-02,  7.6320437e-03,  1.7120766e-02, -1.0586310e-02,
       -5.6292932e-03,  2.7307081e-03,  5.5464049e-04, -2.0349741e-02,
        1.6700088e-03,  1.9430332e-02,  2.9991686e-04, -3.7635993e-03,
       -1.0269030e-02,  1.2892984e-02,  7.4097714e-03,  4.4379453e-04,
        1.7080469e-02, -7.1085035e-03, -5.1936684e-03, -1.1268475e-02,
       -3.9847223e-03, -4.8977381e-04,  8.3787097e-03, -1.8630379e-03,
        4.9545774e-03,  5.7404102e-03, -2.4294550e-02, -3.5352125e-03,
        1.1949964e-03,  7.5188410e-03,  3.0968210e-03, -1.8419711e-02,
        3.5259563e-03, -5.6057386e-03,  2.2549220e-05,  4.8695225e-03,
       -1.2809735e-02,  8.9974515e-03, -1.4005606e-02, -9.6928440e-03,
        3.5121944e-03, -6.3550975e-03,  1.8232320e-02, -8.3840955e-03,
       -1.1785286e-02, -2.9133266e-04, -5.3220778e-03,  1.2011958e-02,
        3.9937897e-03, -2.8522441e-03, -3.4668323e-04, -5.8955420e-03,
      

In [53]:
batch

Array([[54761, 54763, 54767, ...,     0,     1,     0],
       [85711, 85713, 85717, ...,     0,     1,     0],
       [82581, 82583, 82587, ...,     0,     0,     0],
       ...,
       [12461, 12463, 12467, ...,     0,     0,     0],
       [41131, 41133, 41137, ...,     0,     0,     0],
       [29281, 29283, 29287, ...,     0,     1,     0]], dtype=int32)