# Symbolic Regression Tutorial: Binary Search Circuit Recovery

## Introduction

This tutorial demonstrates how to recover an arithmetic circuit for binary search using symbolic regression. We'll show how the binary search algorithm can be "unrolled" into polynomial constraints and use pySR to recover each component.

## Mathematical Framework

Our binary search circuit operates on a sorted array $A$ and target value $t$. For each iteration $i$, the circuit consists of:

### 1. State Variables
- Left endpoint: $L_i$ 
- Right endpoint: $R_i$

### 2. Midpoint Computation
The midpoint $m_i$ is computed with the constraint:

\begin{equation}
(2m_i - L_i - R_i)^2 = 0
\end{equation}

In the ideal case, this simplifies to:

\begin{equation}
m_i = \frac{L_i + R_i}{2}
\end{equation}

### 3. Branch Indicators
We have three Boolean indicators:
- Equality: $\delta_i^=$
- Less than: $\delta_i^<$
- Greater than: $\delta_i^>$

These satisfy:

\begin{equation}
\delta_i^= + \delta_i^< + \delta_i^> = 1
\end{equation}

### 4. State Updates
When the less-than branch is active ($\delta_i^< = 1$), we update:

\begin{equation}
L_{i+1} = \delta_i^< \cdot (m_i + 1) + (1 - \delta_i^<)L_i
\end{equation}

### 5. Output Computation
The final output is:

\begin{equation}
o = \sum_{i=0}^{T-1} \delta_i^= \cdot m_i + \left(1 - \sum_{i=0}^{T-1}\delta_i^=\right)(-1)
\end{equation}

where $T = \lceil \log_2(n) \rceil + 1$ is the number of iterations.

We define the number of iterations $T$ as:

\begin{equation}
T = \left\lceil \log_2(n) \right\rceil + 1
\end{equation}

We simulate the standard binary search. Once the target is found in some iteration, we "freeze" the state for the remaining iterations (and we set the branch indicator for that iteration to 1, and for later iterations to 0). In the output constraint we define:

\begin{equation}
o = \sum_{i=0}^{T-1} \delta_i^= m_i + \left(1-\sum_{i=0}^{T-1}\delta_i^=\right)(-1)
\end{equation}

Thus, if the target is found in exactly one iteration, the sum is just that $m_i$ (the index where the target appears); if not found (so all $\delta_i^= = 0$), then $o = -1$.

We then define a function to check all the polynomial constraints (midpoint, branch sum, comparison, and state update) using SymPy.

In [None]:

# %% [code]
import math
import sympy as sp

# Define symbolic variables for T iterations.
def define_symbols(T):
    L_syms = sp.symbols('L0:'+str(T), real=True)
    R_syms = sp.symbols('R0:'+str(T), real=True)
    m_syms = sp.symbols('m0:'+str(T), real=True)
    delta_eq_syms = sp.symbols('delta0_eq:'+str(T)+'_eq', real=True)
    delta_lt_syms = sp.symbols('delta0_lt:'+str(T)+'_lt', real=True)
    delta_gt_syms = sp.symbols('delta0_gt:'+str(T)+'_gt', real=True)
    return L_syms, R_syms, m_syms, delta_eq_syms, delta_lt_syms, delta_gt_syms

