In [14]:
from IPython.core.interactiveshell import InteractiveShell

InteractiveShell.ast_node_interactivity = "all"
%reload_ext autoreload
%autoreload 2

In [19]:
import flax.nnx as nnx

nnx.Rngs(0).default.key

RngKey(
  value=Array((), dtype=key<fry>) overlaying:
  [0 0],
  tag='default'
)

In [15]:
from gpt2 import GPT
import optax
import flax.nnx as nnx


@nnx.jit
def train_step(model: GPT, optimizer: nnx.Optimizer, metrics: nnx.MultiMetric, batch):
    x, y = batch

    def loss_fn(model: GPT):
        logits = model(x)
        loss = optax.softmax_cross_entropy_with_integer_labels(
            logits.reshape([-1, logits.shape[-1]]), y.reshape([-1])
        ).mean()
        return loss, logits

    grad_fn = nnx.value_and_grad(loss_fn, has_aux=True)
    (loss, logits), grads = grad_fn(model)
    metrics.update(loss=loss, logits=logits, labels=y)
    optimizer.update(grads)

In [20]:
from gpt2 import GPT
import optax
import flax.nnx as nnx
import numpy as np
import tiktoken
import datasets

model = GPT.from_pretrained("gpt2")
model.train()
tx = optax.adamw(learning_rate=1e-4, weight_decay=1e-4)
optimizer = nnx.Optimizer(model, tx)
metrics = nnx.MultiMetric(
    accuracy=nnx.metrics.Accuracy(),
    loss=nnx.metrics.Average("loss"),
)

enc = tiktoken.get_encoding("gpt2")
batch_size = 6
block_size = 1024

data = datasets.load_dataset(path="Trelis/tiny-shakespeare")
train_data = "\n".join([x["Text"] for x in data["train"]])
train_data = enc.encode_ordinary(train_data)
train_data = np.array(train_data, dtype=np.uint16)
val_data = "\n".join([x["Text"] for x in data["test"]])
val_data = enc.encode_ordinary(val_data)
val_data = np.array(val_data, dtype=np.uint16)


def get_batch(split):
    data = train_data if split == "train" else val_data
    ix = np.random.randint(len(data) - block_size, size=(batch_size,))
    x = np.stack([data[i : i + block_size].astype(np.int32) for i in ix])
    y = np.stack([data[i + 1 : i + 1 + block_size].astype(np.int32) for i in ix])
    return x, y


train_step(model, optimizer, metrics, get_batch("train"))
metrics.compute()

train_step(model, optimizer, metrics, get_batch("train"))
metrics.compute()

loading weights from pretrained gpt: gpt2


Length of prepared JAX modules dict: 89


{'accuracy': Array(0.3976237, dtype=float32),
 'loss': Array(4.034092, dtype=float32)}

{'accuracy': Array(0.38875327, dtype=float32),
 'loss': Array(3.8402543, dtype=float32)}