# Weight Checkpoint Demo

This notebook trains a small copy-task model, saves its weights with the new helper methods, reloads them, and verifies that the restored model performs the same evaluation as the original.



In [1]:
import sys
import pathlib

import jax
import jax.numpy as jnp
import optax

project_root = pathlib.Path.cwd()
if str(project_root) not in sys.path:
    sys.path.append(str(project_root))

from src.data.copy_dataset import CopyDataset
from src.models.base import ModelConfig
from src.models.lru import LinearRecurrentUnit



In [2]:
dataset = CopyDataset(min_lag=5, max_lag=15, batch_size=32, num_classes=10, seq_length=8)
model_config = ModelConfig(
    input_dim=10,
    output_dim=10,
    hidden_dim=64,
    num_layers=1,
    precision="float32",
    param_dtype="float32",
)
model = LinearRecurrentUnit(model_config, mlp_hidden_dim=64)
optimizer = optax.adam(learning_rate=1e-3)

train_steps = 150
key = jax.random.PRNGKey(0)
params = model.initialize(key)
opt_state = optimizer.init(params)



In [3]:
def embed_inputs(token_ids: jnp.ndarray) -> jnp.ndarray:
    return jax.nn.one_hot(token_ids, model_config.input_dim, dtype=jnp.float32)


def shift_targets(token_ids: jnp.ndarray, mask: jnp.ndarray):
    shifted_targets = jnp.concatenate([token_ids[:, 1:], jnp.zeros_like(token_ids[:, :1])], axis=1)
    shifted_mask = jnp.concatenate([mask[:, 1:], jnp.zeros_like(mask[:, :1])], axis=1)
    return shifted_targets, shifted_mask


def compute_metrics(logits: jnp.ndarray, target: jnp.ndarray, mask: jnp.ndarray):
    log_probs = jax.nn.log_softmax(logits, axis=-1)
    ll = jnp.take_along_axis(log_probs, target[..., None], axis=-1)[..., 0]
    nll = -jnp.sum(ll * mask) / jnp.sum(mask)
    predictions = jnp.argmax(logits, axis=-1)
    accuracy = jnp.sum((predictions == target) * mask) / jnp.sum(mask)
    return {"nll": nll, "accuracy": accuracy}


# @jax.jit
def train_step(params, opt_state, batch):
    inputs, targets, mask = batch
    shifted_targets, shifted_mask = shift_targets(targets, mask)

    def loss_fn(current_params):
        logits = model.apply(current_params, embed_inputs(inputs), mask)
        metrics = compute_metrics(logits, shifted_targets, shifted_mask)
        return metrics["nll"], metrics

    (loss, metrics), grads = jax.value_and_grad(loss_fn, has_aux=True)(params)
    updates, opt_state = optimizer.update(grads, opt_state, params)
    params = optax.apply_updates(params, updates)
    return params, opt_state, metrics


# @jax.jit
def eval_step(params, batch):
    inputs, targets, mask = batch
    shifted_targets, shifted_mask = shift_targets(targets, mask)
    logits = model.apply(params, embed_inputs(inputs), mask)
    return compute_metrics(logits, shifted_targets, shifted_mask)



In [4]:
training_log = []
for step in range(1, train_steps + 1):
    batch = dataset()
    params, opt_state, train_metrics = train_step(params, opt_state, batch)
    if step % 15 == 0:
        eval_metrics = eval_step(params, dataset())
        training_log.append(
            {
                "step": step,
                "train_nll": float(train_metrics["nll"]),
                "eval_accuracy": float(eval_metrics["accuracy"]),
            }
        )
        print(
            f"step {step:04d} | train nll={training_log[-1]['train_nll']:.4f} | eval acc={training_log[-1]['eval_accuracy']:.4f}"
        )



step 0015 | train nll=2.1165 | eval acc=0.2969
step 0030 | train nll=2.1211 | eval acc=0.2148
step 0045 | train nll=2.0383 | eval acc=0.1641
step 0060 | train nll=1.9975 | eval acc=0.2500
step 0075 | train nll=1.8982 | eval acc=0.3047
step 0090 | train nll=1.7562 | eval acc=0.2578
step 0105 | train nll=1.7903 | eval acc=0.2070
step 0120 | train nll=1.8785 | eval acc=0.1641
step 0135 | train nll=1.8796 | eval acc=0.3164
step 0150 | train nll=1.8916 | eval acc=0.2422


In [5]:
checkpoint_dir = project_root / "weight_checkpoint_demo" / "checkpoints"
checkpoint_dir.mkdir(parents=True, exist_ok=True)
checkpoint_path = checkpoint_dir / "lru_copy_step_{:04d}.pkl".format(train_steps)

saved_path = model.save_weights(params, checkpoint_path)
reloaded_params = model.load_weights(saved_path)

comparison = jax.tree_util.tree_map(
    lambda x, y: jnp.allclose(x, y, atol=1e-6), params, reloaded_params
)
all_match = bool(jax.tree_util.tree_reduce(lambda a, b: a & b, comparison, True))
print(f"Weights saved to {saved_path}")
print(f"Parameters identical after reload: {all_match}")



Weights saved to /Users/mitchellostrow/Desktop/Projects/SLT/9520_recurrent_networks/weight_checkpoint_demo/checkpoints/lru_copy_step_0150.pkl
Parameters identical after reload: True


In [6]:
test_batch = dataset()
original_metrics = eval_step(params, test_batch)
reloaded_metrics = eval_step(reloaded_params, test_batch)

print(
    f"Original eval accuracy: {float(original_metrics['accuracy']):.4f}\n"
    f"Reloaded eval accuracy: {float(reloaded_metrics['accuracy']):.4f}"
)



Original eval accuracy: 0.2578
Reloaded eval accuracy: 0.2578