# simulate_binary_search now works for any sorted A and target.
def simulate_binary_search(A, target):
    n = len(A)
    # Use T = ceil(log2(n)) + 1 iterations
    T = math.ceil(math.log(n, 2)) + 1
    L_syms, R_syms, m_syms, delta_eq_syms, delta_lt_syms, delta_gt_syms = define_symbols(T)
    
    assign = {}
    found = False   # flag: whether we've found the target already.
    found_iter = None  # which iteration first found the target.
    
    # Initialize L0 and R0:
    assign[L_syms[0]] = 0
    assign[R_syms[0]] = n - 1
    
    for i in range(T):
        # If the state has already been "frozen" because we found the target,
        # then for later iterations we simply copy the previous state and set branch indicators to 0.
        if found:
            assign[m_syms[i]] = assign[m_syms[i-1]]
            assign[L_syms[i]] = assign[L_syms[i-1]]
            assign[R_syms[i]] = assign[R_syms[i-1]]
            assign[delta_eq_syms[i]] = 0
            assign[delta_lt_syms[i]] = 0
            assign[delta_gt_syms[i]] = 0
            continue
        
        L_val = assign[L_syms[i]]
        R_val = assign[R_syms[i]]
        # If the interval is empty (L > R), then no valid search remains.
        # In that case, we simply set m to an arbitrary value (say, L_val) and no branch indicator.
        if L_val > R_val:
            assign[m_syms[i]] = L_val
            assign[delta_eq_syms[i]] = 0
            assign[delta_lt_syms[i]] = 0
            assign[delta_gt_syms[i]] = 0
        else:
            m_val = (L_val + R_val) // 2  # floor division
            assign[m_syms[i]] = m_val
            # Decide the branch based on A[m_val]
            if A[m_val] == target:
                assign[delta_eq_syms[i]] = 1
                assign[delta_lt_syms[i]] = 0
                assign[delta_gt_syms[i]] = 0
                found = True
                found_iter = i
            elif A[m_val] < target:
                assign[delta_eq_syms[i]] = 0
                assign[delta_lt_syms[i]] = 1
                assign[delta_gt_syms[i]] = 0
            else:  # A[m_val] > target
                assign[delta_eq_syms[i]] = 0
                assign[delta_lt_syms[i]] = 0
                assign[delta_gt_syms[i]] = 1
        
        # State update for next iteration, if not the last iteration.
        if i < T - 1:
            # If we already found the target, then copy state.
            if found:
                assign[L_syms[i+1]] = assign[L_syms[i]]
                assign[R_syms[i+1]] = assign[R_syms[i]]
            else:
                # For left update: if delta_lt is 1, set L_{i+1} = m_i + 1; otherwise, L remains.
                L_next = assign[delta_lt_syms[i]] * (assign[m_syms[i]] + 1) + (1 - assign[delta_lt_syms[i]]) * assign[L_syms[i]]
                # For right update: if delta_gt is 1, set R_{i+1} = m_i - 1; otherwise, R remains.
                R_next = assign[delta_gt_syms[i]] * (assign[m_syms[i]] - 1) + (1 - assign[delta_gt_syms[i]]) * assign[R_syms[i]]
                assign[L_syms[i+1]] = L_next
                assign[R_syms[i+1]] = R_next
    # Compute the output o.
    o_val = 0
    sum_delta = 0
    for i in range(T):
        o_val += assign[delta_eq_syms[i]] * assign[m_syms[i]]
        sum_delta += assign[delta_eq_syms[i]]
    # If target never found, output -1.
    o_val = o_val + (1 - sum_delta) * (-1)
    assign['o'] = o_val
    vars_dict = {
        'L': L_syms,
        'R': R_syms,
        'm': m_syms,
        'delta_eq': delta_eq_syms,
        'delta_lt': delta_lt_syms,
        'delta_gt': delta_gt_syms
    }
    return assign, vars_dict

