# Zenith JAX Core Integration - Phase 1 Validation

**Date:** 2025-12-28  
**Purpose:** Validate Phase 1 implementation with real JAX on GPU

## Components Being Tested:
1. **Gradient Checkpointing** - `zenith.jax.checkpointing`
2. **Memory Management** - `zenith.jax.memory_manager`
3. **Mixed Precision Training** - `zenith.jax.mixed_precision`

## Hardware Requirements:
- GPU Runtime (T4 or better)
- JAX with GPU support

---

## Setup: Install Dependencies

In [None]:
# Force reinstall pyzenith from GitHub (latest with JAX Core Integration)
%pip uninstall pyzenith -y 2>/dev/null
%pip install --no-cache-dir git+https://github.com/vibeswithkk/ZENITH.git --quiet

# Verify JAX GPU support (pre-installed in Colab)
import jax
import jax.numpy as jnp

print(f"JAX version: {jax.__version__}")
print(f"Devices: {jax.devices()}")
print(f"Default backend: {jax.default_backend()}")

In [None]:
# Verify Zenith installation
import zenith
print(f"Zenith version: {zenith.__version__}")

# Import Phase 1 modules
from zenith.jax.checkpointing import (
    OptimalCheckpointSelector,
    CheckpointPolicy,
    checkpoint,
)
from zenith.jax.memory_manager import (
    JAXActivationStore,
    EvictionPolicy,
    compute_array_size,
    get_device_string,
)
from zenith.jax.mixed_precision import (
    MixedPrecisionPolicy,
    DynamicLossScaler,
    LossScalerConfig,
    ZenithMixedPrecision,
    create_policy,
    detect_best_precision,
)

print("\nAll Phase 1 modules imported successfully!")

---
## Section 1: Gradient Checkpointing Validation

In [None]:
# TEST 1.1: OptimalCheckpointSelector Algorithms
print("="*60)
print("TEST 1.1: OptimalCheckpointSelector Algorithms")
print("="*60)

for num_layers in [4, 12, 24, 48]:
    selector = OptimalCheckpointSelector(num_layers=num_layers)
    sqrt_ckpts = selector.select_sqrt()
    dp_ckpts = selector.select_dp()
    reduction = selector.estimate_memory_reduction(sqrt_ckpts)

    print(f"Layers: {num_layers} -> sqrt: {len(sqrt_ckpts)}, DP: {len(dp_ckpts)}, reduction: {reduction:.1f}%")

print("\n[PASS] OptimalCheckpointSelector works correctly!")

In [None]:
# TEST 1.2: JAX Checkpoint with Gradients
print("="*60)
print("TEST 1.2: JAX Checkpoint with Gradients")
print("="*60)

def mlp_layer(x, w1, w2):
    h = jnp.dot(x, w1)
    h = jax.nn.relu(h)
    return jnp.dot(h, w2)

# Use jax.checkpoint directly
checkpointed_mlp = jax.checkpoint(mlp_layer)

key = jax.random.PRNGKey(42)
x = jax.random.normal(key, (32, 64))
w1 = jax.random.normal(key, (64, 128))
w2 = jax.random.normal(key, (128, 64))

def loss_fn(x, w1, w2):
    out = checkpointed_mlp(x, w1, w2)
    return jnp.mean(out ** 2)

grads = jax.grad(loss_fn, argnums=(1, 2))(x, w1, w2)

print(f"W1 grad shape: {grads[0].shape}, norm: {jnp.linalg.norm(grads[0]):.4f}")
print(f"W2 grad shape: {grads[1].shape}, norm: {jnp.linalg.norm(grads[1]):.4f}")

assert jnp.all(jnp.isfinite(grads[0])) and jnp.all(jnp.isfinite(grads[1]))
print("\n[PASS] JAX checkpoint with gradients works!")

In [None]:
# TEST 1.3: Zenith checkpoint() wrapper
print("="*60)
print("TEST 1.3: Zenith checkpoint() wrapper")
print("="*60)

def simple_fn(x, w):
    return jax.nn.relu(jnp.dot(x, w))

# Use Zenith's checkpoint wrapper
ckpt_fn = checkpoint(simple_fn, policy=CheckpointPolicy.DOTS_SAVEABLE)

key = jax.random.PRNGKey(0)
x = jax.random.normal(key, (8, 16))
w = jax.random.normal(key, (16, 16))

# Forward pass
out = ckpt_fn(x, w)
print(f"Forward output shape: {out.shape}")

# Gradient computation
def loss(w):
    return jnp.mean(ckpt_fn(x, w) ** 2)

