# Zenith JAX Core Integration - Phase 1 Validation

**Date:** 2025-12-28  
**Purpose:** Validate Phase 1 implementation with real JAX on GPU

## Components Being Tested:
1. **Gradient Checkpointing** - `zenith.jax.checkpointing`
2. **Memory Management** - `zenith.jax.memory_manager`
3. **Mixed Precision Training** - `zenith.jax.mixed_precision`

## Hardware Requirements:
- GPU Runtime (T4 or better)
- JAX with GPU support

---

## Setup: Install Dependencies

In [None]:
# Install pyzenith from GitHub (latest with JAX Core Integration)
%pip install git+https://github.com/vibeswithkk/ZENITH.git --quiet

# Verify JAX GPU support (pre-installed in Colab)
import jax
import jax.numpy as jnp

print(f"JAX version: {jax.__version__}")
print(f"Devices: {jax.devices()}")
print(f"Default backend: {jax.default_backend()}")

In [None]:
# Verify Zenith installation
import zenith
print(f"Zenith version: {zenith.__version__}")

# Import Phase 1 modules
from zenith.jax.checkpointing import (
    OptimalCheckpointSelector,
    ZenithCheckpointer,
    CheckpointConfig,
    CheckpointPolicy,
    SelectionMethod,
    checkpoint,
    checkpoint_sequential,
    remat,
)
from zenith.jax.memory_manager import (
    JAXActivationStore,
    JAXMemoryManager,
    JAXMemoryConfig,
    EvictionPolicy,
    compute_array_size,
    get_device_string,
)
from zenith.jax.mixed_precision import (
    MixedPrecisionPolicy,
    DynamicLossScaler,
    LossScalerConfig,
    ZenithMixedPrecision,
    PrecisionMode,
    create_policy,
    detect_best_precision,
)

print("\nAll Phase 1 modules imported successfully!")

---
## Section 1: Gradient Checkpointing Validation

We will test:
1. `OptimalCheckpointSelector` algorithms (sqrt, DP, uniform)
2. `checkpoint()` function with actual `jax.grad`
3. Memory reduction measurement

In [None]:
# TEST 1.1: OptimalCheckpointSelector Algorithms
print("="*60)
print("TEST 1.1: OptimalCheckpointSelector Algorithms")
print("="*60)

# Test with various network sizes
for num_layers in [4, 12, 24, 48, 96]:
    selector = OptimalCheckpointSelector(num_layers=num_layers)

    sqrt_ckpts = selector.select_sqrt()
    dp_ckpts = selector.select_dp()
    uniform_ckpts = selector.select_uniform(num_checkpoints=4)

    reduction = selector.estimate_memory_reduction(sqrt_ckpts)

    print(f"\nLayers: {num_layers}")
    print(f"  sqrt: {len(sqrt_ckpts)} checkpoints")
    print(f"  DP:   {len(dp_ckpts)} checkpoints")
    print(f"  Memory reduction (sqrt): {reduction:.1f}%")

print("\n[PASS] OptimalCheckpointSelector works correctly!")

In [None]:
# TEST 1.2: checkpoint() with Real JAX Gradient Computation
print("="*60)
print("TEST 1.2: checkpoint() with Real JAX Gradients")
print("="*60)

import jax
import jax.numpy as jnp

# Define a simple MLP layer
def mlp_layer(x, w1, w2):
    """Simple MLP: x -> linear -> relu -> linear"""
    h = jnp.dot(x, w1)
    h = jax.nn.relu(h)
    return jnp.dot(h, w2)

# Create checkpointed version using Zenith
checkpointed_mlp = checkpoint(mlp_layer, policy=CheckpointPolicy.DOTS_SAVEABLE)

# Initialize weights
key = jax.random.PRNGKey(42)
x = jax.random.normal(key, (32, 64))  # batch=32, features=64
w1 = jax.random.normal(key, (64, 128))
w2 = jax.random.normal(key, (128, 64))

# Loss function
def loss_fn(x, w1, w2):
    out = checkpointed_mlp(x, w1, w2)
    return jnp.mean(out ** 2)

# Compute gradients
grads = jax.grad(loss_fn, argnums=(1, 2))(x, w1, w2)