# The same check_constraints function as before.
def check_constraints(assign, vars_dict, A, target):
    L_syms = vars_dict['L']
    R_syms = vars_dict['R']
    m_syms = vars_dict['m']
    delta_eq_syms = vars_dict['delta_eq']
    delta_lt_syms = vars_dict['delta_lt']
    delta_gt_syms = vars_dict['delta_gt']
    
    T = len(m_syms)
    print("Checking constraints for each iteration:")
    for i in range(T):
        print(f"\nIteration {i}:")
        # Midpoint constraint: 2*m_i - L_i - R_i = 0.
        eq_mid = sp.simplify(2*m_syms[i] - L_syms[i] - R_syms[i])
        val_mid = eq_mid.subs(assign)
        print(f"  Midpoint: 2*m_{i} - L_{i} - R_{i} = {val_mid}")
        
        # Branch sum constraint: delta_eq + delta_lt + delta_gt - 1 = 0.
        eq_branch = sp.simplify(delta_eq_syms[i] + delta_lt_syms[i] + delta_gt_syms[i] - 1)
        val_branch = eq_branch.subs(assign)
        print(f"  Branch sum: delta_eq_{i} + delta_lt_{i} + delta_gt_{i} - 1 = {val_branch}")
        
        # Comparison constraints:
        m_val = assign[m_syms[i]]
        diff = A[m_val] - target
        eq_comp_eq = sp.simplify(delta_eq_syms[i] * diff)
        val_comp_eq = eq_comp_eq.subs(assign)
        print(f"  Equality compare: delta_eq_{i}*(A[m_{i}]-target) = {val_comp_eq}")
        eq_comp_lt = sp.simplify(delta_lt_syms[i] * (target - A[m_val]))
        val_comp_lt = eq_comp_lt.subs(assign)
        print(f"  Less-than compare: delta_lt_{i}*(target-A[m_{i}]) = {val_comp_lt}")
        eq_comp_gt = sp.simplify(delta_gt_syms[i] * (A[m_val]-target))
        val_comp_gt = eq_comp_gt.subs(assign)
        print(f"  Greater-than compare: delta_gt_{i}*(A[m_{i}]-target) = {val_comp_gt}")
        
        # State update constraints (if not the last iteration)
        if i < T - 1:
            expr_L_next = sp.simplify(delta_lt_syms[i]*(m_syms[i] + 1) + (1 - delta_lt_syms[i])*L_syms[i])
            eq_update_L = sp.simplify(L_syms[i+1] - expr_L_next)
            val_update_L = eq_update_L.subs(assign)
            print(f"  L update: L_{i+1} - [delta_lt_{i}*(m_{i}+1) + (1-delta_lt_{i})*L_{i}] = {val_update_L}")
            
            expr_R_next = sp.simplify(delta_gt_syms[i]*(m_syms[i] - 1) + (1 - delta_gt_syms[i])*R_syms[i])
            eq_update_R = sp.simplify(R_syms[i+1] - expr_R_next)
            val_update_R = eq_update_R.subs(assign)
            print(f"  R update: R_{i+1} - [delta_gt_{i}*(m_{i}-1) + (1-delta_gt_{i})*R_{i}] = {val_update_R}")
    # Compute and print the output.
    o_expr = sum(delta_eq_syms[i] * m_syms[i] for i in range(T)) + (1 - sum(delta_eq_syms[i] for i in range(T)))*(-1)
    o_computed = sp.simplify(o_expr).subs(assign)
    print(f"\nComputed output o = {o_computed}")
    return o_computed


In [None]:
A1 = [2, 4, 6, 8, 10, 12, 14, 16]
target1 = 12
assign1, vars_dict1 = simulate_binary_search(A1, target1)
print("Example 1: Target is in A")
print("Witness assignment:")
for key, val in assign1.items():
    print(f"  {key} = {val}")
print("\nConstraint check:")
o1 = check_constraints(assign1, vars_dict1, A1, target1)
print("\nFinal output (should be index of 10):", o1)

In [None]:
A2 = [2, 4, 6, 8, 10, 12, 14, 16]
target2 = 9
assign2, vars_dict2 = simulate_binary_search(A2, target2)
print("\nExample 2: Target is NOT in A")
print("Witness assignment:")
for key, val in assign2.items():
    print(f"  {key} = {val}")
print("\nConstraint check:")
o2 = check_constraints(assign2, vars_dict2, A2, target2)
print("\nFinal output (should be -1):", o2)


# Binary Search Circuit Recovery via Symbolic Regression

In this tutorial we explain step by step how the binary search algorithm can be "unrolled" into an arithmetic circuit—that is, a set of polynomial constraints. We then illustrate how one might use a symbolic regression tool (pySR) to recover each small module (or "gate") of the circuit.

Our binary search circuit (for a given sorted array $A$ and target $t$) has several layers.

For each iteration $i$, we have:

1. **State variables:** $L_i$ (left endpoint), $R_i$ (right endpoint)

2. **Computation node:** $m_i$, the midpoint, with the constraint:
   \begin{equation}
   (2m_i - L_i - R_i)^2 = 0
   \end{equation}
   (In the ideal circuit, $m_i = \frac{L_i + R_i}{2}$)

3. **Decision nodes:** Boolean branch indicators $\delta_i^=, \delta_i^<, \delta_i^>$ satisfying:
   \begin{equation}
   \delta_i^= + \delta_i^< + \delta_i^> = 1
   \end{equation}
   They "select" which branch is taken based on comparing $A[m_i]$ and $t$

