# Investigating NaN during mixed precision training

Our hunch is that the NaN is due to the output logits diverging from the log probabilities. Thus, we take a checkpoint right before the first NaN, run a forward pass and inspect its output logits using `treescope`. We do the same for an earlier checkpoint and compare the difference between the two.

In [None]:
# Use modular helpers from train_lam to restore a checkpoint and enable sowing
from typing import Dict
import os

import jax
import jax.numpy as jnp
import flax.nnx as nnx

from train_lam import (
    Args,
    build_model,
    build_optimizer,
    build_dataloader,
    build_checkpoint_manager,
    restore_checkpoint_if_needed,
    enable_sowing,
)

# TODO: set data and checkpoint directories
data_dir = "coinrun"
ckpt_dir = "checkpoint"

args = Args(
    num_steps=200_000,
    seed=0,
    seq_len=16,
    image_channels=3,
    image_height=64,
    image_width=64,
    data_dir=data_dir,
    save_ckpt=False,
    restore_ckpt=True,
    # Optimization
    batch_size=36,
    vq_beta=0.25,
    init_lr=0.0,
    max_lr=3e-5,
    decay_end=0.0,
    wsd_decay_steps=(
        10000  # NOTE: wsd_decay_steps will only be used when using a wsd-schedule
    ),
    warmup_steps=5000,
    lr_schedule="wsd",
    vq_reset_thresh=50,
    # LAM
    model_dim=512,
    ffn_dim=2048,
    latent_dim=32,
    num_latents=6,
    patch_size=16,
    num_blocks=4,
    num_heads=8,
    dropout=0.0,
    codebook_dropout=0.0,
    # Logging
    log=False,
    entity="",
    project="",
    name="train_lam",
    tags=["lam"],
    log_interval=5,
    log_image_interval=250,
    ckpt_dir=ckpt_dir,
    log_checkpoint_interval=10000,
    log_checkpoint_keep_period=20000,
    wandb_id="",
    use_flash_attention=True,
)

# Build components
rng = jax.random.key(args.seed)
lam, rng = build_model(args, rng)
optimizer, lr_schedule_fn = build_optimizer(lam, args)
ckpt_mgr = build_checkpoint_manager(args)
_, loader_iterator = build_dataloader(args)

Prepare the batch

In [None]:
videos = next(loader_iterator)

gt = jnp.asarray(videos, dtype=jnp.float32) / 255.0
videos = gt.astype(args.dtype)

rng = jax.random.key(args.seed)
rng, _rng = jax.random.split(rng, 2)

batch: Dict[str, jax.Array] = {
    "videos": videos,
    "rng": _rng,
}

Forward pass of checkpoint at 107k steps

In [4]:
# Restore latest checkpoint
step, optimizer, loader_iterator = restore_checkpoint_if_needed(
    args, ckpt_mgr, optimizer, loader_iterator
)
lam = optimizer.model
print(f"Restored optimizer and dataloader at step {step}.")

enable_sowing(lam)
print("Sowing enabled on encoder/decoder.")
outputs = lam(batch, training=True)

nnx.display(lam.encoder)



Restored dataloader and model state from step 107000
Restored optimizer and dataloader at step 107000.
Sowing enabled on encoder/decoder.


2025-08-19 18:38:28.802968: W external/xla/xla/service/gpu/autotuning/dot_search_space.cc:200] All configs were filtered out because none of them sufficiently match the hints. Maybe the hints set does not contain a good representative set of valid configs?Working around this by using the full hints set instead.


Forward pass of checkpoint at 20k steps

In [8]:
# checkpoint @ 20k
lam, rng = build_model(args, rng)
optimizer, lr_schedule_fn = build_optimizer(lam, args)
step, optimizer, loader_iterator = restore_checkpoint_if_needed(
    args, ckpt_mgr, optimizer, loader_iterator, 20000
)
lam = optimizer.model
print(f"Restored optimizer and dataloader at step {step}.")

enable_sowing(lam)
print("Sowing enabled on encoder/decoder.")

outputs = lam(batch, training=True)

nnx.display(lam.encoder)




Restored dataloader and model state from step 20000
Restored optimizer and dataloader at step 20000.
Sowing enabled on encoder/decoder.


As you can see, `treescope` shows that the range of the output logits is much bigger in the later checkpoint than the earlier one. A clear sign of a known training instability.