# Control Flow

This notebook demonstrates how visu-hlo visualizes JAX's structured control flow primitives. JAX provides functional control flow operations that can be compiled and differentiated.

## Setup

First, let's import the necessary libraries:

In [None]:
import os

os.environ['JAX_PLATFORMS'] = 'cpu'

import jax
import jax.numpy as jnp
from jax import lax
from visu_hlo import show

## Conditional Execution with `lax.cond`

JAX's `lax.cond` provides structured conditional execution:

In [None]:
@jax.jit
def simple_conditional(x):
    """Simple conditional: square if positive, negate if negative."""
    return lax.cond(
        x > 0,
        lambda x: x**2,  # true branch
        lambda x: -x,  # false branch
        x,
    )


print('Simple conditional (x > 0 ? x² : -x):')
show(simple_conditional, jnp.array(5.0))

## Complex Conditional Logic

In [None]:
@jax.jit
def complex_conditional(x, y):
    """More complex conditional with multiple operations."""

    def true_branch(args):
        x, y = args
        return x * y + jnp.sin(x)

    def false_branch(args):
        x, y = args
        return x - y + jnp.cos(y)

    return lax.cond(x > y, true_branch, false_branch, (x, y))


print('Complex conditional with trigonometric functions:')
show(complex_conditional, jnp.array(2.0), jnp.array(1.0))

## Multi-way Conditionals with `lax.switch`

For multiple branches based on an integer index:

In [None]:
@jax.jit
def multi_way_switch(index, x):
    """Multi-way conditional using lax.switch."""
    branches = [
        lambda x: x + 1,  # case 0
        lambda x: x * 2,  # case 1
        lambda x: x**2,  # case 2
        lambda x: jnp.sqrt(x),  # case 3
    ]

    return lax.switch(index, branches, x)


print('Multi-way switch (case 2: x²):')
show(multi_way_switch, 2, jnp.array(4.0))

## Loops with `lax.fori_loop`

Fixed-iteration loops:

In [None]:
@jax.jit
def simple_loop(n, init_val):
    """Simple accumulation loop."""

    def body_fun(i, val):
        return val + i * 2

    return lax.fori_loop(0, n, body_fun, init_val)


print('Simple fori_loop (accumulate i * 2):')
show(simple_loop, 5, jnp.array(0.0))

In [None]:
@jax.jit
def matrix_power_loop(matrix, n):
    """Compute matrix power using a loop."""

    def body_fun(i, result):
        return jnp.dot(result, matrix)

    return lax.fori_loop(0, n, body_fun, matrix)


test_matrix = jnp.array([[1.1, 0.1], [0.1, 1.1]])
print('Matrix power using fori_loop:')
show(matrix_power_loop, test_matrix, 3)

## While Loops with `lax.while_loop`

Condition-based loops:

In [None]:
@jax.jit
def convergence_loop(x):
    """Loop until convergence using while_loop."""

    def cond_fun(val):
        return jnp.abs(val) > 0.01

    def body_fun(val):
        return val * 0.8

    return lax.while_loop(cond_fun, body_fun, x)


print('While loop until convergence:')
show(convergence_loop, jnp.array(10.0))

In [None]:
@jax.jit
def newton_iteration(x):
    """Newton's method for finding square root."""
    target = 2.0  # Finding sqrt(2)

    def cond_fun(state):
        x, error = state
        return error > 1e-6

    def body_fun(state):
        x, _ = state
        new_x = 0.5 * (x + target / x)
        error = jnp.abs(new_x - x)
        return new_x, error

    init_state = (x, jnp.array(1.0))
    final_x, _ = lax.while_loop(cond_fun, body_fun, init_state)
    return final_x


print("Newton's method for square root:")
show(newton_iteration, jnp.array(1.5))

## Scan Operations with `lax.scan`

Efficient loops that accumulate intermediate results:

In [None]:
@jax.jit
def cumulative_sum_scan(xs):
    """Cumulative sum using lax.scan."""

    def scan_fun(carry, x):
        new_carry = carry + x
        return new_carry, new_carry

    _, cumsum = lax.scan(scan_fun, 0.0, xs)
    return cumsum


