In [9]:
import gpt2 as nn
import jax.numpy as jnp
import jax
try:
  import tiktoken
except:
  !pip install tiktoken
  import tiktoken
import numpy as np
from importlib import reload
reload(nn)

<module 'gpt2' from '/Users/ryanbarouki/Documents/Coding/gpt_mini_jax/gpt2/gpt2.py'>

In [10]:
def count_params(params):
  return sum([p.size for p in jax.tree.leaves(params)])

In [11]:
key = jax.random.PRNGKey(42069)
config = nn.Config()
model = nn.GPT2(config)
enc = tiktoken.get_encoding('gpt2')

In [12]:
# Generate intial text
#key, init_key = jax.random.split(key)
#dummy = jnp.ones((1,1), dtype=int)
#params = model.init(init_key, dummy)
#
#start = "Hello, I'm a language model,"
#inputs = jnp.array(enc.encode(start)).reshape(1,-1)
#key, gen_key = jax.random.split(key)
#outs = model.generate_batch(gen_key, params, inputs, 30)
#enc.decode(outs[0])

In [13]:
# Loss function
import optax.losses as ll

def loss_fn(params, forward, x, y):
  logits = forward(params, x)
  return ll.softmax_cross_entropy_with_integer_labels(logits, y).mean()

#print(loss_fn(params, model.apply, x, y))

In [14]:
# Create a train state
from flax.training import train_state
import optax

def create_train_state(model, init_key, learning_rate):
  dummy = jnp.ones((1,1), dtype=int)
  params = model.init(init_key, dummy)
  optimizer = optax.adamw(learning_rate)

  return train_state.TrainState.create(
      apply_fn=model.apply,
      params=params,
      tx=optimizer
  )

In [15]:
# Create train step
@jax.jit
def train_step(state, x, y):
  def loss_f(params):
    logits = state.apply_fn(params, x)
    return ll.softmax_cross_entropy_with_integer_labels(logits, y).mean()

  loss, grads = jax.value_and_grad(loss_f)(state.params)
  state = state.apply_gradients(grads=grads)
  return state, loss

In [16]:
# Initialise model and training state
key, init_key = jax.random.split(key)
state = create_train_state(model, init_key, learning_rate=3e-4)

In [19]:
# Training LOOP
import dataloader as dl
loader = dl.DataLoaderLite(B=4, T=32, fname='../data/input.txt')
for _ in range(100):
  x,y = loader.next_batch()
  state, loss = train_step(state, x, y)
  print(loss)

loaded 338025 tokens
1 epoch = 2640 batches
10.943641
9.806313
9.102635
9.236937
8.669986
8.464756
8.996381
8.647453
8.193598
8.035293
8.38256
7.494884
7.8570457
7.4582863
7.5563455
7.3741255
7.5158014
8.280799
7.2334743
7.824947
7.5578537
7.898467
6.5368
6.9025664
6.93479
6.7480125
6.8465033
7.5895596
7.17289
6.9268947
6.8994384
7.2328315
7.1461525
6.9585185
7.964043
7.808412
7.6431413
7.627168
7.8731794
7.372366
7.4931045
6.7793202
6.944989
7.0636287
6.993957
7.0524063
5.9528337
6.2677107
6.9738917
6.742378
7.230375
6.8494043
5.7705455
7.221552
6.671711
6.4116793
6.9540753
6.8828444
6.233543
6.66222
6.2355337
6.6517463
6.227757
6.049005
6.633167
6.8364606
6.594459
6.063694
6.459332
6.6033936
6.822003
6.7523365
6.460961
6.405623
7.6293488
6.595389
6.9009953
6.026963
6.2947736
5.8505692
6.382288
6.963936
6.574122
5.669996
5.8705583
5.979701
6.2532315
6.9510736
7.424315
7.566949
7.1645794
6.85993
7.595907
6.5034065
5.9603977
6.4760723
6.1941314
6.400194
6.021373
6.527379