print(f"Input shape: {x.shape}")
print(f"W1 shape: {w1.shape}, W1 grad shape: {grads[0].shape}")
print(f"W2 shape: {w2.shape}, W2 grad shape: {grads[1].shape}")
print(f"Grad W1 norm: {jnp.linalg.norm(grads[0]):.4f}")
print(f"Grad W2 norm: {jnp.linalg.norm(grads[1]):.4f}")

# Verify gradients are not zero or NaN
assert jnp.all(jnp.isfinite(grads[0])), "W1 gradients contain inf/nan!"
assert jnp.all(jnp.isfinite(grads[1])), "W2 gradients contain inf/nan!"
assert jnp.linalg.norm(grads[0]) > 0, "W1 gradients are zero!"
assert jnp.linalg.norm(grads[1]) > 0, "W2 gradients are zero!"

print("\n[PASS] checkpoint() produces valid gradients!")

In [None]:
# TEST 1.3: Gradient Correctness Verification
print("="*60)
print("TEST 1.3: Gradient Correctness - Compare with/without checkpointing")
print("="*60)

def simple_net(x, w):
    h = jnp.dot(x, w)
    h = jax.nn.relu(h)
    return jnp.mean(h ** 2)

# Create checkpointed version
checkpointed_net = checkpoint(simple_net, policy=CheckpointPolicy.NOTHING)

# Initialize
key = jax.random.PRNGKey(123)
x = jax.random.normal(key, (16, 32))
w = jax.random.normal(key, (32, 32))

# Compute gradients both ways
grad_no_ckpt = jax.grad(simple_net, argnums=1)(x, w)
grad_with_ckpt = jax.grad(checkpointed_net, argnums=1)(x, w)

# Compare gradients
diff = jnp.abs(grad_no_ckpt - grad_with_ckpt).max()

print(f"Gradient difference (max): {diff:.2e}")

TOLERANCE = 1e-5
assert diff < TOLERANCE, f"Gradient difference too large: {diff}"

print(f"\n[PASS] Gradients match within tolerance {TOLERANCE}!")

---
## Section 2: Memory Management Validation

We will test:
1. `JAXActivationStore` with real JAX arrays
2. Memory tracking accuracy
3. Eviction policies

In [None]:
# TEST 2.1: JAXActivationStore with JAX Arrays
print("="*60)
print("TEST 2.1: JAXActivationStore with JAX Arrays")
print("="*60)

store = JAXActivationStore(max_memory_bytes=100 * 1024 * 1024)

# Store JAX arrays
key = jax.random.PRNGKey(0)
arrays = {}

for i in range(5):
    arr = jax.random.normal(key, (1024, 1024))
    arrays[i] = arr
    success = store.store(layer_id=i, array=arr)
    size_mb = compute_array_size(arr) / 1024 / 1024
    print(f"Stored layer {i}: {size_mb:.2f} MB, success={success}")

print(f"\nTotal memory usage: {store.memory_usage / 1024 / 1024:.2f} MB")
print(f"Stored arrays: {len(store)}")

# Retrieve and verify
for i in range(5):
    retrieved = store.retrieve(layer_id=i)
    if retrieved is not None:
        match = jnp.allclose(retrieved, arrays[i])
        print(f"Layer {i} retrieved, matches original: {match}")
        assert match, f"Layer {i} data mismatch!"

print("\n[PASS] JAXActivationStore works with JAX arrays!")

In [None]:
# TEST 2.2: Eviction Under Memory Pressure
print("="*60)
print("TEST 2.2: Eviction Under Memory Pressure")
print("="*60)

# Small memory budget to force eviction
store = JAXActivationStore(
    max_memory_bytes=10 * 1024 * 1024,  # 10 MB
    eviction_policy=EvictionPolicy.LRU,
)

key = jax.random.PRNGKey(42)

# Store more than budget allows
for i in range(10):
    arr = jax.random.normal(key, (512, 512))  # ~1MB each
    store.store(layer_id=i, array=arr)

stats = store.statistics
print(f"Stored: {stats['store_count']} arrays")
print(f"Evicted: {stats['eviction_count']} arrays")
print(f"Current memory: {stats['current_memory_mb']:.2f} MB")
print(f"Currently stored: {stats['stored_count']} arrays")

# Verify eviction happened
assert stats['eviction_count'] > 0, "Eviction should have occurred!"
assert stats['current_memory_bytes'] <= 10 * 1024 * 1024, "Memory exceeds budget!"

