# Square Span Programs

**Goal**: transform any problem in NP into a square span program.

Steps:
- Program
- Circuit
- Affine Map Constraints
- Polynomials
- SSP

Representing the program as a circuit is out of scope for now, we assume we're given a circuit.

## Step 1: Translate to a set of affine constraints

Every logic gate with fan-in 2 can be _linearized_. For example, the gate
$$
    c = g(a,b) = \neg (a \wedge b)
$$
is equivalent to
$$
    a + b - 2\bar{c} \in \{0,1\}.
$$

\[DFGK14\] gives a nice overview for the different possible gates:
![Table 1](./table1.png)

Theorem 1 formalizes this and basically guarantees that a circuit $C$ can be turned into an affine equation of the form:
$$
aV + b \in \{0,2\}^d.
$$

The proof explicitly constructs $V$ and $b$. So let's build them! :)

In [None]:
import numpy as np
import sympy

In [None]:
# Some basic types for later, in this case wires
class Wire:
    """represents an abstract Wire"""
    def value(self):
        raise NotImplemented

class InputWire(Wire):
    """represents an input wire"""
    def __init__(self, val, symbol=None):
        """val may be None if this is used in a 'symbolic' context"""
        assert val == 0 or val == 1 or val is None
        self._val = val
        self.symbol = symbol

    def value(self):
        return self._val

class OutputWire(Wire):
    """represents the output wire of a gate"""
    def __init__(self, func, symbol=None):
        self.symbol = symbol
        self._func = func

    def value(self):
        return self._func()

In [None]:
# and the Gate base class
class Gate:
    """represents a fan-in 2 binary gate"""
    def __init__(self, lwire: Wire, rwire: Wire, symbol=None):
        self.symbol = symbol
        self._lwire = lwire
        self._rwire = rwire
        self._outwire = None

    def output(self) -> Wire:
        if self._outwire is not None:
            return self._outwire

        self._outwire = self._get_outwire()
        return self._outwire

    def _get_outwire(self) -> Wire:
        raise NotImplementedError
        
    def _linearization_coeffs(self):
        """returns the coefficients of the linearization of this gate.
        
        That is, if the gate is represented as 
            \\alpha * a + \\beta * b + \\gamma *c + delta \\in {0,2},
        
        this returns (alpha, beta, gamma)"""
        
        raise NotImplementedError

    def _linearization_bias(self):
        """returns the bias, in the notation from above that is delta."""
        raise NotImplementedError


As a first example, let's build an XOR-Gate. From the table above, we can see that the equation for XOR is:
$$
c = a \oplus b \ \iff \ a + b + c \in \{0,2\}
$$

In [None]:
class XorGate(Gate):
    def _get_outwire(self):
        return OutputWire(
            lambda: self._lwire.value() ^ self._rwire.value(),
            symbol=self.symbol
        )

    def _linearization_coeffs(self):
        # a + b + c \in {0,2}
        return 1, 1, 1

    def _linearization_bias(self):
        return 0

A simple gate is a bit boring, let's start building circuits. We start with the boring boilerplate part...

In [None]:
class Circuit:
    def __init__(self, inputs: list[InputWire], symbol="C"):
        self._wires = [i for i in inputs] # hoping this copies?
        self._gates = []
        self._output_gate = None
        self.symbol = symbol

    def add_gate(self, gate: Gate):
        assert gate._lwire in self._wires
        assert gate._rwire in self._wires
        
        self._gates.append(gate)
        self._wires.append(gate.output())

    def set_outgate(self, gate: Gate):
        assert gate in self._gates
        self._output_gate = gate

    def eval(self):
        assert self._output_gate is not None, "need output gate"
        
        return self._output_gate.output().value()

    def size(self) -> (int, int):
        # m x n
        # wires x gates
        return len(self._wires), len(self._gates)

    def _gate_wire_idxs(self, gate):
        """helper function that returns the indices of the gate's wires.
        
        -> left input, right input, output"""
        
        assert gate in self._gates
        l = self._wires.index(gate._lwire)
        r = self._wires.index(gate._rwire)
        o = self._wires.index(gate.output())

        return l, r, o

