# Zenith JAX Core Integration - Phase 1 Validation

**Date:** 2025-12-28  
**Purpose:** Validate Phase 1 implementation with real JAX on GPU

## Components Being Tested:
1. **Gradient Checkpointing** - `zenith.jax.checkpointing`
2. **Memory Management** - `zenith.jax.memory_manager`
3. **Mixed Precision Training** - `zenith.jax.mixed_precision`

## Hardware Requirements:
- GPU Runtime (T4 or better)
- JAX with GPU support

---

## Setup: Install Dependencies

In [None]:
# Install pyzenith from PyPI
%pip install pyzenith --quiet

# Verify JAX GPU support (pre-installed in Colab)
import jax
import jax.numpy as jnp

print(f"JAX version: {jax.__version__}")
print(f"Devices: {jax.devices()}")
print(f"Default backend: {jax.default_backend()}")

In [None]:
# Verify Zenith installation
import zenith
print(f"Zenith version: {zenith.__version__}")

# Import Phase 1 modules
from zenith.jax.checkpointing import (
    OptimalCheckpointSelector,
    ZenithCheckpointer,
    CheckpointConfig,
    CheckpointPolicy,
    SelectionMethod,
    checkpoint,
    checkpoint_sequential,
    remat,
)
from zenith.jax.memory_manager import (
    JAXActivationStore,
    JAXMemoryManager,
    JAXMemoryConfig,
    EvictionPolicy,
    compute_array_size,
    get_device_string,
)
from zenith.jax.mixed_precision import (
    MixedPrecisionPolicy,
    DynamicLossScaler,
    LossScalerConfig,
    ZenithMixedPrecision,
    PrecisionMode,
    create_policy,
    detect_best_precision,
)

print("\nAll Phase 1 modules imported successfully!")

---
## Section 1: Gradient Checkpointing Validation

We will test:
1. `OptimalCheckpointSelector` algorithms (sqrt, DP, uniform)
2. `checkpoint()` function with actual `jax.grad`
3. Memory reduction measurement

In [None]:
# TEST 1.1: OptimalCheckpointSelector Algorithms
print("="*60)
print("TEST 1.1: OptimalCheckpointSelector Algorithms")
print("="*60)

# Test with various network sizes
for num_layers in [4, 12, 24, 48, 96]:
    selector = OptimalCheckpointSelector(num_layers=num_layers)

    sqrt_ckpts = selector.select_sqrt()
    dp_ckpts = selector.select_dp()
    uniform_ckpts = selector.select_uniform(num_checkpoints=4)

    reduction = selector.estimate_memory_reduction(sqrt_ckpts)

    print(f"\nLayers: {num_layers}")
    print(f"  sqrt: {len(sqrt_ckpts)} checkpoints -> {sqrt_ckpts[:5]}{'...' if len(sqrt_ckpts) > 5 else ''}")
    print(f"  DP:   {len(dp_ckpts)} checkpoints -> {dp_ckpts[:5]}{'...' if len(dp_ckpts) > 5 else ''}")
    print(f"  Memory reduction (sqrt): {reduction:.1f}%")

print("\n[PASS] OptimalCheckpointSelector works correctly!")

In [None]:
# TEST 1.2: checkpoint() with Real JAX Gradient Computation
print("="*60)
print("TEST 1.2: checkpoint() with Real JAX Gradients")
print("="*60)

import jax
import jax.numpy as jnp

# Define a simple MLP layer
def mlp_layer(x, w1, w2):
    """Simple MLP: x -> linear -> relu -> linear"""
    h = jnp.dot(x, w1)
    h = jax.nn.relu(h)
    return jnp.dot(h, w2)

# Create checkpointed version using Zenith
checkpointed_mlp = checkpoint(mlp_layer, policy=CheckpointPolicy.DOTS_SAVEABLE)

# Initialize weights
key = jax.random.PRNGKey(42)
x = jax.random.normal(key, (32, 64))  # batch=32, features=64
w1 = jax.random.normal(key, (64, 128))
w2 = jax.random.normal(key, (128, 64))

# Loss function
def loss_fn(x, w1, w2):
    out = checkpointed_mlp(x, w1, w2)
    return jnp.mean(out ** 2)