print("\n[PASS] Eviction works correctly under memory pressure!")

In [None]:
# TEST 2.3: Device Detection for JAX Arrays
print("="*60)
print("TEST 2.3: Device Detection")
print("="*60)

# Create arrays on different devices
key = jax.random.PRNGKey(0)
gpu_array = jax.random.normal(key, (100, 100))

gpu_device = get_device_string(gpu_array)
print(f"GPU array device: {gpu_device}")

# Try CPU if available
try:
    cpu_array = jax.device_put(gpu_array, jax.devices('cpu')[0])
    cpu_device = get_device_string(cpu_array)
    print(f"CPU array device: {cpu_device}")
except:
    print("CPU device not available for comparison")

print("\n[PASS] Device detection completed!")

---
## Section 3: Mixed Precision Validation

We will test:
1. `MixedPrecisionPolicy` dtype conversions
2. `DynamicLossScaler` with actual gradients
3. `ZenithMixedPrecision` end-to-end training

In [None]:
# TEST 3.1: MixedPrecisionPolicy Dtype Conversions
print("="*60)
print("TEST 3.1: MixedPrecisionPolicy Dtype Conversions")
print("="*60)

# Test all policies
for mode in ['fp32', 'bf16', 'fp16']:
    policy = create_policy(mode)
    print(f"\n{mode.upper()} policy:")
    print(f"  param_dtype: {policy.param_dtype}")
    print(f"  compute_dtype: {policy.compute_dtype}")
    print(f"  requires_loss_scaling: {policy.requires_loss_scaling}")

# Verify dtype conversions work
policy_bf16 = MixedPrecisionPolicy.bf16()
arr = jnp.ones((10, 10), dtype=jnp.float32)
bf16_dtype = policy_bf16.get_jax_dtype('bfloat16')
arr_bf16 = arr.astype(bf16_dtype)

print(f"\nOriginal dtype: {arr.dtype}")
print(f"BF16 dtype: {arr_bf16.dtype}")

assert arr_bf16.dtype == jnp.bfloat16, "BF16 conversion failed!"

print("\n[PASS] MixedPrecisionPolicy dtype conversions work!")

In [None]:
# TEST 3.2: DynamicLossScaler with Real Gradients
print("="*60)
print("TEST 3.2: DynamicLossScaler with Real Gradients")
print("="*60)

scaler = DynamicLossScaler(LossScalerConfig(
    initial_scale=2**15,
    growth_factor=2.0,
    backoff_factor=0.5,
    growth_interval=5,
))

print(f"Initial scale: {scaler.scale}")

# Simulate training loop
key = jax.random.PRNGKey(0)
params = jax.random.normal(key, (64, 64))

def loss_fn(params, x):
    return jnp.mean((jnp.dot(x, params)) ** 2)

x = jax.random.normal(key, (32, 64))

# Training steps with scaling
for step in range(10):
    # Compute loss and scale
    loss = loss_fn(params, x)
    
    # Compute scaled gradients
    def scaled_loss_fn(p):
        return scaler.scale_loss(loss_fn(p, x))

    scaled_grads = jax.grad(scaled_loss_fn)(params)

    # Unscale and check
    grads, is_finite = scaler.unscale_grads({'params': scaled_grads})

    # Update scaler
    scaler.update(is_finite)

    if step % 3 == 0:
        print(f"Step {step}: loss={float(loss):.4f}, scale={scaler.scale:.0f}, finite={is_finite}")

print(f"\nFinal scale: {scaler.scale}")

print("\n[PASS] DynamicLossScaler works with real gradients!")

In [None]:
# TEST 3.3: ZenithMixedPrecision End-to-End
print("="*60)
print("TEST 3.3: ZenithMixedPrecision End-to-End Training")
print("="*60)

# Create mixed precision handler with BF16
mp = ZenithMixedPrecision(policy='bf16')

print(f"Policy: {mp.policy.mode.value}")
print(f"Compute dtype: {mp.policy.compute_dtype}")
print(f"Has scaler: {mp.scaler is not None}")

# Initialize model params in FP32
key = jax.random.PRNGKey(0)
params = {
    'w1': jax.random.normal(key, (64, 128), dtype=jnp.float32),
    'w2': jax.random.normal(key, (128, 64), dtype=jnp.float32),
}

