In [21]:
import data
from modelling import model
import jax.numpy as jnp
import jax
import numpy as np

from importlib import reload
model = reload(model)
data = reload(data)

In [15]:
ds = data.DNADataset(sequence_length=8192)

In [16]:
seqs = ds.load_and_retokenize_tfrecord(file_path='data/tfrecords/record_40.tfrecord')
seqs[0]

'TGAAATTCTTCCAAAAACTTGAAGGAGAAAGAGTATTTCCAAACTCATTTTAAAAGATCAGCATTATTGTTTTTTTTTTAAAGTGATGTTCCCCTTCCTGTGTCCATGTGTTCTCATTGTCCAATTCCCACCTATGAGTGAGAACATGCACTGTTTGGTTTTTTGTCCTTGTGATAGTTTGCTGAGAATGATGGTTTCCAGCTTCATCCATGTCCCTACAAAGGACATGAACTCATCATTTTTTATGGCTGCATAGTATTCCATGGTGTATATGTGCCACATTTTCTTAATCCAGTCTATCATTGTTGGACATTTGGATTGGTTCCAAGTCTTTGCTATTGTGAATAGTGCTGCAATAAACATACGTGTGCATGTGTCTTTATAGCAGCATGATTTATAATCCTTTGGGTATATACCCAGTAATGGGATGGCTGGGTCAAATGGTATTTCTAGTTCTAGATCCCTAAGAAATCGCCACACTGACTTCCACAATGGTTGAACTAGTTTACAGTCCCACCAACAGTGTAAAAATGTTCCTATTTCTCCACATCCTCTCCAGCACCTGTTGTTTCCTGACTTTTTAATGATGGCCATTCTAACTGGTGTGAGATGGTATCTCATTGTGGTTTTGATTTGCACTTCTCTGATGGCCAGTGATGATGAGCATTTTTTCATGTGTTTTTTGGCTGCATAGATGTCTTCTTTTGAGAAGTGTCTGTTCATATCCTTTGCCCACTTTTTGATGGGGTTGTTTGTTTTTTTCTTGTAAATTTGTTTGGGTTCATTGTAGATTCCGGATATTAGCACTGGGGCCTGTTGTGGGGTGGGGGGAGGGGGGAGGGATAGCATTAGGAGATACACCTAATGTTAAATGATGAGTTAATGGGTGTAGCACACCAGCATGGCACATGTATACATATGTAACTAACCTGCACGTTGTGCACATGTACCCTAAAACTTAAAGTATAATTTAAAAAATAAATAAATAAAAATAAAAAT

In [32]:
def process_batch(batch):
    batch_size = batch['x'].shape[0]
    dummy = np.zeros((batch_size, 1), dtype=jnp.int32)
    return {
        'x': np.concatenate([batch['x'][:, :-1], dummy], axis=-1),
        'y': np.concatenate([batch['x'][:, 1:], dummy], axis=-1),
        'segment_ids': np.concatenate([batch['segment_ids'][:, :-1], dummy], axis=-1)
    }
iter = ds.create_iterator('data/tfrecords/record_*.tfrecord', batch_size=8)

In [40]:
cfg = model.Config(
    d_model=1024,
    ffw_multiplier=4,
    query_heads=8,
    key_heads=8,
    num_layers=12,
    key_dim=128,
    vocab_size=8,
    max_seq_len=8192,
    causal=True,
    use_attn_kernel=True,
    weight_dtype_at_rest=jnp.float32,
    active_weight_dtype=jnp.bfloat16,
    rules=model.fsdp_rules,
    mesh=model.create_mesh(),
    max_lr=3e-4,
    min_lr=1e-5,
    warmup_steps=50,
    total_steps=10000,
)

In [41]:
weights = model.Weights.init(cfg, jax.random.PRNGKey(0), cfg.mesh, model.fsdp_rules)
opt_state = model.init_adam_state(weights)

In [42]:
batch = process_batch(next(iter))
step = jax.jit(model.update_step, static_argnames='cfg')
batch = jax.device_put(batch, model.input_shardings(cfg.mesh, cfg.rules))
step = step.lower(weights, batch['x'], batch['segment_ids'], batch['y'], opt_state, 0, cfg=cfg).compile()

In [43]:
# Does loss go down!
for i in range(0, 1000):
    batch = process_batch(next(iter))
    batch = jax.device_put(batch, model.input_shardings(cfg.mesh, cfg.rules))
    loss, weights, opt_state, internals = step(weights, batch['x'], batch['segment_ids'], batch['y'], opt_state, i)
    print(loss, internals['accuracy'])

39.600475 0.0015413258
42.542957 0.000656208
24.45315 0.026782444
12.467861 0.30420583
16.69549 0.276172
13.469164 0.21499206
6.736226 0.28287143
13.2519245 0.31150043
11.449053 0.26455867
12.979078 0.2126114
10.387452 0.3179099
27.142069 0.27792698
18.599087 0.2771792
27.184969 0.22326334
28.56374 0.22718532
55.92567 0.30896714
51.755417 0.27913257
49.067627 0.3103406
47.243843 0.31592602
38.434044 0.23988219
18.431517 0.2569741
50.29581 0.20176108
40.111393 0.29326394
42.14412 0.27371505
46.94265 0.3032139
51.457863 0.29218045
45.39944 0.21384752
16.628838 0.34763765
32.040314 0.31752837
58.09972 0.19815956
49.791286 0.24151507
14.049938 0.26950312
25.736164 0.31082895
43.409527 0.3056556
65.297554 0.1886827
52.14601 0.22802466
49.1439 0.2910664
69.36028 0.3074411
58.23304 0.2982542
52.43035 0.2998718
30.879581 0.30229825
45.45823 0.23144305
48.89489 0.24120986
50.66997 0.27081552
53.23142 0.2860762
36.237957 0.29088327
23.888727 0.21483946
43.20402 0.31082895
34.800243 0.28601512
32

In [38]:
internals

{'accuracy': Array(0.34821755, dtype=float32),
 'grad_norms': Weights(layers=[Layer(q=Array(0.07459337, dtype=float32), k=Array(0.08246642, dtype=float32), v=Array(0.57063633, dtype=float32), proj=Array(0.62406945, dtype=float32), w1=Array(0.55819076, dtype=float32), w2=Array(0.6721479, dtype=float32), gamma1=Array(0.02530021, dtype=float32), gamma2=Array(0.02334251, dtype=float32)), Layer(q=Array(0.04922511, dtype=float32), k=Array(0.03818633, dtype=float32), v=Array(0.4549637, dtype=float32), proj=Array(0.5107503, dtype=float32), w1=Array(0.7790565, dtype=float32), w2=Array(0.8085995, dtype=float32), gamma1=Array(0.01952285, dtype=float32), gamma2=Array(0.03541267, dtype=float32)), Layer(q=Array(0.03584538, dtype=float32), k=Array(0.03503671, dtype=float32), v=Array(0.72344947, dtype=float32), proj=Array(0.6892184, dtype=float32), w1=Array(1.5681542, dtype=float32), w2=Array(1.4255437, dtype=float32), gamma1=Array(0.03180821, dtype=float32), gamma2=Array(0.0905989, dtype=float32)), L