# Compute gradients
grads = jax.grad(loss_fn, argnums=(1, 2))(x, w1, w2)

print(f"Input shape: {x.shape}")
print(f"W1 shape: {w1.shape}, W1 grad shape: {grads[0].shape}")
print(f"W2 shape: {w2.shape}, W2 grad shape: {grads[1].shape}")
print(f"Grad W1 norm: {jnp.linalg.norm(grads[0]):.4f}")
print(f"Grad W2 norm: {jnp.linalg.norm(grads[1]):.4f}")

# Verify gradients are not zero or NaN
assert jnp.all(jnp.isfinite(grads[0])), "W1 gradients contain inf/nan!"
assert jnp.all(jnp.isfinite(grads[1])), "W2 gradients contain inf/nan!"
assert jnp.linalg.norm(grads[0]) > 0, "W1 gradients are zero!"
assert jnp.linalg.norm(grads[1]) > 0, "W2 gradients are zero!"

print("\n[PASS] checkpoint() produces valid gradients!")

In [None]:
# TEST 1.3: checkpoint_sequential() with Multiple Layers
print("="*60)
print("TEST 1.3: checkpoint_sequential() with Multiple Layers")
print("="*60)

# Create a sequence of layers
key = jax.random.PRNGKey(0)
keys = jax.random.split(key, 6)

# Initialize 6 layers' weights
layers_weights = [
    jax.random.normal(keys[i], (64, 64)) for i in range(6)
]

# Define layer functions
def make_layer_fn(w):
    def layer(x):
        return jax.nn.relu(jnp.dot(x, w))
    return layer

layer_fns = [make_layer_fn(w) for w in layers_weights]

# Input
x = jax.random.normal(key, (16, 64))

# Run checkpoint_sequential
output = checkpoint_sequential(
    functions=layer_fns,
    input_value=x,
    segments=3,  # Use 3 segments
    policy=CheckpointPolicy.DOTS_SAVEABLE,
)

print(f"Input shape: {x.shape}")
print(f"Output shape: {output.shape}")
print(f"Output mean: {jnp.mean(output):.4f}")
print(f"Output std: {jnp.std(output):.4f}")

# Verify output is valid
assert output.shape == x.shape, f"Shape mismatch: {output.shape} vs {x.shape}"
assert jnp.all(jnp.isfinite(output)), "Output contains inf/nan!"

print("\n[PASS] checkpoint_sequential() works correctly!")

In [None]:
# TEST 1.4: Gradient Correctness - Compare with/without checkpointing
print("="*60)
print("TEST 1.4: Gradient Correctness Verification")
print("="*60)

def transformer_block(x, w_attn, w_ff):
    """Simplified transformer block"""
    # Self-attention (simplified)
    attn = jnp.dot(x, w_attn)
    attn = jax.nn.softmax(attn, axis=-1)
    h = jnp.dot(attn, x)

    # FFN
    out = jnp.dot(h, w_ff)
    out = jax.nn.gelu(out)

    return out + x  # Residual

# Create checkpointed version
checkpointed_block = checkpoint(transformer_block, policy=CheckpointPolicy.NOTHING)

# Initialize
key = jax.random.PRNGKey(123)
x = jax.random.normal(key, (8, 32, 64))  # (batch, seq, dim)
w_attn = jax.random.normal(key, (64, 64))
w_ff = jax.random.normal(key, (64, 64))

# Loss functions
def loss_no_ckpt(x, w_attn, w_ff):
    out = transformer_block(x, w_attn, w_ff)
    return jnp.mean(out ** 2)

def loss_with_ckpt(x, w_attn, w_ff):
    out = checkpointed_block(x, w_attn, w_ff)
    return jnp.mean(out ** 2)

# Compute gradients both ways
grads_no_ckpt = jax.grad(loss_no_ckpt, argnums=(1, 2))(x, w_attn, w_ff)
grads_with_ckpt = jax.grad(loss_with_ckpt, argnums=(1, 2))(x, w_attn, w_ff)

# Compare gradients
diff_attn = jnp.abs(grads_no_ckpt[0] - grads_with_ckpt[0]).max()
diff_ff = jnp.abs(grads_no_ckpt[1] - grads_with_ckpt[1]).max()

