# JAX Transformations

This notebook demonstrates how visu-hlo visualizes the computational graphs created by JAX's powerful function transformations. JAX provides composable function transformations for automatic differentiation, vectorization, and parallelization.

## Setup

First, let's import the necessary libraries:

In [None]:
import os

os.environ['JAX_PLATFORMS'] = 'cpu'

import jax
import jax.numpy as jnp
from visu_hlo import show

## Automatic Differentiation with `grad`

JAX's `grad` transformation computes gradients of scalar-valued functions:

In [None]:
def simple_function(x):
    """Simple quadratic function."""
    return x**2 + 3 * x + 1


# Gradient function
grad_fn = jax.grad(simple_function)

print('Original function f(x) = x² + 3x + 1:')
show(simple_function, jnp.array(2.0))

print("\nGradient f'(x) = 2x + 3:")
show(grad_fn, jnp.array(2.0))

## Multivariate Gradients

In [None]:
def multivariate_function(params):
    """Function of multiple variables."""
    x, y, z = params
    return x**2 + y * z + jnp.sin(x * y)


grad_multivariate = jax.grad(multivariate_function)

test_params = jnp.array([1.0, 2.0, 3.0])
print('Multivariate function gradient:')
show(grad_multivariate, test_params)

## Partial Derivatives with `argnums`

In [None]:
def loss_function(params, x, y):
    """Loss function with respect to parameters."""
    w, b = params
    prediction = w * x + b
    return (prediction - y) ** 2


# Gradient with respect to parameters (argnums=0)
grad_wrt_params = jax.grad(loss_function, argnums=0)

# Gradient with respect to input x (argnums=1)
grad_wrt_x = jax.grad(loss_function, argnums=1)

params = jnp.array([2.0, 1.0])  # w=2, b=1
x = jnp.array(3.0)
y = jnp.array(5.0)

print('Gradient with respect to parameters:')
show(grad_wrt_params, params, x, y)

print('\nGradient with respect to input x:')
show(grad_wrt_x, params, x, y)

## Higher-Order Derivatives

In [None]:
def polynomial(x):
    """Polynomial function for higher-order derivatives."""
    return x**4 + 2 * x**3 - 3 * x**2 + x + 1


# First derivative
first_deriv = jax.grad(polynomial)
# Second derivative
second_deriv = jax.grad(jax.grad(polynomial))
# Third derivative
third_deriv = jax.grad(jax.grad(jax.grad(polynomial)))

x = jnp.array(2.0)

print('Second derivative:')
show(second_deriv, x)

print('\nThird derivative:')
show(third_deriv, x)

## `value_and_grad` for Efficiency

In [None]:
def expensive_function(x):
    """Function where we want both value and gradient."""
    return jnp.sum(x**3) + jnp.sum(jnp.sin(x))


# Get both value and gradient in one pass
value_and_grad_fn = jax.value_and_grad(expensive_function)

x = jnp.array([1.0, 2.0, 3.0])
print('value_and_grad (more efficient than separate calls):')
show(value_and_grad_fn, x)

## Vectorization with `vmap`

Automatically vectorize functions to work on batches:

In [None]:
def single_example_function(x):
    """Function that works on a single example."""
    return jnp.sum(x**2) + jnp.mean(x)


# Vectorize to work on batches
batched_function = jax.vmap(single_example_function)

# Single example
single_input = jnp.array([1.0, 2.0, 3.0])
print('Single example function:')
show(single_example_function, single_input)

# Batch of examples
batch_input = jnp.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]])
print('\nVectorized function (vmap):')
show(batched_function, batch_input)

## Advanced `vmap` with `in_axes`

In [None]:
def matrix_vector_mult(matrix, vector):
    """Matrix-vector multiplication."""
    return jnp.dot(matrix, vector)


# Vectorize over the vector argument (axis 0) but not the matrix
batch_matvec = jax.vmap(matrix_vector_mult, in_axes=(None, 0))

matrix = jnp.array([[1.0, 2.0], [3.0, 4.0]])
vectors = jnp.array([[1.0, 1.0], [2.0, 3.0], [0.5, 1.5]])

print('Vectorized matrix-vector multiplication:')
print(f'Matrix shape: {matrix.shape}, Vectors shape: {vectors.shape}')
show(batch_matvec, matrix, vectors)

## Nested `vmap` for Multiple Batch Dimensions

In [None]:
def pairwise_distance(x, y):
    """Euclidean distance between two points."""
    return jnp.sqrt(jnp.sum((x - y) ** 2))


# First vmap over y, then over x
vectorized_distances = jax.vmap(jax.vmap(pairwise_distance, in_axes=(None, 0)), in_axes=(0, None))

points_x = jnp.array([[0.0, 0.0], [1.0, 1.0], [2.0, 2.0]])
points_y = jnp.array([[1.0, 0.0], [0.0, 1.0]])

print('Nested vmap for pairwise distances:')
print(f'Points X shape: {points_x.shape}, Points Y shape: {points_y.shape}')
show(vectorized_distances, points_x, points_y)

## Combining `grad` and `vmap`

Vectorized gradients for batch processing:

In [None]:
def loss_per_example(params, x, y):
    """Loss for a single example."""
    w, b = params
    prediction = jnp.dot(w, x) + b
    return (prediction - y) ** 2


# Gradient for a single example
grad_single = jax.grad(loss_per_example)

# Vectorized gradient for batch of examples
grad_batch = jax.vmap(grad_single, in_axes=(None, 0, 0))

params = jnp.array([1.0, 2.0]), jnp.array(0.5)  # w, b
batch_x = jnp.array([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]])
batch_y = jnp.array([3.0, 7.0, 11.0])