4. **State update nodes:** For example, if the less-than branch is active ($\delta_i^< = 1$), the left endpoint is updated as:
   \begin{equation}
   L_{i+1} = \delta_i^< \cdot (m_i + 1) + (1 - \delta_i^<)L_i
   \end{equation}

5. **Output node:** Finally, the output is defined as:
   \begin{equation}
   o = \left(\sum_{i=0}^{T-1} \delta_i^= \cdot m_i\right) + \left(1 - \sum_{i=0}^{T-1} \delta_i^=\right)(-1)
   \end{equation}
   so that if an equality branch is taken, the corresponding $m_i$ is output; otherwise $o = -1$

## Key Components to Recover

We'll focus on recovering three essential modules:

1. **Midpoint Function**
   \begin{equation}
   m = \frac{L + R}{2}
   \end{equation}

2. **Left State Update**
   \begin{equation}
   L_\text{next} = \delta_\text{lt}(m+1) + (1-\delta_\text{lt})L
   \end{equation}

3. **Output Function**
   \begin{equation}
   o = \delta_\text{eq} \cdot m + (1-\delta_\text{eq})(-1)
   \end{equation}

Each module will be learned using symbolic regression on simulated data.

## Recovering the Midpoint Function

In a correct binary search the midpoint is computed as:

\begin{equation}
m = \frac{L + R}{2}
\end{equation}

We generate data by sampling random values for $L$ and $R$ (with $L < R$) and then setting $m$ to $\frac{L+R}{2}$. Our goal is to see if pySR can recover this linear relationship.

In [None]:

# %% [code]
import numpy as np
from pysr import PySRRegressor
import matplotlib.pyplot as plt

# %% [code]
# Number of samples for training
num_samples = 1000

# Generate random L and R values
L_vals = np.random.uniform(0, 5, num_samples)
# Ensure R > L by choosing R uniformly in (L+1, L+6)
R_vals = np.array([np.random.uniform(L + 1, L + 6) for L in L_vals])
m_vals = (L_vals + R_vals) / 2

# Training data: inputs are [L, R] and output is m.
X_mid = np.column_stack([L_vals, R_vals])
y_mid = m_vals

# Setup pySR to recover the midpoint function.
model_mid = PySRRegressor(
    niterations=1000,
    binary_operators=["+", "-", "*", "/"],
    unary_operators=[],
    model_selection="best",
    maxsize=7,
    loss="loss(x, y) = (x - y)^2",
    verbosity=1,
)
model_mid.fit(X_mid, y_mid)
print("Recovered midpoint function:")
print(model_mid)

## Recovering the Left State Update Function

Next, we recover the function that updates the left endpoint. In our arithmetic circuit, if we are in the less-than branch, the left endpoint is updated as:

\begin{equation}
L_{\text{next}} = \delta_{\text{lt}} \cdot (m + 1) + (1 - \delta_{\text{lt}}) \cdot L
\end{equation}

Here $\delta_{\text{lt}}$ is a Boolean variable (0 or 1). We simulate training data by randomly generating values for $L$ and $m$, and a random binary value for $\delta_{\text{lt}}$, then computing $L_{\text{next}}$ accordingly.

In [None]:


# %% [code]
num_samples = 1000
L_vals = np.random.uniform(0, 5, num_samples)
m_vals = np.random.uniform(5, 10, num_samples)
# δ_lt is either 0 or 1
delta_lt = np.random.randint(0, 2, num_samples)

L_next = delta_lt * (m_vals + 1) + (1 - delta_lt) * L_vals

# Training data: inputs are [L, m, δ_lt] and output is L_next.
X_state = np.column_stack([L_vals, m_vals, delta_lt])
y_state = L_next

model_state = PySRRegressor(
    niterations=1000,
    binary_operators=["+", "-", "*", "/"],
    unary_operators=[],
    model_selection="best",
    maxsize=10,
    loss="loss(x, y) = (x - y)^2",
    verbosity=1,
)
model_state.fit(X_state, y_state)
print("\nRecovered left update function:")
print(model_state)

## Recovering the Final Output Function

Finally, we try to recover the output function. Suppose that at the relevant layer we have a branch indicator $\delta_{\text{eq}}$ for the equality branch and a midpoint $m$. Our output is defined as:

\begin{equation}
o = \delta_{\text{eq}} \cdot m + (1 - \delta_{\text{eq}}) \cdot (-1)
\end{equation}

