In [1]:
import numpy as np
import sympy

In [2]:
class Wire:
    """represents an abstract Wire"""
    def value(self):
        raise NotImplemented

class InputWire(Wire):
    """represents an input wire"""
    def __init__(self, val, symbol=None):
        """val may be None if this is used in a 'symbolic' context"""
        assert val == 0 or val == 1 or val is None
        self._val = val
        self.symbol = symbol

    def value(self):
        return self._val

class OutputWire(Wire):
    """represents the output wire of a gate"""
    def __init__(self, func, symbol=None):
        self.symbol = symbol
        self._func = func

    def value(self):
        return self._func()

class Gate:
    """represents a fan-in 2 binary gate"""
    def __init__(self, lwire: Wire, rwire: Wire, symbol=None):
        self.symbol = symbol
        self._lwire = lwire
        self._rwire = rwire
        self._outwire = None

    def output(self) -> Wire:
        if self._outwire is not None:
            return self._outwire

        self._outwire = self._get_outwire()
        return self._outwire

    def _get_outwire(self) -> Wire:
        raise NotImplementedError
        
    def _linearization_coeffs(self):
        """returns the coefficients of the linearization of this gate.
        
        That is, if the gate is represented as 
            \\alpha * a + \\beta * b + \\gamma *c + delta \\in {0,2},
        
        this returns (alpha, beta, gamma)"""
        
        raise NotImplementedError

    def _linearization_bias(self):
        """returns the bias, in the notation from above that is delta."""
        raise NotImplementedError

class XorGate(Gate):
    def _get_outwire(self):
        return OutputWire(
            lambda: self._lwire.value() ^ self._rwire.value(),
            symbol=self.symbol
        )

    def _linearization_coeffs(self):
        # a + b + c \in {0,2}
        return 1, 1, 1

    def _linearization_bias(self):
        return 0

class OrGate(Gate):
    def _get_outwire(self):
        return OutputWire(
            lambda: self._lwire.value() | self._rwire.value(),
            symbol=self.symbol
        )
    def _linearization_coeffs(self):
        #      !a + !b - 2!c \in {0,1}
        # <=>  1 -a  + 1 -b -2 + 2c \in {0,1}
        # <=>  -a -b +2c +0 \in {0,1}
        # <=> -2a -2b +4c \in {0,2}
        return -2, -2, 4

    def _linearization_bias(self):
        # see coeffs
        return 0

class AndGate(Gate):
    def _get_outwire(self):
        return OutputWire(
            lambda: self._lwire.value() & self._rwire.value(),
            symbol=self.symbol
        )

    def _linearization_coeffs(self):
        #     a + b - 2c \in {0,1}
        # <=> 2a + 2b - 4c \in {0,2}
        return 2, 2, -4

    def _linearization_bias(self):
        return 0

class Circuit:
    def __init__(self, inputs: list[InputWire], symbol="C"):
        self._wires = [i for i in inputs] # hoping this copies?
        self._gates = []
        self._output_gate = None
        self.symbol = symbol

    def add_gate(self, gate: Gate):
        assert gate._lwire in self._wires
        assert gate._rwire in self._wires
        
        self._gates.append(gate)
        self._wires.append(gate.output())

    def set_outgate(self, gate: Gate):
        assert gate in self._gates
        self._output_gate = gate

    def eval(self):
        if self._output_gate is None:
            print("no output gate set, using last one")
            return self._gates[-1].output().value()
        return self._output_gate.output().value()

    def size(self) -> (int, int):
        # m x n
        # wires x gates
        return len(self._wires), len(self._gates)

    def _gate_wire_idxs(self, gate):
        """return the indices of the gate's wires. -> left input, right input, output"""
        assert gate in self._gates
        l = self._wires.index(gate._lwire)
        r = self._wires.index(gate._rwire)
        o = self._wires.index(gate.output())

        return l, r, o

    def _matrix_G(self):
        assert self._output_gate is not None, "need output gate"
        
        G = np.zeros(self.size())
        
        for i, gate in enumerate(self._gates):
            li, ri, oi = self._gate_wire_idxs(gate)
            lc, rc, oc = gate._linearization_coeffs()
            if gate == self._output_gate:
                oc -= 3
            G[li, i] = lc
            G[ri, i] = rc
            G[oi, i] = oc

        return G

    def _vector_delta(self):
        assert self._output_gate is not None, "need output gate"

        _, n = self.size()
        delta = np.zeros(n)
        
        for i, gate in enumerate(self._gates):
            b = gate._linearization_bias()
            if gate == self._output_gate:
                b += 3
            delta[i] = b

        return delta

    def matrix_V(self):
        m, n = self.size()
        return np.concat([2*np.eye(m), self._matrix_G()], axis=1)

    def vector_b(self):
        m, _ = self.size()
        return np.concat([np.zeros(m), self._vector_delta()])


In [3]:
def accept(a, V, b):
    a = np.array(a)
    out = a @ V + b
    test = np.logical_or(out == 2, out == 0)
    return bool(test.all())

def test_xor():
    a = InputWire(None, symbol="a")
    b = InputWire(None, symbol="b")

    C = Circuit([a, b], symbol="C")
    g = XorGate(a,b, symbol="c")
    C.add_gate(g)
    C.set_outgate(g)

    V, b = C.matrix_V(), C.vector_b()
    for l in [0,1]:
        for r in [0,1]:
            if l^r:
                assert accept([l, r, 1], V, b)
            else:
                assert not accept([l, r, 0], V, b)
                assert not accept([l, r, 1], V, b)

    print("xor passed")