print(f"Gradient difference (w_attn): {diff_attn:.2e}")
print(f"Gradient difference (w_ff): {diff_ff:.2e}")

# Tolerance for floating-point differences (checkpointing may have small numerical differences)
TOLERANCE = 1e-4
assert diff_attn < TOLERANCE, f"w_attn gradient difference too large: {diff_attn}"
assert diff_ff < TOLERANCE, f"w_ff gradient difference too large: {diff_ff}"

print(f"\n[PASS] Gradients match within tolerance {TOLERANCE}!")

---
## Section 2: Memory Management Validation

We will test:
1. `JAXActivationStore` with real JAX arrays
2. Memory tracking accuracy
3. Eviction policies
4. CPU offloading

In [None]:
# TEST 2.1: JAXActivationStore with JAX Arrays
print("="*60)
print("TEST 2.1: JAXActivationStore with JAX Arrays")
print("="*60)

store = JAXActivationStore(max_memory_bytes=100 * 1024 * 1024)  # 100 MB

# Store JAX arrays
key = jax.random.PRNGKey(0)
arrays = {}

for i in range(5):
    arr = jax.random.normal(key, (1024, 1024))  # ~4MB each
    arrays[i] = arr
    success = store.store(layer_id=i, array=arr)
    print(f"Stored layer {i}: {compute_array_size(arr) / 1024 / 1024:.2f} MB, success={success}")

print(f"\nTotal memory usage: {store.memory_usage / 1024 / 1024:.2f} MB")
print(f"Stored arrays: {len(store)}")

# Retrieve and verify
for i in range(5):
    retrieved = store.retrieve(layer_id=i)
    if retrieved is not None:
        match = jnp.allclose(retrieved, arrays[i])
        print(f"Layer {i} retrieved, matches original: {match}")
        assert match, f"Layer {i} data mismatch!"

print("\n[PASS] JAXActivationStore works with JAX arrays!")

In [None]:
# TEST 2.2: Eviction Under Memory Pressure
print("="*60)
print("TEST 2.2: Eviction Under Memory Pressure")
print("="*60)

# Small memory budget to force eviction
store = JAXActivationStore(
    max_memory_bytes=10 * 1024 * 1024,  # 10 MB
    eviction_policy=EvictionPolicy.LRU,
)

key = jax.random.PRNGKey(42)

# Store more than budget allows
for i in range(10):
    arr = jax.random.normal(key, (512, 512))  # ~1MB each
    store.store(layer_id=i, array=arr)

stats = store.statistics
print(f"Stored: {stats['store_count']} arrays")
print(f"Evicted: {stats['eviction_count']} arrays")
print(f"Current memory: {stats['current_memory_mb']:.2f} MB")
print(f"Peak memory: {stats['peak_memory_mb']:.2f} MB")
print(f"Currently stored: {stats['stored_count']} arrays")

# Verify eviction happened
assert stats['eviction_count'] > 0, "Eviction should have occurred!"
assert stats['current_memory_bytes'] <= 10 * 1024 * 1024, "Memory exceeds budget!"

print("\n[PASS] Eviction works correctly under memory pressure!")

In [None]:
# TEST 2.3: Device Detection for JAX Arrays
print("="*60)
print("TEST 2.3: Device Detection")
print("="*60)

# Create arrays on different devices
key = jax.random.PRNGKey(0)
gpu_array = jax.random.normal(key, (100, 100))
cpu_array = jax.device_put(gpu_array, jax.devices('cpu')[0])

gpu_device = get_device_string(gpu_array)
cpu_device = get_device_string(cpu_array)

print(f"GPU array device: {gpu_device}")
print(f"CPU array device: {cpu_device}")

# Verify device detection
assert 'gpu' in gpu_device.lower() or 'cuda' in gpu_device.lower() or 'tpu' in gpu_device.lower(), f"Expected GPU device, got {gpu_device}"
assert 'cpu' in cpu_device.lower(), f"Expected CPU device, got {cpu_device}"

print("\n[PASS] Device detection works correctly!")

In [None]:
# TEST 2.4: JAXMemoryManager with CPU Offloading
print("="*60)
print("TEST 2.4: JAXMemoryManager with CPU Offloading")
print("="*60)

