# JAXFrame JIT Printing Examples

This notebook demonstrates the JIT-compatible printing functionality in JAXFrame, including the improved tracer formatting that shows clean output like "f32[3]" instead of ugly tracer representations.

In [1]:
import jax
import jax.numpy as jnp
import numpy as np
from jaxframe import DataFrame, MaskedArray
from jaxframe.jitprint import jit_print_dataframe, jit_print_masked_array

print("JAX version:", jax.__version__)
print("JAXFrame JIT printing ready!")

JAX version: 0.7.1
JAXFrame JIT printing ready!


## Example 1: DataFrame Printing with Static Arguments

Here we show how to print a JAXFrame DataFrame inside a JIT-compiled function using static arguments.

In [3]:
# Create a simple DataFrame
df = DataFrame({
    'id': [1, 2, 3],
    'value': [10.5, 20.7, 30.9],
    'active': [True, False, True]
}, name="sample_data")

print("Original DataFrame:")
print(df)

def process_with_dataframe_info(data_array, metadata_df):
    """Function that processes data and prints DataFrame info inside JIT."""
    jax.debug.print("=== Processing started ===")
    
    # Print the DataFrame structure
    jit_print_dataframe(metadata_df)
    
    # Process the data
    result = jnp.sum(data_array) * 2
    
    jax.debug.print("Sum result: {}", result)
    jax.debug.print("=== Processing complete ===")
    
    return result

# Use static_argnames to pass the DataFrame as a static argument
process_jit = jax.jit(process_with_dataframe_info, static_argnames=['metadata_df'])

# Run the JIT function
input_data = jnp.array([1.0, 2.0, 3.0])
result = process_jit(input_data, df)

print(f"\nFinal result: {result}")

Original DataFrame:
DataFrame 'sample_data'(3 rows, 3 columns)
Columns: id, value, active
  [0]: {'id': '1', 'value': '10.500', 'active': 'True'}
  [1]: {'id': '2', 'value': '20.700', 'active': 'False'}
  [2]: {'id': '3', 'value': '30.900', 'active': 'True'}
=== Processing started ===
Columns: id, value, active
=== Processing complete ===
  [2]: {'id': '3', 'value': '30.900', 'active': 'True'}
  [1]: {'id': '2', 'value': '20.700', 'active': 'False'}
  [0]: {'id': '1', 'value': '10.500', 'active': 'True'}
DataFrame 'sample_data'(3 rows, 3 columns)
Sum result: 12.0

Final result: 12.0


## Example 2: Beautiful Tracer Formatting

This example shows how JAX tracers are formatted cleanly (e.g., "f32[3]") instead of showing ugly tracer representations.

In [4]:
from jaxframe.jitprint import jit_print_dataframe_data

@jax.jit
def demonstrate_tracer_formatting():
    """Show how JAX tracers are formatted cleanly in DataFrame data."""
    
    # These arrays will be tracers inside the JIT context
    ids = jnp.array([1, 2, 3, 4])
    scores = jnp.array([95.5, 87.2, 92.1, 89.8])
    is_passing = jnp.array([True, True, True, True])
    categories = jnp.array([0, 1, 0, 1])  # categorical data
    
    # Create a DataFrame-like structure with tracers
    data = {
        'student_id': ids,
        'score': scores, 
        'passing': is_passing,
        'category': categories
    }
    columns = ['student_id', 'score', 'passing', 'category']
    
    jax.debug.print("\\n=== DataFrame with JAX Tracers ===")
    jax.debug.print("Note: Tracers show as clean 'dtype[shape]' format, not ugly representations")
    
    # This will show tracers in a clean format like 'i32[4]', 'f32[4]', etc.
    jit_print_dataframe_data(data, columns, 4, "student_grades")
    
    # Show some computation with the data
    avg_score = jnp.mean(scores)
    total_passing = jnp.sum(is_passing)
    
    jax.debug.print("\\nComputations:")
    jax.debug.print("Average score: {}", avg_score)
    jax.debug.print("Students passing: {}", total_passing)
    
    return avg_score

# Run the demonstration
result = demonstrate_tracer_formatting()
print(f"\\nAverage score: {result:.2f}")

\n=== DataFrame with JAX Tracers ===
Note: Tracers show as clean 'dtype[shape]' format, not ugly representations
Columns: student_id, score, passing, category
\nComputations:
  [3]: {'student_id': 'i32', 'score': 'f32', 'passing': 'bool', 'category': 'i32'}
  [2]: {'student_id': 'i32', 'score': 'f32', 'passing': 'bool', 'category': 'i32'}
  [1]: {'student_id': 'i32', 'score': 'f32', 'passing': 'bool', 'category': 'i32'}
  [0]: {'student_id': 'i32', 'score': 'f32', 'passing': 'bool', 'category': 'i32'}
