In [None]:
import flax.nnx as nnx
import jax
import jax.numpy as jnp
import optax
from tqdm.auto import tqdm

from src.model import Transformer

In [None]:
rngs = nnx.Rngs(0)
layer = nnx.Linear(768, 768 * 2, rngs=rngs, use_bias=False)

In [None]:
x = jax.random.normal(jax.random.PRNGKey(0), (4, 768))
temp = layer(x)

In [None]:
x, gate = jnp.split(temp, 2, axis=-1)
x.shape, gate.shape

jnp.allclose(temp, jnp.concat([x, gate], axis=-1))

In [None]:
document_ids = jnp.array([0, 0, 0, 1, 1, 2, 2, 2, 2, 2]).reshape(1, -1)
causal_mask = nnx.make_causal_mask(document_ids)
print(causal_mask)
doc_mask = nnx.make_attention_mask(document_ids, document_ids, jnp.equal)
print(doc_mask)
final_mask = nnx.combine_masks(causal_mask, doc_mask)
print(final_mask)

In [None]:
def loss_fn(model, x):
    logits = model(x)
    loss = optax.softmax_cross_entropy_with_integer_labels(logits.reshape(-1, 10000), x.reshape(-1))
    return loss


@nnx.jit
def train_step(model, optimizer, x):
    def final_loss_fn(model, x):
        loss = loss_fn(model, x)
        return loss.mean()

    grad_fn = nnx.value_and_grad(final_loss_fn)
    loss, grads = grad_fn(model, x)
    optimizer.update(grads)
    return loss


@nnx.jit
def vmap_train_step(model, optimizer, x):
    vmap_loss_fn = nnx.vmap(loss_fn, in_axes=(None, 0))

    def final_loss_fn(model, x):
        loss = vmap_loss_fn(model, x)
        return loss.mean()

    grad_fn = nnx.value_and_grad(final_loss_fn)
    loss, grads = grad_fn(model, x)
    optimizer.update(grads)
    return loss

In [None]:
model = Transformer(
    vocab_size=10000,
    num_layers=4,
    dim=512,
    dim_ff=2048,
    num_heads=8,
    rngs=nnx.Rngs(params=0, dropout=0),
    context_length=1024,
    ff_activation="gelu",
    ff_dropout=0.0,
    attention_dropout=0.0,
    residual_dropout=0.0,
    use_bias=False,
    norm_class="rmsnorm",
    use_glu=True,
)
optimizer = nnx.Optimizer(model, optax.adam(learning_rate=1e-4))

losses = []
batch_size = 64
context_length = 1024

train_step(model, optimizer, jax.random.randint(jax.random.PRNGKey(0), (batch_size, context_length), 0, 10000))
for _ in tqdm(range(500)):
    losses.append(
        train_step(model, optimizer, jax.random.randint(jax.random.PRNGKey(0), (batch_size, context_length), 0, 10000))
    )
print(losses[:10])

In [None]:
model = Transformer(
    vocab_size=10000,
    num_layers=4,
    dim=768,
    dim_ff=2048,
    num_heads=8,
    rngs=nnx.Rngs(params=0, dropout=0),
    context_length=1024,
    ff_activation="gelu",
    ff_dropout=0.0,
    attention_dropout=0.0,
    residual_dropout=0.0,
    use_bias=False,
    norm_class="rmsnorm",
    use_glu=True,
)
optimizer = nnx.Optimizer(model, optax.adam(learning_rate=1e-4))

losses = []
vmap_train_step(model, optimizer, jax.random.randint(jax.random.PRNGKey(0), (batch_size, context_length), 0, 10000))
for _ in tqdm(range(500)):
    losses.append(
        vmap_train_step(
            model, optimizer, jax.random.randint(jax.random.PRNGKey(0), (batch_size, context_length), 0, 10000)
        )
    )
print(losses[:10])