In [5]:
import jax, optax, yaml
import jax.numpy as jnp
from jax import grad, jit, vmap, random
from jax.tree_util import tree_flatten

from einops import rearrange
from functional import partial

import numpy as np
import yaml
import numpy as np
from tqdm import tqdm
from src.data import data_fn, prime_fn

In [36]:
def get_config():
    pass

n      = 100000
base   = 10
n_toks = base

In [9]:
rng, key = random.split(random.PRNGKey(0))
data     = data_fn(key, n, base)

In [10]:
with open('config.yaml', 'r') as f:
    config = yaml.load(f, Loader=yaml.FullLoader)

In [11]:
def apply_fn(params, x):
    x = embed_fn(params, x)
    for block in params['blocks']:
        x = head_fn(block['head'], x)  # + mlp_fn(block['mlp'], x)
    logits = x # @ params['lm_head']                       # B x T x V
    return logits

def head_fn(params, x):
    x = [head_apply_fn(params[f'head_{i}'], x) for i in range(len(params))]
    x = jnp.stack(x, axis=-1).mean(axis=-1)
    return x

def head_apply_fn(params, x):
    # tril = jnp.nan_to_num(jnp.absolute(jnp.tril(jnp.ones((T, T))) - 1) * (-jnp.inf), nan=0)
    # mask = jnp.nan_to_num(jnp.triu(jnp.ones((T, T))) * (-jnp.inf), nan=0)
    k  = x @ params['key']                      # B x T x H
    q  = x @ params['query']                    # B x T x H
    z  = q @ rearrange(k, "b t h -> b h t")  # k.transpose(0, 2, 1)               # B x T x T
    z /= jnp.sqrt(params['key'].shape[1])       # divide by sqrt to normalize
    z  = params['alpha'] * np.eye(x.shape[1]) # +  params['beta'] * x              # - gamma * C  # <-- shaped attention
    return z @ x
    
def mlp_fn(params, x):  # TODO: maybe switch activation
    x = x @ params['dense1'] + params['bias1']
    x = jax.nn.relu(x)
    x = x @ params['dense2'] + params['bias2']
    return x

def embed_fn(params, x):
    n  = x.shape[1]                              # num toks in sample
    x  = params['tok_embedding'][x]              # tok embeddings
    x += params['pos_embedding'][jnp.arange(n)]  # pos embeddings
    return x

In [12]:
def init_head_fn(rng, embed_dim, n_heads, scale):
    head_size = embed_dim // n_heads
    rng, key_key, key_value, key_query = jax.random.split(rng, 4)
    params = {} 
    for i in range(n_heads):
        params[f'head_{i}'] = {
            'key'   : jax.random.normal(key_key, shape=(embed_dim, head_size)) * scale,
            'query' : jnp.zeros((embed_dim, head_size)),
            'alpha' : jnp.array(1),
            'beta'  : jnp.array(0),
            }
    return params

def init_mlp_fn(rng, embed_dim, scale=1e-2):
    rng, key1, key2 = jax.random.split(rng, 3)
    params = {
        'dense1' : jax.random.normal(key1, shape=(embed_dim, 4 * embed_dim)) * scale,
        'bias1'  : jnp.zeros((embed_dim,)),
        'dense2' : jax.random.normal(key2, shape=(4 * embed_dim, embed_dim)) * scale,
        'bias2'  : jnp.zeros((embed_dim,)),
        }
    return params

def init_block_fn(rng, embed_dim, n_heads, scale):
    rng, key1, key2 = jax.random.split(rng, 3)
    params = {
        'head'  : init_head_fn(key1, embed_dim, n_heads, scale),
        'mlp'  : init_mlp_fn(key2, embed_dim, scale),
        }
    return params

def init_fn(rng, config):
    rng, key1, key2, key3 = jax.random.split(rng, 4)
    params = {
        'tok_embedding': jax.random.normal(key1, shape=(config['vocab_size'], config['embed_dim'])) * config['scale'],
        'pos_embedding': jax.random.normal(key2, shape=(config['block_size'], config['embed_dim'])) * config['scale'],
        'lm_head': jax.random.normal(key3, shape=(config['embed_dim'], config['vocab_size'])) * config['scale'],
        'blocks': [init_block_fn(key1, config['embed_dim'], config['n_heads'], scale=config['scale']) for _ in range(config['n_layers'])],
        }
    return params

In [13]:
def loss_fn(params, xb, yb):
    # we cant to minimise cross entropy
    logits = apply_fn(params, xb) # B x T x C
    B, T, C = logits.shape
    yb = yb.reshape(-1)
    logits = logits.reshape(B * T, C)
    logits = jnp.clip(logits, -100, 100)
    loss = jnp.mean(optax.softmax_cross_entropy(logits, jax.nn.one_hot(yb, C)))
    return loss