DataFrame 'student_grades'(4 rows, 4 columns)
Students passing: 4
Average score: 91.1500015258789
\nAverage score: 91.15


## Example 3: MaskedArray Printing in JIT

This shows how to print MaskedArray objects inside JIT-compiled functions.

In [6]:
# Create a MaskedArray with some missing data
data = jnp.array([[1.5, 2.7, 3.2], 
                  [4.1, 5.9, 6.3],
                  [7.8, 8.1, 9.4]])

# Mask: True = valid data, False = missing/invalid data  
mask = np.array([[True, False, True],
                 [True, True, False], 
                 [False, True, True]])

index_df = DataFrame({'sample': ['A', 'B', 'C']})
ma = MaskedArray(data, mask, index_df)

print("Original MaskedArray:")
print(ma)

def process_masked_data(computation_data, metadata_ma):
    """Process data while printing MaskedArray metadata."""
    jax.debug.print("\\n=== Processing Masked Data ===")
    
    # Print the MaskedArray structure
    jit_print_masked_array(metadata_ma)
    
    # Do some computation
    result = jnp.sum(computation_data ** 2)
    jax.debug.print("Sum of squares: {}", result)
    
    return result

# JIT compile with static MaskedArray
process_jit = jax.jit(process_masked_data, static_argnames=['metadata_ma'])

# Run with computation data
computation_input = jnp.array([[1.0, 2.0], [3.0, 4.0]])
result = process_jit(computation_input, ma)

print(f"\\nComputation result: {result}")

Original MaskedArray:
MaskedArray(3 rows, 3 columns)
Valid values: 6/9 (66.7%)
Index DataFrame: 3 rows, 1 columns
\n=== Processing Masked Data ===
Valid values: 6/9 (66.7%)
MaskedArray(3 rows, 3 columns)
Index DataFrame: 3 rows, 1 columns
Sum of squares: 30.0
\nComputation result: 30.0


## Example 4: Before and After - Tracer Formatting Improvement

This example demonstrates the difference between ugly default tracer output and our clean formatting.

In [7]:
from jaxframe.jitprint import _format_value_for_jit_print

@jax.jit  
def compare_formatting():
    """Compare ugly default tracer strings vs our clean formatting."""
    
    float_array = jnp.array([1.1, 2.2, 3.3])
    int_array = jnp.array([10, 20, 30])
    scalar = jnp.array(42.0)
    matrix = jnp.array([[1.0, 2.0], [3.0, 4.0]])
    
    jax.debug.print("\\n=== UGLY: Default tracer representations ===")
    jax.debug.print("Float array: {}", str(float_array))
    jax.debug.print("Int array: {}", str(int_array))
    jax.debug.print("Scalar: {}", str(scalar))
    jax.debug.print("Matrix: {}", str(matrix))
    
    jax.debug.print("\\n=== CLEAN: Our improved formatting ===")
    jax.debug.print("Float array: {}", _format_value_for_jit_print(float_array))
    jax.debug.print("Int array: {}", _format_value_for_jit_print(int_array))
    jax.debug.print("Scalar: {}", _format_value_for_jit_print(scalar))
    jax.debug.print("Matrix: {}", _format_value_for_jit_print(matrix))
    
    jax.debug.print("\\n✨ Much better! Clean, readable, and concise.")
    
    return jnp.sum(float_array)

result = compare_formatting()
print(f"\\nSum result: {result}")

\n=== UGLY: Default tracer representations ===
Float array: JitTracer<float32[3]>
Int array: JitTracer<int32[3]>
Scalar: JitTracer<~float32[]>
Matrix: JitTracer<float32[2,2]>
\n=== CLEAN: Our improved formatting ===
Float array: f32[3]
Int array: i32[3]
Scalar: f32
Matrix: f32[2x2]
\n✨ Much better! Clean, readable, and concise.
\nSum result: 6.599999904632568


## Summary

The JAXFrame JIT printing functionality provides:

✅ **JIT-Compatible Printing**: Use `jit_print_dataframe()` and `jit_print_masked_array()` inside `@jax.jit` functions

✅ **Static Arguments**: Pass JAXFrame objects as static arguments using `static_argnames`

✅ **Clean Tracer Formatting**: JAX tracers display as readable formats like `f32[3]` instead of ugly `JitTracer<float32[3]>` strings

✅ **Comprehensive Support**: Works with all JAX data types, shapes, and both DataFrame and MaskedArray objects

✅ **Seamless Integration**: Drop-in replacement for regular print statements in JIT contexts

This makes debugging and monitoring JAX-compiled functions much more pleasant and readable!