# JaxFrames Test Suite

This notebook runs the complete test and benchmark suite for JaxFrames, including the new parallel algorithms (Stage 3).

## Quick Start

Run all cells in order to:
1. Install JaxFrames from GitHub
2. Set up TPU environment (if applicable)
3. Run functionality tests
4. Run performance benchmarks
5. Execute full test suite

**Note**: This notebook is designed to work on:
- Local CPU/GPU environments
- TPU VMs
- Google Colab with TPU runtime

## Installation

In [None]:
# Install JaxFrames from GitHub
!pip install --upgrade pip
!pip install git+https://github.com/solenyaPickleman/jaxframes.git

# Install additional dependencies if needed
!pip install numpy pandas pytest

print("JaxFrames and dependencies installed!")

## Environment Setup

In [None]:
# Setup and imports
import sys
import os
import time
import numpy as np
import pandas as pd
import jax
import jax.numpy as jnp
from jax.sharding import Mesh

# Check JAX devices
print(f"JAX version: {jax.__version__}")
print(f"Number of devices: {len(jax.devices())}")
print(f"Device types: {[d.device_kind for d in jax.devices()]}")
print(f"Device details:")
for i, device in enumerate(jax.devices()):
    print(f"  Device {i}: {device}")

In [None]:
# TPU-specific setup (if running on TPU VM)
import os

# Kill any existing TPU processes if needed (for TPU VMs)
if 'TPU_NAME' in os.environ:
    print("Running on TPU VM, cleaning up processes...")
    !pkill -f libtpu.so || true
    print("TPU cleanup complete")
    
# For Colab TPU runtime
if 'COLAB_TPU_ADDR' in os.environ:
    import jax.tools.colab_tpu
    jax.tools.colab_tpu.setup_tpu()
    print("Colab TPU setup complete")

In [None]:
# Verify JaxFrames installation
try:
    from jaxframes import JaxFrame
    from jaxframes.distributed import DistributedJaxFrame
    from jaxframes.distributed.sharding import row_sharded
    print("✓ JaxFrames imported successfully!")
    
    # Check version if available
    import jaxframes
    if hasattr(jaxframes, '__version__'):
        print(f"  Version: {jaxframes.__version__}")
    else:
        print("  Version: development")
        
except ImportError as e:
    print(f"✗ Failed to import JaxFrames: {e}")
    print("  Please ensure installation completed successfully")

## 1. Basic Functionality Tests

In [None]:
# Test basic JaxFrame creation and operations
print("Testing basic JaxFrame functionality...")

# Create a simple DataFrame
data = {
    'a': jnp.array([1, 2, 3, 4, 5]),
    'b': jnp.array([10, 20, 30, 40, 50]),
    'c': jnp.array([100.0, 200.0, 300.0, 400.0, 500.0])
}

df = JaxFrame(data)
print(f"Created JaxFrame: {df}")
print(f"Shape: {df.shape}")
print(f"Columns: {df.columns}")

# Test arithmetic operations
df2 = df + 10
print(f"\nAfter adding 10:")
print(df2.to_pandas())

# Test reduction operations
sums = df.sum()
print(f"\nColumn sums: {sums}")

means = df.mean()
print(f"Column means: {means}")

## 2. Parallel Algorithms Tests

In [None]:
# Test sorting
print("Testing sort_values...")

# Create unsorted data
np.random.seed(42)
sort_data = {
    'key': jnp.array(np.random.randint(0, 100, size=20)),
    'value': jnp.arange(20)
}

df_unsorted = JaxFrame(sort_data)
print("Unsorted DataFrame:")
print(df_unsorted.to_pandas().head(10))

# Sort by key
df_sorted = df_unsorted.sort_values('key')
print("\nSorted by 'key':")
print(df_sorted.to_pandas().head(10))

# Verify sorting
assert np.all(np.diff(df_sorted.data['key']) >= 0), "Sorting failed!"
print("✓ Sorting verified")

In [None]:
# Test groupby aggregations
print("Testing groupby aggregations...")

# Create data with groups
group_data = {
    'group': jnp.array([1, 2, 1, 3, 2, 1, 3, 2]),
    'value1': jnp.array([10, 20, 30, 40, 50, 60, 70, 80]),
    'value2': jnp.array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0])
}

df_groups = JaxFrame(group_data)
print("Original DataFrame:")
print(df_groups.to_pandas())

# Test groupby sum
grouped_sum = df_groups.groupby('group').sum()
print("\nGroupBy Sum:")
print(grouped_sum.to_pandas())

# Test groupby mean
grouped_mean = df_groups.groupby('group').mean()
print("\nGroupBy Mean:")
print(grouped_mean.to_pandas())

print("✓ GroupBy operations completed")

In [None]:
# Test merge/join operations
print("Testing merge operations...")