config = JAXMemoryConfig(
    max_memory_bytes=50 * 1024 * 1024,  # 50 MB
    enable_offloading=True,
    offload_threshold_bytes=5 * 1024 * 1024,  # 5 MB threshold
)

manager = JAXMemoryManager(config=config)

key = jax.random.PRNGKey(0)

# Store a large array (should be offloaded)
large_array = jax.random.normal(key, (2048, 2048))  # ~16 MB
success = manager.store(layer_id=0, array=large_array, allow_offload=True)
print(f"Large array stored: {success}")

# Store a small array (should stay on GPU)
small_array = jax.random.normal(key, (256, 256))  # ~0.25 MB
success = manager.store(layer_id=1, array=small_array, allow_offload=True)
print(f"Small array stored: {success}")

stats = manager.get_statistics()
print(f"\nOn-device memory: {stats['current_memory_mb']:.2f} MB")
print(f"Offloaded arrays: {stats['offloaded_count']}")
print(f"Offloaded size: {stats['offloaded_mb']:.2f} MB")
print(f"Total managed: {stats['total_managed_mb']:.2f} MB")

# Retrieve and verify
retrieved_large = manager.retrieve(layer_id=0, prefetch=True)
retrieved_small = manager.retrieve(layer_id=1)

print(f"\nLarge array retrieved shape: {retrieved_large.shape}")
print(f"Small array retrieved shape: {retrieved_small.shape}")

# Verify data integrity
assert jnp.allclose(retrieved_small, small_array), "Small array data mismatch!"
# Note: large array may have been moved to CPU and back, slight numerical differences possible

print("\n[PASS] JAXMemoryManager works with CPU offloading!")

---
## Section 3: Mixed Precision Validation

We will test:
1. `MixedPrecisionPolicy` dtype conversions
2. `DynamicLossScaler` with actual gradients
3. `ZenithMixedPrecision` end-to-end training

In [None]:
# TEST 3.1: MixedPrecisionPolicy Dtype Conversions
print("="*60)
print("TEST 3.1: MixedPrecisionPolicy Dtype Conversions")
print("="*60)

# Test FP32 policy
policy_fp32 = MixedPrecisionPolicy.fp32()
print(f"FP32 policy:")
print(f"  param_dtype: {policy_fp32.param_dtype}")
print(f"  compute_dtype: {policy_fp32.compute_dtype}")
print(f"  requires_loss_scaling: {policy_fp32.requires_loss_scaling}")

# Test BF16 policy
policy_bf16 = MixedPrecisionPolicy.bf16()
print(f"\nBF16 policy:")
print(f"  param_dtype: {policy_bf16.param_dtype}")
print(f"  compute_dtype: {policy_bf16.compute_dtype}")
print(f"  requires_loss_scaling: {policy_bf16.requires_loss_scaling}")

# Test FP16 policy
policy_fp16 = MixedPrecisionPolicy.fp16()
print(f"\nFP16 policy:")
print(f"  param_dtype: {policy_fp16.param_dtype}")
print(f"  compute_dtype: {policy_fp16.compute_dtype}")
print(f"  requires_loss_scaling: {policy_fp16.requires_loss_scaling}")

# Verify dtype conversions work
arr = jnp.ones((10, 10), dtype=jnp.float32)
bf16_dtype = policy_bf16.get_jax_dtype('bfloat16')
arr_bf16 = arr.astype(bf16_dtype)

print(f"\nOriginal dtype: {arr.dtype}")
print(f"BF16 dtype: {arr_bf16.dtype}")

assert arr_bf16.dtype == jnp.bfloat16, "BF16 conversion failed!"

print("\n[PASS] MixedPrecisionPolicy dtype conversions work!")

In [None]:
# TEST 3.2: DynamicLossScaler with Real Gradients
print("="*60)
print("TEST 3.2: DynamicLossScaler with Real Gradients")
print("="*60)

scaler = DynamicLossScaler(LossScalerConfig(
    initial_scale=2**15,
    growth_factor=2.0,
    backoff_factor=0.5,
    growth_interval=5,
))

print(f"Initial scale: {scaler.scale}")