print('Vectorized gradients for batch:')
show(grad_batch, params, batch_x, batch_y)

## Combining Multiple Transformations

JAX transformations are composable:

In [None]:
def neural_network_layer(params, x):
    """Simple neural network layer."""
    W, b = params
    return jnp.tanh(jnp.dot(W, x) + b)


def loss_fn(params, batch_x, batch_y):
    """Loss function for the neural network."""
    # Vectorize over the batch
    batch_predictions = jax.vmap(neural_network_layer, in_axes=(None, 0))(params, batch_x)
    # Mean squared error
    return jnp.mean((batch_predictions - batch_y) ** 2)


# Combine JIT and grad
jit_grad_loss = jax.jit(jax.grad(loss_fn))

# Parameters
W = jnp.array([[0.1, 0.2], [0.3, 0.4]])
b = jnp.array([0.1, 0.1])
params = (W, b)

# Batch data
batch_x = jnp.array([[1.0, 2.0], [3.0, 4.0]])
batch_y = jnp.array([[0.5, 0.8], [0.2, 0.9]])

print('JIT + grad + vmap combination:')
show(jit_grad_loss, params, batch_x, batch_y)

## Advanced: Jacobian Computation

In [None]:
def vector_function(x):
    """Vector-valued function for Jacobian computation."""
    return jnp.array([x[0] ** 2 + x[1], x[0] * x[1], x[1] ** 2])


# Jacobian using vmap and grad
jacobian_fn = jax.jacfwd(vector_function)

x = jnp.array([2.0, 3.0])
print('Jacobian computation:')
show(jax.jit(jacobian_fn), x)

## Hessian Computation

In [None]:
def scalar_function_for_hessian(x):
    """Scalar function for Hessian computation."""
    return x[0] ** 3 + x[1] ** 2 + x[0] * x[1]


# Hessian using nested grad
hessian_fn = jax.hessian(scalar_function_for_hessian)

x = jnp.array([1.0, 2.0])
print('Hessian computation:')
show(jax.jit(hessian_fn), x)

## Custom Transformations with `custom_vjp`

Defining custom vector-Jacobian products:

In [None]:
@jax.custom_vjp
def smooth_abs(x):
    """Smooth approximation to absolute value."""
    return jnp.sqrt(x**2 + 1e-8)


def smooth_abs_fwd(x):
    """Forward pass."""
    y = smooth_abs(x)
    return y, x


def smooth_abs_bwd(x, g):
    """Backward pass with custom gradient."""
    return (g * x / jnp.sqrt(x**2 + 1e-8),)


smooth_abs.defvjp(smooth_abs_fwd, smooth_abs_bwd)


# Use in a function with grad
def function_with_custom_grad(x):
    return jnp.sum(smooth_abs(x))


grad_custom = jax.grad(function_with_custom_grad)

x = jnp.array([-1.0, 0.5, 2.0])
print('Function with custom VJP:')
show(jax.jit(grad_custom), x)

## Optimization with Transformations

A complete optimization example combining multiple transformations:

In [None]:
def optimization_step(params, batch_x, batch_y, learning_rate):
    """Single optimization step combining multiple transformations."""

    def model(params, x):
        W1, b1, W2, b2 = params
        h = jnp.tanh(jnp.dot(W1, x) + b1)
        return jnp.dot(W2, h) + b2

    def batch_loss(params, batch_x, batch_y):
        # Vectorize model over batch
        predictions = jax.vmap(model, in_axes=(None, 0))(params, batch_x)
        return jnp.mean((predictions - batch_y) ** 2)

    # Get both loss and gradients efficiently
    loss, grads = jax.value_and_grad(batch_loss)(params, batch_x, batch_y)

    # Update parameters
    new_params = jax.tree.map(lambda p, g: p - learning_rate * g, params, grads)

    return new_params, loss


# Initialize parameters
W1 = jnp.array([[0.1, 0.2], [0.3, 0.4]])
b1 = jnp.array([0.0, 0.0])
W2 = jnp.array([[0.5, 0.6]])
b2 = jnp.array([0.0])
params = (W1, b1, W2, b2)

# Batch data
batch_x = jnp.array([[1.0, 0.5], [0.8, 1.2]])
batch_y = jnp.array([[0.7], [0.9]])
learning_rate = 0.01

print('Complete optimization step (JIT + value_and_grad + vmap):')
show(jax.jit(optimization_step), params, batch_x, batch_y, learning_rate)

## Summary

This notebook demonstrated JAX's powerful function transformations and their visualizations:

### Core Transformations:
- **`grad`**: Automatic differentiation for computing gradients
- **`vmap`**: Automatic vectorization for batch processing
- **`jit`**: Just-in-time compilation for performance
- **`pmap`**: Parallel mapping across devices (conceptual)

### Advanced Features:
- **Higher-order derivatives**: Nested `grad` calls
- **`value_and_grad`**: Efficient computation of both value and gradient
- **Jacobians and Hessians**: `jacfwd`, `jacrev`, `hessian`
- **Custom transformations**: `custom_vjp` for specialized gradients

### Composition:
- Transformations are **composable**: `jit(grad(vmap(...)))`
- Order matters: `vmap(grad(...))` vs `grad(vmap(...))`
- Can be combined for complex workflows

### Key Benefits:
- **Functional**: No side effects, pure transformations
- **Performant**: JIT compilation and vectorization
- **Flexible**: Works with arbitrary Python functions
- **Scalable**: Efficient batch processing and parallelization

The HLO visualizations reveal how these high-level transformations are compiled into efficient computational graphs, showing the automatic optimizations performed by JAX's compiler.