Now we can turn to the proof to see how $V$ and $b$ should be constructed. In fact, 

$$ 
V = [ \ 2I  \ | \ G \ ], \quad b = (\ 0 \ | \ \delta \ ),
$$

where $G$ and $\delta$ come from the linearizations. The $2I$ block is to ensure that the input wires are binary.

$G$ has one row for every _wire_ and one column for every _gate_. So for a gate at index $i$ with the affine equation
$$
K_l x_l + K_r x_r + K_o x_o + B_i \in \{0,2\}
$$ 
we know the entries:
$$
    G_{li} = K_l
$$ $$
    G_{ri} = K_r
$$ $$
    G_{oi} = K_o
$$

and similarly $\delta_i = B_i$.

Lastly, we need an extra condition on the output gate to make sure it has the value 1 -- after all we want a satisfied circuit! We omit the details here.

With this setup, we can now implement the matrices $G, V$ and vectors $\delta, b$:

In [None]:
def _matrix_G(self):
    assert self._output_gate is not None, "need output gate"
    
    G = np.zeros(self.size())
    
    for i, gate in enumerate(self._gates):
        li, ri, oi = self._gate_wire_idxs(gate)
        lc, rc, oc = gate._linearization_coeffs()
        if gate == self._output_gate:
            oc -= 3
        G[li, i] = lc
        G[ri, i] = rc
        G[oi, i] = oc

    return G

def _vector_delta(self):
    assert self._output_gate is not None, "need output gate"

    _, n = self.size()
    delta = np.zeros(n)
    
    for i, gate in enumerate(self._gates):
        b = gate._linearization_bias()
        if gate == self._output_gate:
            b += 3
        delta[i] = b

    return delta

def matrix_V(self):
    m, n = self.size()
    return np.concat([2*np.eye(m), self._matrix_G()], axis=1)

def vector_b(self):
    m, _ = self.size()
    return np.concat([np.zeros(m), self._vector_delta()])


# yay monkey-patching classes
Circuit._matrix_G = _matrix_G
Circuit.matrix_V = matrix_V
Circuit._vector_delta = _vector_delta
Circuit.vector_b = vector_b

So we can finally look at a simple example! We take the one from the paper, where the circuit is simply
$$
a_3 = a_1 \oplus a_2
$$

In [None]:
def simple_xor_circuit():
    a1 = InputWire(None, "a1")
    a2 = InputWire(None, "a2")
    a3 = XorGate(a1, a2, "a3")

    C = Circuit([a1,a2], "C")
    C.add_gate(a3)
    C.set_outgate(a3)

    return C

In [None]:
C0 = simple_xor_circuit()
V0 = C0.matrix_V()
b0 = C0.vector_b()

In [None]:
V0

In [None]:
b0

Hurray! We've recovered the example from the book. Let's try some examples:

<font color="green">**Question 1:**
    What will the output of the following example be?
</font>

-  A. $(0, 2, 2, 2)$
-  B. $(2, 0, 5, 3)$
-  C. Impossible to tell

In [None]:
sat_ex_0 = [0,1,1]
sat_ex_0 @ V0 + b0

In [None]:
sat_ex_1 = [1,0,1]
sat_ex_1 @ V0 + b0

In [None]:
# correct gate, but not satisfied
non_sat_ex_0 = [1,1,0]
non_sat_ex_0 @ V0 + b0

In [None]:
# gate is not correctly computed
non_sat_ex_1 = [0,0,1]
non_sat_ex_1 @ V0 + b0

In [None]:
# input wires aren't binary
non_sat_ex_2 = [2,3,1]
non_sat_ex_2 @ V0 + b0

In [None]:
# We can wrap this test in a convenience function:
def accept(a, V, b):
    out = a @ V + b
    test = np.logical_or(out == 2, out == 0)
    return bool(test.all())