# Simulate training loop
key = jax.random.PRNGKey(0)
params = jax.random.normal(key, (64, 64))

def loss_fn(params, x):
    return jnp.mean((jnp.dot(x, params)) ** 2)

x = jax.random.normal(key, (32, 64))

# Training steps with scaling
for step in range(10):
    # Scale loss
    loss = loss_fn(params, x)
    scaled_loss = scaler.scale_loss(loss)

    # Compute scaled gradients
    def scaled_loss_fn(params):
        return scaler.scale_loss(loss_fn(params, x))

    scaled_grads = jax.grad(scaled_loss_fn)(params)

    # Unscale and check
    grads, is_finite = scaler.unscale_grads({'params': scaled_grads})

    # Update scaler
    scaler.update(is_finite)

    if step % 2 == 0:
        print(f"Step {step}: loss={float(loss):.4f}, scale={scaler.scale:.0f}, finite={is_finite}")

print(f"\nFinal scale: {scaler.scale}")
print(f"Total overflow count: {scaler.state.overflow_count}")

print("\n[PASS] DynamicLossScaler works with real gradients!")

In [None]:
# TEST 3.3: ZenithMixedPrecision End-to-End
print("="*60)
print("TEST 3.3: ZenithMixedPrecision End-to-End Training")
print("="*60)

# Create mixed precision handler with BF16
mp = ZenithMixedPrecision(policy='bf16')

print(f"Policy: {mp.policy.mode.value}")
print(f"Compute dtype: {mp.policy.compute_dtype}")
print(f"Has scaler: {mp.scaler is not None}")

# Initialize model params in FP32
key = jax.random.PRNGKey(0)
params = {
    'w1': jax.random.normal(key, (64, 128), dtype=jnp.float32),
    'w2': jax.random.normal(key, (128, 64), dtype=jnp.float32),
}

print(f"\nOriginal param dtypes:")
print(f"  w1: {params['w1'].dtype}")
print(f"  w2: {params['w2'].dtype}")

# Cast to compute dtype
compute_params = mp.cast_to_compute(params)

print(f"\nCompute param dtypes:")
print(f"  w1: {compute_params['w1'].dtype}")
print(f"  w2: {compute_params['w2'].dtype}")

# Verify BF16
assert compute_params['w1'].dtype == jnp.bfloat16, "w1 not BF16!"
assert compute_params['w2'].dtype == jnp.bfloat16, "w2 not BF16!"

# Simulate forward pass
x = jax.random.normal(key, (32, 64)).astype(jnp.bfloat16)

def forward(params, x):
    h = jnp.dot(x, params['w1'])
    h = jax.nn.relu(h)
    return jnp.dot(h, params['w2'])

output = forward(compute_params, x)
print(f"\nOutput dtype: {output.dtype}")
print(f"Output shape: {output.shape}")

# Cast back to param dtype
back_to_fp32 = mp.cast_to_param(compute_params)
print(f"\nBack to FP32:")
print(f"  w1: {back_to_fp32['w1'].dtype}")
print(f"  w2: {back_to_fp32['w2'].dtype}")

assert back_to_fp32['w1'].dtype == jnp.float32, "w1 not back to FP32!"

print("\n[PASS] ZenithMixedPrecision end-to-end works correctly!")

In [None]:
# TEST 3.4: FP16 with Loss Scaling Full Loop
print("="*60)
print("TEST 3.4: FP16 with Loss Scaling Full Training Loop")
print("="*60)

# Create FP16 handler (requires loss scaling)
mp_fp16 = ZenithMixedPrecision(policy='fp16')

print(f"Policy: {mp_fp16.policy.mode.value}")
print(f"Has scaler: {mp_fp16.scaler is not None}")
print(f"Initial scale: {mp_fp16.scaler.scale if mp_fp16.scaler else 'N/A'}")

# Initialize params
key = jax.random.PRNGKey(42)
params = {'w': jax.random.normal(key, (32, 32), dtype=jnp.float32)}
x = jax.random.normal(key, (16, 32), dtype=jnp.float32)

def loss_fn(params, x):
    out = jnp.dot(x, params['w'])
    return jnp.mean(out ** 2)