In [14]:
def generate_fn(rng, params, idx, block_size, length=100, temperature=1.0):
    for _ in tqdm(range(length)):
        rng, key = jax.random.split(rng)
        logits = apply_fn(params, idx[:, -block_size:])         # B x T x C
        logits = logits[:, -1, :] / temperature                 # B x C
        idx_new = jax.random.categorical(key, logits)[:, None]  # B x 1
        idx = jnp.concatenate([idx, idx_new], axis=1)           # B x T + 1
    return idx

In [15]:
def batch_fn(rng, data):
    while True:
        rng, key = random.split(rng)
        idxs = random.permutation(key, data.shape[0])
        yield data[idxs]

In [17]:
rng, key = random.split(random.PRNGKey(0))
# data     = jit(partial(data_fn, d))()
# batches  = iter(batch_fn(key, data))
params   = init_fn(rng, config)

In [None]:
# py.plot(params['tok_embedding'], kind='density_heatmap')

In [23]:
weights, structure = tree_flatten(params)

In [24]:
sum([w.size for w in weights])

256048

In [25]:
import matplotlib.pyplot as plt

In [33]:
[w.shape for w in weights]

[(),
 (),
 (64, 16),
 (64, 16),
 (),
 (),
 (64, 16),
 (64, 16),
 (),
 (),
 (64, 16),
 (64, 16),
 (),
 (),
 (64, 16),
 (64, 16),
 (64,),
 (64,),
 (64, 256),
 (256, 64),
 (),
 (),
 (64, 16),
 (64, 16),
 (),
 (),
 (64, 16),
 (64, 16),
 (),
 (),
 (64, 16),
 (64, 16),
 (),
 (),
 (64, 16),
 (64, 16),
 (64,),
 (64,),
 (64, 256),
 (256, 64),
 (),
 (),
 (64, 16),
 (64, 16),
 (),
 (),
 (64, 16),
 (64, 16),
 (),
 (),
 (64, 16),
 (64, 16),
 (),
 (),
 (64, 16),
 (64, 16),
 (64,),
 (64,),
 (64, 256),
 (256, 64),
 (),
 (),
 (64, 16),
 (64, 16),
 (),
 (),
 (64, 16),
 (64, 16),
 (),
 (),
 (64, 16),
 (64, 16),
 (),
 (),
 (64, 16),
 (64, 16),
 (64,),
 (64,),
 (64, 256),
 (256, 64),
 (),
 (),
 (64, 16),
 (64, 16),
 (),
 (),
 (64, 16),
 (64, 16),
 (),
 (),
 (64, 16),
 (64, 16),
 (),
 (),
 (64, 16),
 (64, 16),
 (64,),
 (64,),
 (64, 256),
 (256, 64),
 (),
 (),
 (64, 16),
 (64, 16),
 (),
 (),
 (64, 16),
 (64, 16),
 (),
 (),
 (64, 16),
 (64, 16),
 (),
 (),
 (64, 16),
 (64, 16),
 (64,),
 (64,),
 (64, 256),
 (25

In [35]:
print(structure)

PyTreeDef({'blocks': [{'head': {'head_0': {'alpha': *, 'beta': *, 'key': *, 'query': *}, 'head_1': {'alpha': *, 'beta': *, 'key': *, 'query': *}, 'head_2': {'alpha': *, 'beta': *, 'key': *, 'query': *}, 'head_3': {'alpha': *, 'beta': *, 'key': *, 'query': *}}, 'mlp': {'bias1': *, 'bias2': *, 'dense1': *, 'dense2': *}}, {'head': {'head_0': {'alpha': *, 'beta': *, 'key': *, 'query': *}, 'head_1': {'alpha': *, 'beta': *, 'key': *, 'query': *}, 'head_2': {'alpha': *, 'beta': *, 'key': *, 'query': *}, 'head_3': {'alpha': *, 'beta': *, 'key': *, 'query': *}}, 'mlp': {'bias1': *, 'bias2': *, 'dense1': *, 'dense2': *}}, {'head': {'head_0': {'alpha': *, 'beta': *, 'key': *, 'query': *}, 'head_1': {'alpha': *, 'beta': *, 'key': *, 'query': *}, 'head_2': {'alpha': *, 'beta': *, 'key': *, 'query': *}, 'head_3': {'alpha': *, 'beta': *, 'key': *, 'query': *}}, 'mlp': {'bias1': *, 'bias2': *, 'dense1': *, 'dense2': *}}, {'head': {'head_0': {'alpha': *, 'beta': *, 'key': *, 'query': *}, 'head_1': {'al