Thus, if $\delta_{\text{eq}} = 1$ the output is $m$; if $\delta_{\text{eq}} = 0$ the output is $-1$.

We simulate data by choosing random $m$ values and random binary values for $\delta_{\text{eq}}$.

In [None]:


# %% [code]
num_samples = 1000
m_vals = np.random.uniform(0, 10, num_samples)
delta_eq = np.random.randint(0, 2, num_samples)

o_vals = delta_eq * m_vals + (1 - delta_eq) * (-1)

# Training data: inputs are [m, δ_eq] and output is o.
X_out = np.column_stack([m_vals, delta_eq])
y_out = o_vals

model_out = PySRRegressor(
    niterations=1000,
    binary_operators=["+", "-", "*", "/"],
    unary_operators=[],
    model_selection="best",
    maxsize=10,
    loss="loss(x, y) = (x - y)^2",
    verbosity=1,
)
model_out.fit(X_out, y_out)
print("\nRecovered output function:")
print(model_out)

## Custom Operators and Recovered Module Functions

We define our custom operator and the functions for the modules as recovered previously:

\begin{align*}
\text{sel}(x,y,z) &= xy + (1-x)z \\
\text{mid}(L,R) &= \frac{L + R}{2} \\
\text{left\_update}(L,m,\delta_{\text{lt}}) &= \text{sel}(\delta_{\text{lt}}, m+1, L) \\
\text{output\_module}(m,\delta_{\text{eq}}) &= \text{sel}(\delta_{\text{eq}}, m, -1)
\end{align*}

In [None]:

def sel(x, y, z):
    """Custom operator: if x is 1, returns y; if x is 0, returns z."""
    return x * y + (1 - x) * z

def mid(L, R):
    """Midpoint function: returns (L + R)/2."""
    return (L + R) / 2

def left_update(L, m, delta_lt):
    """Left state update: L_next = sel(delta_lt, m+1, L)."""
    return sel(delta_lt, m + 1, L)

def output_module(m, delta_eq):
    """Output module: o = sel(delta_eq, m, -1)."""
    return sel(delta_eq, m, -1)

## Full Circuit Composition

We'll compose the recovered modules using a custom selection operator:

\begin{equation}
\text{sel}(x,y,z) = xy + (1-x)z
\end{equation}

For our example array $A = [2, 4, 6, 8, 10, 12, 14, 16]$, the ideal mapping is:

\begin{equation}
f(t) = \begin{cases}
\frac{t-2}{2} & \text{if } t \in A \\
-1 & \text{otherwise}
\end{cases}
\end{equation}

We'll use symbolic regression to recover an approximation of this piecewise function.

In [None]:
#### TODO Understand this section a bit more.... does it generalize to other inputs?

## Full Circuit Simulation: Generating Input–Output Pairs

We now define a function that simulates the full binary search circuit. The array $A$ is hard-wired, and the only public input is the target $t$.

The circuit is unrolled for $T$ iterations (with $T = \left\lceil \log_2(n) \right\rceil + 1$). At each iteration, the circuit computes the midpoint and then decides the branch:

- If $A[m] = t$, we set $\delta_{\text{eq}} = 1$ and freeze the state
- If $A[m] < t$, we set $\delta_{\text{lt}} = 1$
- If $A[m] > t$, we set $\delta_{\text{gt}} = 1$

The left state is updated using $\text{left\_update}$. Finally, the output is computed using $\text{output\_module}$ at the iteration where $\delta_{\text{eq}} = 1$.

We then generate $(t, o)$ pairs for targets in a specified range.

In [None]:

# %% [code]
import math
import numpy as np
import sympy as sp
from pysr import PySRRegressor