print("\nTraining steps:")
for step in range(5):
    # Cast to FP16
    compute_params = mp_fp16.cast_to_compute(params)
    compute_x = x.astype(jnp.float16)

    # Compute loss
    loss = loss_fn(compute_params, compute_x)

    # Scale loss
    scaled_loss = mp_fp16.scale_loss(loss)

    # Compute gradients
    def scaled_loss_fn(params):
        return mp_fp16.scale_loss(loss_fn(params, compute_x))

    grads = jax.grad(scaled_loss_fn)(compute_params)

    # Unscale and handle
    grads, is_finite = mp_fp16.handle_grads(grads)

    # Update scale
    mp_fp16.update_scale(is_finite)

    print(f"  Step {step}: loss={float(loss):.4f}, scale={mp_fp16.scaler.scale:.0f}, finite={is_finite}")

print(f"\nFinal stats:")
print(f"  Total steps: {mp_fp16.stats.total_steps}")
print(f"  Overflow count: {mp_fp16.stats.overflow_count}")

print("\n[PASS] FP16 with loss scaling works correctly!")

In [None]:
# TEST 3.5: Hardware Detection
print("="*60)
print("TEST 3.5: Hardware-Based Precision Detection")
print("="*60)

best_precision = detect_best_precision()
print(f"Detected best precision: {best_precision}")

# On GPU, should recommend bf16
print(f"\nDevice info:")
for device in jax.devices():
    print(f"  {device.platform}: {device}")

print("\n[PASS] Hardware detection completed!")

---
## Section 4: Performance Benchmarks

Measure actual performance improvements from:
1. Gradient checkpointing memory reduction
2. Mixed precision speedup

In [None]:
# BENCHMARK 1: Gradient Checkpointing Memory Impact
print("="*60)
print("BENCHMARK 1: Checkpointing Memory Impact")
print("="*60)

import time

# Deep network simulation
NUM_LAYERS = 24
HIDDEN_DIM = 512
BATCH_SIZE = 32

key = jax.random.PRNGKey(0)
keys = jax.random.split(key, NUM_LAYERS + 1)

# Initialize weights for all layers
weights = [jax.random.normal(keys[i], (HIDDEN_DIM, HIDDEN_DIM)) for i in range(NUM_LAYERS)]
x = jax.random.normal(keys[NUM_LAYERS], (BATCH_SIZE, HIDDEN_DIM))

def layer_fn(x, w):
    return jax.nn.relu(jnp.dot(x, w))

# Without checkpointing
def forward_no_ckpt(x, weights):
    for w in weights:
        x = layer_fn(x, w)
    return x

def loss_no_ckpt(x, weights):
    return jnp.mean(forward_no_ckpt(x, weights) ** 2)

# With checkpointing
checkpointed_layer = checkpoint(layer_fn, policy=CheckpointPolicy.NOTHING)

def forward_with_ckpt(x, weights):
    for w in weights:
        x = checkpointed_layer(x, w)
    return x

def loss_with_ckpt(x, weights):
    return jnp.mean(forward_with_ckpt(x, weights) ** 2)

# Warm up JIT
grad_fn_no_ckpt = jax.jit(jax.grad(loss_no_ckpt, argnums=1))
grad_fn_with_ckpt = jax.jit(jax.grad(loss_with_ckpt, argnums=1))

# Warm up
_ = grad_fn_no_ckpt(x, weights)
_ = grad_fn_with_ckpt(x, weights)

# Benchmark
N_RUNS = 10

start = time.time()
for _ in range(N_RUNS):
    grads = grad_fn_no_ckpt(x, weights)
    jax.block_until_ready(grads)
time_no_ckpt = (time.time() - start) / N_RUNS * 1000

start = time.time()
for _ in range(N_RUNS):
    grads = grad_fn_with_ckpt(x, weights)
    jax.block_until_ready(grads)
time_with_ckpt = (time.time() - start) / N_RUNS * 1000

print(f"Network: {NUM_LAYERS} layers, {HIDDEN_DIM} hidden dim, batch {BATCH_SIZE}")
print(f"\nTime per backward pass:")
print(f"  Without checkpointing: {time_no_ckpt:.2f} ms")
print(f"  With checkpointing: {time_with_ckpt:.2f} ms")
print(f"  Overhead: {((time_with_ckpt / time_no_ckpt) - 1) * 100:.1f}%")