In [None]:
accept(sat_ex_0, V0, b0), accept(non_sat_ex_1, V0, b0)

Before we move on to the next transformation, let's quickly drop in implementations for more gates, so we can build more fun circuits later.

In [None]:
class OrGate(Gate):
    def _get_outwire(self):
        return OutputWire(
            lambda: self._lwire.value() | self._rwire.value(),
            symbol=self.symbol
        )
    def _linearization_coeffs(self):
        #      !a + !b - 2!c \in {0,1}
        # <=>  1 -a  + 1 -b -2 + 2c \in {0,1}
        # <=>  -a -b +2c +0 \in {0,1}
        # <=> -2a -2b +4c \in {0,2}
        return -2, -2, 4

    def _linearization_bias(self):
        # see coeffs
        return 0


class AndGate(Gate):
    def _get_outwire(self):
        return OutputWire(
            lambda: self._lwire.value() & self._rwire.value(),
            symbol=self.symbol
        )

    def _linearization_coeffs(self):
        #     a + b - 2c \in {0,1}
        # <=> 2a + 2b - 4c \in {0,2}
        return 2, 2, -4

    def _linearization_bias(self):
        return 0

class NotXandYGate(Gate):
    """represents a gate with formula: !x and y"""
    def _get_outwire(self):
        return OutputWire(
            lambda: not self._lwire.value() and self._rwire.value(),
            symbol=self.symbol
        )

    def _linearization_coeffs(self):
        #     !a + b âˆ’ 2c \in {0, 1}
        # <=> (1-a) + b - 2c \in {0,1}
        # <=> -a + b - 2c + 1 \in {0,1}
        # <=> -2a + 2b -4c + 2 \in {0, 2}
        return -2, 2, -4

    def _linearization_bias(self):
        return 2

As a slightly bigger example, let's consider this circuit from [PAZK]:

![circuit](circuit.png)

In [None]:
# Let's quickly see a slightly bigger example

def bigger_example_circuit():
    x1 = InputWire(None, symbol='x1')
    x2 = InputWire(None, symbol='x2')
    x3 = InputWire(None, symbol='x3')
    x4 = InputWire(None, symbol='x4')
    
    C = Circuit([x1, x2, x3, x4], symbol="phi")
    
    g1 = NotXandYGate(x1, x2, symbol="g1")
    C.add_gate(g1)
    
    g2 = OrGate(x3, x4, symbol = "g2")
    C.add_gate(g2)
    
    g_out = AndGate(g1.output(), g2.output(), symbol="g_out")
    C.add_gate(g_out)
    C.set_outgate(g_out)

    return C

In [None]:
C1 = bigger_example_circuit()
V1 = C1.matrix_V()
b1 = C1.vector_b()

print(V1)
print()
print(b1)

In [None]:
# Happy case: satisfying assignment and correct gate values
a1_good = np.array([
    0, # x1
    1, # x2
    1, # x3
    1, # x4
    1, # !x1 and x2
    1, # x3 or x4
    1, # output
])

a1_good @ V1 + b1

In [None]:
# Problem 1: non-satsfying assignment (but correct 'computations')
a1_sad  = np.array([
    0, # x1
    1, # x2
    0, # x3
    0, # x4
    1, # !x1 and x2
    0, # x3 or x4
    0, # output
])
a1_sad @ V1 + b1

## Step 2: Transform V, b into polynomials

Notation:
- $m$ number of wires
- $n$ number of gates
- $d = n + m$
- $p$ a prime with $p \geqslant \max(d, 8)$

From now on, we work on $\mathbb{Z}_p$. 

**Question:** Why can we do that?

### Constructing the polynomials
We fix distinct elements in $\mathbb{Z}_q$ and call them $r_1, ..., r_d$. We want to encode all of $V$ and $b$ in polynomials.

Define $v_0(x), ..., v_m(x)$ such that:
$$
    v_0(r_j) = b_j - 1,
$$ $$
    v_i(r_j) = V_{ij}.