grad_w = jax.grad(loss)(w)
print(f"Gradient shape: {grad_w.shape}, norm: {jnp.linalg.norm(grad_w):.4f}")

assert jnp.all(jnp.isfinite(grad_w)) and jnp.linalg.norm(grad_w) > 0
print("\n[PASS] Zenith checkpoint() wrapper works!")

---
## Section 2: Memory Management Validation

In [None]:
# TEST 2.1: JAXActivationStore with JAX Arrays
print("="*60)
print("TEST 2.1: JAXActivationStore with JAX Arrays")
print("="*60)

store = JAXActivationStore(max_memory_bytes=100 * 1024 * 1024)

key = jax.random.PRNGKey(0)
arrays = {}

for i in range(5):
    arr = jax.random.normal(key, (1024, 1024))
    arrays[i] = arr
    success = store.store(layer_id=i, array=arr)
    size_mb = compute_array_size(arr) / 1024 / 1024
    print(f"Stored layer {i}: {size_mb:.2f} MB, success={success}")

print(f"\nTotal memory: {store.memory_usage / 1024 / 1024:.2f} MB")

# Retrieve and verify
for i in range(5):
    retrieved = store.retrieve(layer_id=i)
    if retrieved is not None:
        match = jnp.allclose(retrieved, arrays[i])
        assert match, f"Layer {i} mismatch!"

print("\n[PASS] JAXActivationStore works with JAX arrays!")

In [None]:
# TEST 2.2: Eviction Under Memory Pressure
print("="*60)
print("TEST 2.2: Eviction Under Memory Pressure")
print("="*60)

# Budget = 5 MB, but we store 10 arrays of ~1MB each
# This FORCES eviction to happen
store = JAXActivationStore(
    max_memory_bytes=5 * 1024 * 1024,  # 5 MB budget
    eviction_policy=EvictionPolicy.LRU,
)

key = jax.random.PRNGKey(42)

# Store 10 arrays of ~1MB each = 10 MB total, but budget is only 5 MB
for i in range(10):
    arr = jax.random.normal(key, (512, 512))  # 512*512*4 = 1MB
    store.store(layer_id=i, array=arr)

stats = store.statistics
print(f"Stored: {stats['store_count']} arrays")
print(f"Evicted: {stats['eviction_count']} arrays")
print(f"Current memory: {stats['current_memory_mb']:.2f} MB")
print(f"Currently in store: {stats['stored_count']} arrays")

# With 5MB budget and 1MB arrays, we can only keep 5 arrays
# So 5 should have been evicted
assert stats['eviction_count'] > 0, "Eviction should have occurred!"
assert stats['current_memory_bytes'] <= 5 * 1024 * 1024, "Memory exceeds budget!"

print("\n[PASS] Eviction works correctly under memory pressure!")

In [None]:
# TEST 2.3: Device Detection
print("="*60)
print("TEST 2.3: Device Detection")
print("="*60)

key = jax.random.PRNGKey(0)
gpu_array = jax.random.normal(key, (100, 100))

device_str = get_device_string(gpu_array)
print(f"Array device: {device_str}")

print("\n[PASS] Device detection completed!")

---
## Section 3: Mixed Precision Validation

In [None]:
# TEST 3.1: MixedPrecisionPolicy
print("="*60)
print("TEST 3.1: MixedPrecisionPolicy Dtype Conversions")
print("="*60)

for mode in ['fp32', 'bf16', 'fp16']:
    policy = create_policy(mode)
    print(f"{mode.upper()}: param={policy.param_dtype}, compute={policy.compute_dtype}, scaling={policy.requires_loss_scaling}")

# Test BF16 conversion
policy_bf16 = MixedPrecisionPolicy.bf16()
arr = jnp.ones((10, 10), dtype=jnp.float32)
arr_bf16 = arr.astype(jnp.bfloat16)

print(f"\nOriginal: {arr.dtype} -> BF16: {arr_bf16.dtype}")
assert arr_bf16.dtype == jnp.bfloat16

print("\n[PASS] MixedPrecisionPolicy works!")

In [None]:
# TEST 3.2: DynamicLossScaler
print("="*60)
print("TEST 3.2: DynamicLossScaler")
print("="*60)

scaler = DynamicLossScaler(LossScalerConfig(
    initial_scale=2**15,
    growth_factor=2.0,
    backoff_factor=0.5,
    growth_interval=5,
))

print(f"Initial scale: {scaler.scale}")

key = jax.random.PRNGKey(0)
params = jax.random.normal(key, (64, 64))
x = jax.random.normal(key, (32, 64))