# Memory estimation
selector = OptimalCheckpointSelector(NUM_LAYERS)
checkpoints = selector.select_sqrt()
memory_reduction = selector.estimate_memory_reduction(checkpoints)
print(f"\nEstimated memory reduction: {memory_reduction:.1f}%")

print("\n[BENCHMARK COMPLETE]")

In [None]:
# BENCHMARK 2: Mixed Precision Speedup
print("="*60)
print("BENCHMARK 2: Mixed Precision Speedup")
print("="*60)

import time

# Matrix multiplication benchmark
SIZE = 2048
key = jax.random.PRNGKey(0)

a_fp32 = jax.random.normal(key, (SIZE, SIZE), dtype=jnp.float32)
b_fp32 = jax.random.normal(key, (SIZE, SIZE), dtype=jnp.float32)

a_bf16 = a_fp32.astype(jnp.bfloat16)
b_bf16 = b_fp32.astype(jnp.bfloat16)

a_fp16 = a_fp32.astype(jnp.float16)
b_fp16 = b_fp32.astype(jnp.float16)

# JIT compile
matmul = jax.jit(jnp.dot)

# Warm up
_ = matmul(a_fp32, b_fp32)
_ = matmul(a_bf16, b_bf16)
_ = matmul(a_fp16, b_fp16)

N_RUNS = 20

# FP32
start = time.time()
for _ in range(N_RUNS):
    c = matmul(a_fp32, b_fp32)
    jax.block_until_ready(c)
time_fp32 = (time.time() - start) / N_RUNS * 1000

# BF16
start = time.time()
for _ in range(N_RUNS):
    c = matmul(a_bf16, b_bf16)
    jax.block_until_ready(c)
time_bf16 = (time.time() - start) / N_RUNS * 1000

# FP16
start = time.time()
for _ in range(N_RUNS):
    c = matmul(a_fp16, b_fp16)
    jax.block_until_ready(c)
time_fp16 = (time.time() - start) / N_RUNS * 1000

print(f"Matrix size: {SIZE}x{SIZE}")
print(f"\nTime per matmul:")
print(f"  FP32: {time_fp32:.2f} ms")
print(f"  BF16: {time_bf16:.2f} ms (speedup: {time_fp32/time_bf16:.2f}x)")
print(f"  FP16: {time_fp16:.2f} ms (speedup: {time_fp32/time_fp16:.2f}x)")

print("\n[BENCHMARK COMPLETE]")

---
## Final Summary

In [None]:
# FINAL SUMMARY
print("="*70)
print("ZENITH JAX PHASE 1 VALIDATION - FINAL SUMMARY")
print("="*70)

print("""
TESTS COMPLETED:
================

Section 1: Gradient Checkpointing
  [x] TEST 1.1: OptimalCheckpointSelector algorithms
  [x] TEST 1.2: checkpoint() with real JAX gradients
  [x] TEST 1.3: checkpoint_sequential() with multiple layers
  [x] TEST 1.4: Gradient correctness verification

Section 2: Memory Management
  [x] TEST 2.1: JAXActivationStore with JAX arrays
  [x] TEST 2.2: Eviction under memory pressure
  [x] TEST 2.3: Device detection
  [x] TEST 2.4: JAXMemoryManager with CPU offloading

Section 3: Mixed Precision
  [x] TEST 3.1: MixedPrecisionPolicy dtype conversions
  [x] TEST 3.2: DynamicLossScaler with real gradients
  [x] TEST 3.3: ZenithMixedPrecision end-to-end
  [x] TEST 3.4: FP16 with loss scaling full loop
  [x] TEST 3.5: Hardware detection

Section 4: Performance Benchmarks
  [x] BENCHMARK 1: Checkpointing memory impact
  [x] BENCHMARK 2: Mixed precision speedup

CONCLUSION:
===========
If all tests above show [PASS], then Phase 1 JAX Core Integration
is validated and ready for production use.

Proceed to Phase 2: XLA Backend & ONNX Export Enhancement.
""")

print("="*70)