$$

Hence, the circuit $C$ is now satisfiable iff there is an $a \in \mathbb{Z}_p^m$ that satisfies, for all $r_j$:
$$
\left( v_0(r_j) + \sum_{i=1}^m a_i v_i (r_j) \right)^2 = 1.
$$

**Question:** Why is that?


Before we move on, let's verify this fact in code. We first define functions that help us get such polynomials, the details are unimportant here.


In [None]:
def lagrange_basis(nodes: list[int], var, field):
    polys = []

    for j, val_j in enumerate(nodes):
        val_j = field.convert(val_j)
        expr = field.one
        for m, val_m in enumerate(nodes):
            if m == j:
                continue
            val_m = field.convert(val_m)

            expr *= (var - val_m) * ((val_j - val_m)**-1)
        polys.append(expr.as_poly(domain=field))

    return polys

def polynomial_with_values(nodes: list[int], vals: list[int], var, field):
    assert len(nodes) == len(vals)
    
    basis = lagrange_basis(nodes, var, field)
    expr = field.zero
    for v, b in zip(vals, basis):
        expr += v * b

    return expr.as_poly()

In [None]:
# with the helper functions from above, defining v_0 and v_i is not that hard
def _ssp_polynomial_v0(rs, b: np.array, var, field):
    b_finite = [field.convert(x) for x in b-1]
    return polynomial_with_values(rs, b_finite, var, field)

def _ssp_polynomial_vi(rs, i: int, V: np.array, var, field):
    i -= 1
    vi_finite = [field.convert(x) for x in V[i]]
    return polynomial_with_values(rs, vi_finite, var, field)

In [None]:
# we also define a helper function that sets up the field and
# gives us the r_j

import random

def prepare_field(d: int):
    p = sympy.ntheory.generate.nextprime(max(8, d))
    field = sympy.GF(p)
    rs = []

    while len(rs) < d:
        el = field.convert(random.randrange(p))
        if el in rs:
            continue
        rs.append(el)

    return rs, field

In [None]:
# Let's pick up the simple example again
C = simple_xor_circuit()
V = C.matrix_V()
b = C.vector_b()

In [None]:
m, n = C.size()
d = m + n

In [None]:
random.seed(42)
rs, field = prepare_field(d)
x = sympy.symbols("x")

In [None]:
v0 = _ssp_polynomial_v0(rs, b, x, field)
v0

In [None]:
vi = []
for i in range(m):
    vi.append(_ssp_polynomial_vi(rs, i+1, V, x, field))
    print(vi[-1])

In [None]:
print(b)

<font color="green">**Question 2:**
    What should the output of the next cell be?
</font>

-  A. `0 0 0 3`
-  B. `-1 -1 -1 2`
-  C. `5 0 -5 -1`
-  D. Impossible to tell

In [None]:
# Let's first see again that v_0 satisfies its condition

for rj in rs:
    print(v0(rj), end=' ')

In [None]:
# Let's now build the large squared polynomial

def mega_polynomial(a: np.array, V: np.array, b: np.array, rs=None, field=None, var=None):
    m, d = V.shape
    
    if var is None:
        var = sympy.symbols("x")
    if rs is None or field is None:
        rs, field = prepare_field(d)
        
    polynomial = _ssp_polynomial_v0(rs, b, var, field)

    for i in range(m):
        polynomial += (a[i] * _ssp_polynomial_vi(rs, i+1, V, var, field))

    return polynomial**2, rs

In [None]:
poly_good, _ = mega_polynomial([1,0,1], V, b, rs, field, x)
poly_good

In [None]:
for rj in rs:
    print(poly_good(rj))

In [None]:
poly_bad, _ = mega_polynomial([1,1,1], V, b, rs, field, x)
poly_bad

In [None]:
for rj in rs:
    print(poly_bad(rj))

In [None]:
# Let's check also with the bigger circuit

C = bigger_example_circuit()
V = C.matrix_V()
b = C.vector_b()