test_array = jnp.array([1.0, 2.0, 3.0, 4.0, 5.0])
print('Cumulative sum using scan:')
show(cumulative_sum_scan, test_array)

In [None]:
@jax.jit
def running_average_scan(xs):
    """Running average using lax.scan."""

    def scan_fun(carry, x):
        count, total = carry
        new_count = count + 1
        new_total = total + x
        avg = new_total / new_count
        return (new_count, new_total), avg

    _, averages = lax.scan(scan_fun, (0.0, 0.0), xs)
    return averages


print('Running average using scan:')
show(running_average_scan, test_array)

## Recurrent Neural Network with Scan

A simple RNN implementation using scan:

In [None]:
@jax.jit
def simple_rnn(params, inputs):
    """Simple RNN using lax.scan."""
    W_h, W_x, b = params

    def rnn_step(h, x):
        new_h = jnp.tanh(jnp.dot(W_h, h) + jnp.dot(W_x, x) + b)
        return new_h, new_h

    h0 = jnp.zeros(W_h.shape[0])
    _, hidden_states = lax.scan(rnn_step, h0, inputs)
    return hidden_states


# RNN parameters
hidden_size = 3
input_size = 2
W_h = jnp.array([[0.1, 0.2, 0.3], [0.4, 0.5, 0.6], [0.7, 0.8, 0.9]])
W_x = jnp.array([[0.1, 0.2], [0.3, 0.4], [0.5, 0.6]])
b = jnp.array([0.1, 0.1, 0.1])
params = (W_h, W_x, b)

# Input sequence
inputs = jnp.array([[1.0, 0.5], [0.8, 1.2], [0.3, 0.9]])

print('Simple RNN with scan:')
show(simple_rnn, params, inputs)

## Nested Control Flow

Combining different control flow primitives:

In [None]:
@jax.jit
def nested_control_flow(x, condition):
    """Nested conditional and loop."""

    def true_branch(x):
        # If condition is true, apply a loop
        def body_fun(i, val):
            return val * 1.1

        return lax.fori_loop(0, 5, body_fun, x)

    def false_branch(x):
        # If condition is false, apply a different transformation
        return jnp.sqrt(jnp.abs(x))

    return lax.cond(condition, true_branch, false_branch, x)


print('Nested control flow (condition=True):')
show(nested_control_flow, jnp.array(2.0), True)

## Dynamic Programming Example

Computing Fibonacci numbers using scan:

In [None]:
@jax.jit
def fibonacci_scan(n):
    """Compute Fibonacci sequence using scan."""

    def fib_step(carry, _):
        a, b = carry
        return (b, a + b), b

    # Initialize with F(0)=0, F(1)=1
    init_carry = (0, 1)
    _, fib_sequence = lax.scan(fib_step, init_carry, jnp.arange(n))
    return fib_sequence


print('Fibonacci sequence using scan:')
show(fibonacci_scan, 8)

## Optimization Loop

Simple gradient descent optimization:

In [None]:
@jax.jit
def gradient_descent_loop(params, learning_rate, n_steps):
    """Simple gradient descent using a loop."""

    def objective(x):
        return (x - 2.0) ** 2 + 1.0

    grad_fn = jax.grad(objective)

    def update_step(i, params):
        grad = grad_fn(params)
        return params - learning_rate * grad

    return lax.fori_loop(0, n_steps, update_step, params)


print('Gradient descent optimization loop:')
show(gradient_descent_loop, jnp.array(0.0), 0.1, 10)

## Summary

This notebook demonstrated JAX's structured control flow primitives and their visualization:

- **Conditionals**: `lax.cond` for if-then-else logic, `lax.switch` for multi-way branching
- **Loops**: `lax.fori_loop` for fixed iterations, `lax.while_loop` for condition-based loops
- **Scan**: `lax.scan` for efficient loops with intermediate results
- **Applications**: RNNs, dynamic programming, optimization algorithms
- **Nested structures**: Combining different control flow primitives

All these control flow operations are:
- **Functional**: No side effects, pure functions
- **Compilable**: Can be JIT compiled for performance
- **Differentiable**: Work with JAX's automatic differentiation
- **Parallelizable**: Can be executed on GPUs and TPUs

The HLO visualizations show how these high-level control structures are compiled into efficient low-level operations.