def test_or():
    a = InputWire(None, symbol="a")
    b = InputWire(None, symbol="b")

    C = Circuit([a, b], symbol="C")
    g = OrGate(a,b, symbol="c")
    C.add_gate(g)
    C.set_outgate(g)

    V, b = C.matrix_V(), C.vector_b()
    for l in [0,1]:
        for r in [0,1]:
            if l or r:
                assert accept([l, r, 1], V, b), [l, r, 1]
            else:
                assert not accept([l, r, 0], V, b), [l, r, 0]
                assert not accept([l, r, 1], V, b), [l, r, 1]

    print("or passed")

def test_and():
    a = InputWire(None, symbol="a")
    b = InputWire(None, symbol="b")

    C = Circuit([a, b], symbol="C")
    g = AndGate(a,b, symbol="c")
    C.add_gate(g)
    C.set_outgate(g)

    V, b = C.matrix_V(), C.vector_b()
    for l in [0,1]:
        for r in [0,1]:
            if l and r:
                assert accept([l, r, 1], V, b), [l, r, 1]
            else:
                assert not accept([l, r, 0], V, b), [l, r, 0]
                assert not accept([l, r, 1], V, b), [l, r, 1]

    print("and passed")

In [4]:
test_xor()
test_and()
test_or()

xor passed
and passed
or passed


In [5]:
def lagrange_basis(nodes: list[int], var, field):
    polys = []

    for j, val_j in enumerate(nodes):
        val_j = field.convert(val_j)
        expr = field.one
        for m, val_m in enumerate(nodes):
            if m == j:
                continue
            val_m = field.convert(val_m)

            expr *= (var - val_m) * ((val_j - val_m)**-1)
        polys.append(expr.as_poly(domain=field))

    return polys

def polynomial_with_values(nodes: list[int], vals: list[int], var, field):
    assert len(nodes) == len(vals)
    
    basis = lagrange_basis(nodes, var, field)
    expr = field.zero
    for v, b in zip(vals, basis):
        expr += v * b

    return expr.as_poly()

In [6]:
import random

def prepare_field(d: int):
    p = sympy.ntheory.generate.nextprime(max(8, d))
    field = sympy.GF(p)
    rs = []

    while len(rs) < d:
        el = field.convert(random.randrange(p))
        if el in rs:
            continue
        rs.append(el)

    return rs, field

def _ssp_polynomial_v0(rs, b: np.array, var, field):
    b_finite = [field.convert(x) for x in b-1]
    return polynomial_with_values(rs, b_finite, var, field)

def _ssp_polynomial_vi(rs, i: int, V: np.array, var, field):
    i -= 1
    vi_finite = [field.convert(x) for x in V[i]]
    return polynomial_with_values(rs, vi_finite, var, field)

In [7]:
def mega_polynomial(a: np.array, V: np.array, b: np.array):
    m, d = V.shape
    x = sympy.symbols("x")
    rs, field = prepare_field(d)
    polynomial = _ssp_polynomial_v0(rs, b, x, field)

    for i in range(m):
        a_i = a[i]
        polynomial += (a_i * _ssp_polynomial_vi(rs, i+1, V, x, field))

    return polynomial**2, rs

In [8]:
def target_polynomial(rs, field):
    expr = 1
    x = sympy.symbols("x")
    for rj in rs:
        expr *= (x - rj)

    return expr.as_poly(domain=field)

def divides(target, poly) -> bool:
    """true if target divides polynomial"""

    _, rem = poly.div(target)
    return rem == 0

In [12]:
def test_xor_ssp():
    a = InputWire(None, symbol="a")
    b = InputWire(None, symbol="b")

    C = Circuit([a, b], symbol="C")
    g = XorGate(a,b, symbol="c")
    C.add_gate(g)
    C.set_outgate(g)

    V, b = C.matrix_V(), C.vector_b()
    for l in [0,1]:
        for r in [0,1]:
            for out in [0,1]:
                assignment = np.array([l,r,out])
                poly, rs = mega_polynomial(assignment, V, b)
                target = target_polynomial(rs, poly.domain)

                if l^r and out:
                    assert divides(target, poly-1)
                else:
                    assert not divides(target, poly-1)

    print("xor passed")

def test_or_ssp():
    a = InputWire(None, symbol="a")
    b = InputWire(None, symbol="b")

    C = Circuit([a, b], symbol="C")
    g = OrGate(a,b, symbol="c")
    C.add_gate(g)
    C.set_outgate(g)

    V, b = C.matrix_V(), C.vector_b()
    for l in [0,1]:
        for r in [0,1]:
            for out in [0,1]:
                assignment = np.array([l,r,out])
                poly, rs = mega_polynomial(assignment, V, b)
                target = target_polynomial(rs, poly.domain)

                if (l or r) and out:
                    assert divides(target, poly-1)
                else:
                    assert not divides(target, poly-1)

    print("or passed")

def test_and_ssp():
    a = InputWire(None, symbol="a")
    b = InputWire(None, symbol="b")

    C = Circuit([a, b], symbol="C")
    g = AndGate(a,b, symbol="c")
    C.add_gate(g)
    C.set_outgate(g)

    V, b = C.matrix_V(), C.vector_b()
    for l in [0,1]:
        for r in [0,1]:
            for out in [0,1]:
                assignment = np.array([l,r,out])
                poly, rs = mega_polynomial(assignment, V, b)
                target = target_polynomial(rs, poly.domain)

                if (l and r) and out:
                    assert divides(target, poly-1)
                else:
                    assert not divides(target, poly-1)

    print("and passed")

In [13]:
test_xor_ssp()
test_or_ssp()
test_and_ssp()

xor passed
or passed
and passed
