In [None]:
import logging
import sys

import jax.experimental
import jax.experimental.multihost_utils

logging.basicConfig(
    format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
    datefmt="%m/%d/%Y %H:%M:%S",
    handlers=[logging.StreamHandler(sys.stdout)],
    level=logging.INFO,
)
import argparse
import glob
import time
from pathlib import Path

import flax.nnx as nnx
import jax
import jax.numpy as jnp
import numpy as np
import optax
import orbax.checkpoint as ocp
import torch
from tqdm.auto import tqdm

from src.lr_scheduler import LEARNING_RATE_SCHEDULES, get_learning_rate_scheduler
from src.model import Transformer

In [None]:
model = Transformer(10000, 4, 512, 2048, 4, rngs=nnx.Rngs(params=0, dropout=1), context_length=1024)

In [None]:
num_devices = jax.local_device_count()
mesh = jax.sharding.Mesh(jax.experimental.mesh_utils.create_device_mesh((num_devices,)), ("data",))
model_sharding = jax.NamedSharding(mesh, jax.sharding.PartitionSpec())
data_sharding = jax.NamedSharding(mesh, jax.sharding.PartitionSpec("data"))

In [None]:
learning_rate_fn = get_learning_rate_scheduler(
    "linear_warmup_cosine_decay",
    lr=0.01,
    warmup_steps=400,
    decay_steps=400,
    total_steps=10000,
)
optimizer = optax.chain(
    optax.clip_by_global_norm(1.0),
    optax.adamw(
        learning_rate=learning_rate_fn,
        b1=0.9,
        b2=0.95,
        weight_decay=0.1,
    ),
)
state = nnx.state((model, optimizer))
state = jax.device_put(state, model_sharding)
nnx.update((model, optimizer), state)

In [None]:
state = nnx.Optimizer(model, optimizer)

In [None]:
def loss_fn(model, x, y):
    logits = model(x)
    loss = optax.softmax_cross_entropy_with_integer_labels(logits.reshape(-1, 10000), y.reshape(-1))
    return loss.mean()


def train_step(model, x, y):
    grad_fn = nnx.value_and_grad(loss_fn)
    loss, grad = grad_fn(model, x, y)
    return loss, grad

In [None]:
x = np.random.randint(0, 10000, (4, 1024))
y = x

x, y = jax.device_put((x, y), data_sharding)
jax.debug.visualize_array_sharding(x)
loss, grad = train_step(state.model, x, y)
jax.debug.visualize_array_sharding(grad["token_emb"]["embedding"].value)