# Create left DataFrame
left_data = {
    'key': jnp.array([1, 2, 3, 4]),
    'left_value': jnp.array([10, 20, 30, 40])
}
df_left = JaxFrame(left_data)

# Create right DataFrame
right_data = {
    'key': jnp.array([2, 3, 4, 5]),
    'right_value': jnp.array([200, 300, 400, 500])
}
df_right = JaxFrame(right_data)

print("Left DataFrame:")
print(df_left.to_pandas())
print("\nRight DataFrame:")
print(df_right.to_pandas())

# Perform inner join
df_merged = df_left.merge(df_right, on='key', how='inner')
print("\nMerged (inner join):")
print(df_merged.to_pandas())

print("✓ Merge operations completed")

## 3. Performance Benchmarks

In [None]:
# Run performance benchmarks
print("Running performance benchmarks...")
print("=" * 60)

# Try to run the benchmark script if available
import subprocess
import os

# Check if we're in the repo or installed via pip
if os.path.exists("benchmarks/benchmark_parallel_algorithms.py"):
    # Running from repo
    result = subprocess.run(
        ["python", "benchmarks/benchmark_parallel_algorithms.py"],
        capture_output=True,
        text=True
    )
    print(result.stdout)
    if result.stderr:
        print("Benchmark errors:", result.stderr)
else:
    # Running from pip install - run inline benchmarks
    print("Running inline benchmarks (full benchmark script not available in pip install)")
    
    from jaxframes.distributed.parallel_algorithms import parallel_sort
    import time
    
    # Quick sort benchmark
    size = 100_000
    np.random.seed(42)
    data = jnp.array(np.random.randint(0, size, size=size))
    
    # Time the sort
    start = time.perf_counter()
    sorted_data = parallel_sort(data)
    jax_time = time.perf_counter() - start
    
    print(f"\nSort benchmark ({size:,} elements):")
    print(f"  JAX parallel sort: {jax_time:.4f}s")
    print(f"  Throughput: {size/jax_time/1e6:.2f}M elements/s")

## 4. Run Full Test Suite

In [None]:
# Run pytest tests
print("Running pytest test suite...")
print("=" * 60)

import subprocess
import os

# Check if tests directory exists
if os.path.exists("tests/"):
    # Running from repo
    result = subprocess.run(
        ["python", "-m", "pytest", "tests/", "-v", "--tb=short"],
        capture_output=True,
        text=True
    )
    
    print(result.stdout)
    if result.stderr:
        print("Errors:")
        print(result.stderr)
    
    if result.returncode == 0:
        print("\n✓ All tests passed!")
    else:
        print(f"\n✗ Tests failed with return code {result.returncode}")
else:
    print("Test directory not found (running from pip install)")
    print("Running basic verification tests...")
    
    # Run some basic tests inline
    from jaxframes import JaxFrame
    
    # Test 1: Basic creation
    df = JaxFrame({'a': jnp.array([1, 2, 3])})
    assert df.shape == (3, 1), "Shape test failed"
    print("✓ Basic creation test passed")
    
    # Test 2: Sort
    df2 = JaxFrame({'x': jnp.array([3, 1, 2])})
    sorted_df = df2.sort_values('x')
    assert jnp.allclose(sorted_df.data['x'], jnp.array([1, 2, 3])), "Sort test failed"
    print("✓ Sort test passed")
    
    # Test 3: GroupBy
    df3 = JaxFrame({
        'group': jnp.array([1, 2, 1, 2]),
        'value': jnp.array([10, 20, 30, 40])
    })
    grouped = df3.groupby('group').sum()
    assert jnp.allclose(grouped.data['value'], jnp.array([40, 60])), "GroupBy test failed"
    print("✓ GroupBy test passed")
    
    print("\n✓ Basic verification tests completed!")

## 5. Summary

In [None]:
print("\n" + "=" * 60)
print("TEST SUITE COMPLETE")
print("=" * 60)

print(f"\nEnvironment:")
print(f"  JAX version: {jax.__version__}")
print(f"  Devices: {len(jax.devices())} x {jax.devices()[0].device_kind}")
print(f"  NumPy version: {np.__version__}")
print(f"  Pandas version: {pd.__version__}")

print(f"\nFeatures tested:")
print(f"  ✓ Basic JaxFrame operations")
print(f"  ✓ Parallel radix sort")
print(f"  ✓ Sort-based groupby aggregations")
print(f"  ✓ Parallel sort-merge joins")
if len(jax.devices()) >= 2:
    print(f"  ✓ Distributed operations across {len(jax.devices())} devices")
else:
    print(f"  ⚠ Distributed operations (single device only)")

print(f"\nStage 3 Implementation Status: COMPLETE")
print(f"\nNext steps:")
print(f"  - Stage 4: Lazy Execution Engine (8 weeks)")
print(f"  - Stage 5: API Completeness & Advanced Features (6 weeks)")
print(f"  - Stage 6: Validation & Documentation (4 weeks)")