# %% [code]
def simulate_full_circuit(A, t):
    """
    Simulate a full binary search circuit for a fixed sorted array A and target t.
    Returns the final output o (index if t is in A, -1 otherwise).
    """
    n = len(A)
    T = math.ceil(math.log(n, 2)) + 1  # Unroll for T iterations.
    L = 0
    R = n - 1
    found = False
    output = -1
    for i in range(T):
        # Compute midpoint (using floor for index).
        m_val = math.floor(mid(L + R))
        # Decide branch if within array bounds.
        if not found and 0 <= m_val < n:
            if A[m_val] == t:
                delta_eq = 1
                delta_lt = 0
                delta_gt = 0
                found = True
            elif A[m_val] < t:
                delta_eq = 0
                delta_lt = 1
                delta_gt = 0
            else:
                delta_eq = 0
                delta_lt = 0
                delta_gt = 1
        else:
            delta_eq = 0
            delta_lt = 0
            delta_gt = 0
        # Update state: if target not found, update using our modules.
        if i < T - 1:
            if not found:
                L = left_update(L, m_val, delta_lt)
                # For simplicity, update R using a similar rule (here we use a direct update)
                R = sel(delta_gt, m_val - 1, R)
            # Once found, state remains frozen.
    # Final output: if found, use output_module from the iteration where it was found.
    if found:
        output = output_module(m_val, 1)  # since δ_eq was set to 1.
    else:
        output = -1
    return output



## Symbolic Regression to Recover the Full Mapping

Our goal is to use pySR to recover a symbolic expression for the overall mapping $f(t) = o$. The ideal mapping for our toy array is piecewise:

\begin{equation}
f(t) = \begin{cases}
\frac{t - 2}{2} & \text{if } t \in A \\
-1 & \text{otherwise}
\end{cases}
\end{equation}

We now let pySR search for an expression in terms of $t$ (using our custom operators as candidate building blocks) that fits the data.

We supply our custom operator dictionary, which includes:
- $\text{sel}$
- $\text{mid}$
- $\text{left\_update}$
- $\text{output\_module}$

**Note:** Recovering a piecewise function exactly is challenging; pySR might find an expression that approximates the behavior over our small target range.

In [None]:
# %% [code]

# Generate data for a range of target values.
A_fixed = [2, 4, 6, 8, 10, 12, 14, 16]
target_values = np.arange(1, 18)  # t = 1,2,...,17
outputs = [simulate_full_circuit(A_fixed, t) for t in target_values]

data = np.column_stack([target_values, outputs])
print("Full circuit simulation data (target, output):")
print(data)
X_full = target_values.reshape(-1, 1)
y_full = np.array(outputs)

custom_ops_full = {
    "sel": sel,
    "mid": mid,
    "left_update": left_update,
    "output_module": output_module
}

model_full = PySRRegressor(
    niterations=5000,
    binary_operators=["+", "-", "*", "/"],
    unary_operators=[],
    extra_sympy_mappings=custom_ops_full,
    model_selection="best",
    maxsize=12,
    loss="loss(x, y) = (x - y)^2",
    verbosity=1,
)
model_full.fit(X_full, y_full)
print("\nRecovered full mapping from target t to output o:")
print(model_full)

## Discussion

In this tutorial we demonstrated an end-to-end approach for recovering the entire arithmetic circuit that encodes the binary search algorithm:

1. We assumed that symbolic regression had already recovered the building blocks (modules) such as:
   
   \begin{align*}
   \text{mid}(L, R) &= \frac{L + R}{2} \\
   \text{left\_update}(L, m, \delta_{\text{lt}}) &= \text{sel}(\delta_{\text{lt}}, m+1, L) \\
   \text{output\_module}(m, \delta_{\text{eq}}) &= \text{sel}(\delta_{\text{eq}}, m, -1)
   \end{align*}
   
   using our custom operator $\text{sel}$.

2. We composed these modules into a full circuit by simulating an unrolled binary search run (for $T$ iterations) on a fixed sorted array $A$ and a public target $t$.

3. We generated $(t, o)$ pairs from many runs and then applied pySR (with our custom operators) to recover a symbolic expression for the overall mapping from target to output.

The ideal mapping for our evenly spaced array $A = [2, 4, 6, 8, 10, 12, 14, 16]$ is:

\begin{equation}
f(t) = \text{sel}(\text{inA}(t), \frac{t - 2}{2}, -1)
\end{equation}

where $\text{inA}(t)$ is an indicator function that equals 1 when $t \in A$ and 0 otherwise.

Although recovering the exact piecewise function is challenging, the recovered expression (or candidate expressions) should reflect the structure of the full circuit and the composition of the modules.

This demonstrates that symbolic regression can be used directly (via composition and custom operators) to recover an entire arithmetic circuit from input–output pairs.

Happy exploring!