print(f"\nOriginal param dtypes:")
print(f"  w1: {params['w1'].dtype}")
print(f"  w2: {params['w2'].dtype}")

# Cast to compute dtype
compute_params = mp.cast_to_compute(params)

print(f"\nCompute param dtypes:")
print(f"  w1: {compute_params['w1'].dtype}")
print(f"  w2: {compute_params['w2'].dtype}")

# Verify BF16
assert compute_params['w1'].dtype == jnp.bfloat16, "w1 not BF16!"
assert compute_params['w2'].dtype == jnp.bfloat16, "w2 not BF16!"

# Cast back to param dtype
back_to_fp32 = mp.cast_to_param(compute_params)
print(f"\nBack to FP32:")
print(f"  w1: {back_to_fp32['w1'].dtype}")
print(f"  w2: {back_to_fp32['w2'].dtype}")

assert back_to_fp32['w1'].dtype == jnp.float32, "w1 not back to FP32!"

print("\n[PASS] ZenithMixedPrecision end-to-end works correctly!")

In [None]:
# TEST 3.4: Hardware Detection
print("="*60)
print("TEST 3.4: Hardware-Based Precision Detection")
print("="*60)

best_precision = detect_best_precision()
print(f"Detected best precision: {best_precision}")

print(f"\nDevice info:")
for device in jax.devices():
    print(f"  {device.platform}: {device}")

print("\n[PASS] Hardware detection completed!")

---
## Section 4: Performance Benchmark

In [None]:
# BENCHMARK: Mixed Precision Speedup
print("="*60)
print("BENCHMARK: Mixed Precision Speedup")
print("="*60)

import time

# Matrix multiplication benchmark
SIZE = 2048
key = jax.random.PRNGKey(0)

a_fp32 = jax.random.normal(key, (SIZE, SIZE), dtype=jnp.float32)
b_fp32 = jax.random.normal(key, (SIZE, SIZE), dtype=jnp.float32)

a_bf16 = a_fp32.astype(jnp.bfloat16)
b_bf16 = b_fp32.astype(jnp.bfloat16)

# JIT compile
matmul = jax.jit(jnp.dot)

# Warm up
_ = matmul(a_fp32, b_fp32).block_until_ready()
_ = matmul(a_bf16, b_bf16).block_until_ready()

N_RUNS = 20

# FP32
start = time.time()
for _ in range(N_RUNS):
    c = matmul(a_fp32, b_fp32).block_until_ready()
time_fp32 = (time.time() - start) / N_RUNS * 1000

# BF16
start = time.time()
for _ in range(N_RUNS):
    c = matmul(a_bf16, b_bf16).block_until_ready()
time_bf16 = (time.time() - start) / N_RUNS * 1000

print(f"Matrix size: {SIZE}x{SIZE}")
print(f"\nTime per matmul:")
print(f"  FP32: {time_fp32:.2f} ms")
print(f"  BF16: {time_bf16:.2f} ms (speedup: {time_fp32/time_bf16:.2f}x)")

print("\n[BENCHMARK COMPLETE]")

---
## Final Summary

In [None]:
# FINAL SUMMARY
print("="*70)
print("ZENITH JAX PHASE 1 VALIDATION - FINAL SUMMARY")
print("="*70)

print("""
TESTS COMPLETED:
================

Section 1: Gradient Checkpointing
  [x] TEST 1.1: OptimalCheckpointSelector algorithms
  [x] TEST 1.2: checkpoint() with real JAX gradients
  [x] TEST 1.3: Gradient correctness verification

Section 2: Memory Management
  [x] TEST 2.1: JAXActivationStore with JAX arrays
  [x] TEST 2.2: Eviction under memory pressure
  [x] TEST 2.3: Device detection

Section 3: Mixed Precision
  [x] TEST 3.1: MixedPrecisionPolicy dtype conversions
  [x] TEST 3.2: DynamicLossScaler with real gradients
  [x] TEST 3.3: ZenithMixedPrecision end-to-end
  [x] TEST 3.4: Hardware detection

Section 4: Performance Benchmarks
  [x] BENCHMARK: Mixed precision speedup

CONCLUSION:
===========
If all tests above show [PASS], then Phase 1 JAX Core Integration
is validated and ready for production use.

Proceed to Phase 2: XLA Backend & ONNX Export Enhancement.
""")

print("="*70)