In [None]:
happy_assignment = [
    0, # x1
    1, # x2
    0, # x3
    1, # x4
    1, # !x1 and x2
    1, # x3 or x4
    1  # output
]

big_poly_good, rs = mega_polynomial(happy_assignment, V, b)
big_poly_good

In [None]:
for rj in rs:
    print(big_poly_good(rj))

In [None]:
sad_assignment = [
    1, # x1
    1, # x2
    0, # x3
    1, # x4
    1, # !x1 and x2
    1, # x3 or x4
    1  # output
]

big_poly_sad, rs = mega_polynomial(sad_assignment, V, b)
big_poly_sad

In [None]:
for rj in rs:
    print(big_poly_sad(rj))

### Framing this as an SSP

Let's check out the (general) definition the authors give of an SSP:

> **Definition 1 (Square span program).** A square span program $Q$ over the
field $\mathbb{F}$ consists of <font color="cyan"> $m + 1$ polynomials $v_0(x), v_1(x), ... , v_m(x)$ and a target polynomial $t(x)$ </font> such that $\deg(v_i(x)) \leqslant \deg(t(x))$ for all $i = 0, . . . , m.$
> 
> We say that the square span program $Q$ has size $m$ and degree $d = \deg(t(x))$.
We say that $Q$ <font color="cyan"> **accepts** an input $(a_1, . . . , a_\ell) \in \mathbb{F}$ if and only if there exist
$a_{\ell+1}, . . . , a_m \in \mathbb{F}$ satisfying
> $$
 t(x) \quad \text{  divides  } \quad \left( v_0(x) + \sum_{i=1}^m a_i v_i(x) \right)^2 - 1
$$ </font>
>
> We say that $Q$ verifies a boolean function $f : \{0, 1\}^\ell \rightarrow \{0, 1\}$ if it accepts
exactly those inputs $a \in \mathbb{F}^\ell$ that satisfy $a \in \{0, 1\}^\ell$ and $f (a) = 1.$

(emphasis added by us)

The goal now is to use the polynomials $v_i$ that we defined before and see them as an instance of an SSP.

Indeed, we note from before that for a satisfying assignment $a \in \mathbb{F}_p^m$, the "mega-polynomial" (minus 1)
$$
P(x) = \left( v_0(r_j) + \sum_{i=1}^m a_i v_i (r_j) \right)^2 - 1
$$

has roots $r_1, ..., r_d$. Thus,
$$
P(x) = R(x) \cdot \prod_{j=1}^d (x - r_j),
$$
meaning $\prod_{j=1}^d (x - r_j)$ divides $P(x)$!

Therefore, we can choose that as our target polynomial $t(x)$ and get an SSP that is equivalent to the original circuit. Nice!

Again, we can test that in code:

In [None]:
def target_polynomial(rs, var, field):
    expr = 1
    for rj in rs:
        expr *= (var - rj)

    return expr.as_poly(domain=field)

def divides(target, poly) -> bool:
    """true if target divides polynomial"""

    _, rem = poly.div(target)
    return rem == 0

In [None]:
C = simple_xor_circuit()
V = C.matrix_V()
b = C.vector_b()

sat_poly, rs = mega_polynomial([1,0,1], V, b)
sat_target = target_polynomial(rs, sat_poly.gen, sat_poly.domain)

In [None]:
sat_poly

In [None]:
sat_target

In [None]:
# important: do not forget the -1 ;)
divides(sat_target, sat_poly - 1)

In [None]:
C = bigger_example_circuit()
V = C.matrix_V()
b = C.vector_b()

In [None]:
bad_assignment = [
    0, # x1
    1, # x2
    0, # x3
    0, # x4
    1, # !x1 and x2
    0, # x3 or x4
    0  # output
]
bad_poly, rs = mega_polynomial(bad_assignment, V, b)
bad_target = target_polynomial(rs, bad_poly.gen, bad_poly.domain)

In [None]:
divides(bad_target, bad_poly)