def loss_fn(p, x):
    return jnp.mean((jnp.dot(x, p)) ** 2)

for step in range(10):
    def scaled_loss(p):
        return scaler.scale_loss(loss_fn(p, x))
    
    grads = jax.grad(scaled_loss)(params)
    unscaled, is_finite = scaler.unscale_grads({'p': grads})
    scaler.update(is_finite)
    
    if step % 3 == 0:
        print(f"Step {step}: scale={scaler.scale:.0f}, finite={is_finite}")

print("\n[PASS] DynamicLossScaler works!")

In [None]:
# TEST 3.3: ZenithMixedPrecision End-to-End
print("="*60)
print("TEST 3.3: ZenithMixedPrecision End-to-End")
print("="*60)

mp = ZenithMixedPrecision(policy='bf16')

print(f"Policy: {mp.policy.mode.value}")
print(f"Compute dtype: {mp.policy.compute_dtype}")

key = jax.random.PRNGKey(0)
params = {
    'w1': jax.random.normal(key, (64, 128), dtype=jnp.float32),
    'w2': jax.random.normal(key, (128, 64), dtype=jnp.float32),
}

print(f"\nOriginal: w1={params['w1'].dtype}, w2={params['w2'].dtype}")

# Cast to BF16
compute_params = mp.cast_to_compute(params)
print(f"Compute: w1={compute_params['w1'].dtype}, w2={compute_params['w2'].dtype}")

assert compute_params['w1'].dtype == jnp.bfloat16

# Cast back
back = mp.cast_to_param(compute_params)
print(f"Back: w1={back['w1'].dtype}")

assert back['w1'].dtype == jnp.float32

print("\n[PASS] ZenithMixedPrecision works!")

In [None]:
# TEST 3.4: Hardware Detection
print("="*60)
print("TEST 3.4: Hardware Detection")
print("="*60)

best = detect_best_precision()
print(f"Detected best precision: {best}")

for device in jax.devices():
    print(f"  Device: {device.platform} - {device}")

print("\n[PASS] Hardware detection works!")

---
## Section 4: Performance Benchmark

In [None]:
# BENCHMARK: Mixed Precision Speedup
print("="*60)
print("BENCHMARK: Mixed Precision Speedup")
print("="*60)

import time

SIZE = 2048
key = jax.random.PRNGKey(0)

a_fp32 = jax.random.normal(key, (SIZE, SIZE), dtype=jnp.float32)
b_fp32 = jax.random.normal(key, (SIZE, SIZE), dtype=jnp.float32)
a_bf16 = a_fp32.astype(jnp.bfloat16)
b_bf16 = b_fp32.astype(jnp.bfloat16)

matmul = jax.jit(jnp.dot)

# Warmup
_ = matmul(a_fp32, b_fp32).block_until_ready()
_ = matmul(a_bf16, b_bf16).block_until_ready()

N = 20

start = time.time()
for _ in range(N):
    matmul(a_fp32, b_fp32).block_until_ready()
t_fp32 = (time.time() - start) / N * 1000

start = time.time()
for _ in range(N):
    matmul(a_bf16, b_bf16).block_until_ready()
t_bf16 = (time.time() - start) / N * 1000

print(f"Matrix: {SIZE}x{SIZE}")
print(f"FP32: {t_fp32:.2f} ms")
print(f"BF16: {t_bf16:.2f} ms (speedup: {t_fp32/t_bf16:.2f}x)")

print("\n[BENCHMARK COMPLETE]")

---
## Final Summary

In [None]:
print("="*70)
print("ZENITH JAX PHASE 1 VALIDATION - COMPLETE")
print("="*70)

print("""
All Tests Passed:

 Section 1: Gradient Checkpointing
   [PASS] TEST 1.1: OptimalCheckpointSelector
   [PASS] TEST 1.2: JAX checkpoint with gradients
   [PASS] TEST 1.3: Zenith checkpoint() wrapper

 Section 2: Memory Management
   [PASS] TEST 2.1: JAXActivationStore
   [PASS] TEST 2.2: Eviction under pressure
   [PASS] TEST 2.3: Device detection

 Section 3: Mixed Precision
   [PASS] TEST 3.1: MixedPrecisionPolicy
   [PASS] TEST 3.2: DynamicLossScaler
   [PASS] TEST 3.3: ZenithMixedPrecision
   [PASS] TEST 3.4: Hardware detection

 Section 4: Performance
   [DONE] Mixed precision benchmark

Phase 1 JAX Core Integration VALIDATED!
Ready for Phase 2: XLA Backend & ONNX Export
""")
print("="*70)