# Naive Plonk

Let's write the most naive and simple plonk kind of protocol. This will be very naive, and very useless. The actual protocol is much more convoluted. It's full of details and optimizations. Optimizations that get in the way and hide the core ideas. So let's start the simplest we can. Even in this minimalistic setting, we can start seeing some of the magic happen.

We will start by proving the satisfiability of a set of equations. This is only half of the story to make proofs of execution of arithmetic circuits. But it will serve as a starting point.

We need the following ingredients for now. Everything lives in a prime field $\mathbb{F}_p$.

Common and of public knowledge:
- A set of degree two polynomial expressions $\{E_i\}_{i=1}^N$  of the form $E_i(X, Y, Z) = Q_L X + Q_R Y + Q_M X Y + Q_O Z + Q_C$. We'll call these *equations* since we'll only be interested in them for the solutions $(A, B, C)$ to the equation $E_i(A, B, C) = 0$.

Known only to the prover:
- A set of triplets $\{(A_i, B_i, C_i)\}_{i=1}^N$ such that each triplet is a solution to the corresponding equation. More precisely: $E_i(A_i, B_i, C_i) = 0$ for all $i=1,\dots,N$.

The goal is to describe a "protocol" between a prover and a verifier in which the prover convinces the verifier that he is in possession of such a set of triplets.

Let's start by implementing `Equations` and `Triplets`


    <b>Why do we talk about "sets of equations" and not "systems of equations"?</b>

When talking about a system of equations it's implicit that by "a solution" we mean values that satisfy each of the equations of the system simultaneously. For example a solution to the system of equations $$\begin{aligned}X + XY + Z &= 0 \\ X + Y - Z &= 0 \end{aligned},$$ would be a triplet $(A, B, C)$ such that it satisfies *both* equations. For example $(A, B, C) = (1, -1, 0)$.

In our context we have independent equations, each with its independent solution. That's why we talk about sets of equations. So in our context if $E_1 = X - XY + Z$ and $E_2 = X + Y - Z$, then we are intereseted in pairs of triplets $(A_1, B_1, C_1)$ and $(A_2, B_2, C_2)$ such that $A_1 - A_1 B_1 + Z_1 = 0$ and $A_2 + B_2 - Z_1 = 0$. For example $\{(0, 1, 0), (0, 1, 1)\}$.

That's why we talk about sets and not systems. We have sets of independent equations and sets of independent solutions to each one.
    
<b>Eventually for a working Plonk version one actually needs to see these somehow as systems of equations. Checking that solution values are consistent across equations is a rabbit hole in itself that we'll cover later on. For this quest we are only interested in them as sets.</b>

# Setting up

In [83]:
from zk_adventures_types import F, Polynomial

# Equations and triplets

In [85]:
from enum import IntEnum

class Equation:
    """An expression of form Q_L X + Q_R Y + Q_M X Y + Q_O Z + Q_C on variables X, Y and Z"""
    def __init__(self, Q_L: int, Q_R: int, Q_M: int, Q_O: int, Q_C: int):
        self._values = (F(Q_L), F(Q_R), F(Q_M), F(Q_O), F(Q_C))
        
    def values(self):
        return self._values
        
    def __getitem__(self, index):
        if not isinstance(index, self.Index):
            raise ValueError
        return self._values[index]
    
    class Index(IntEnum):
        L = 0, 
        R = 1,
        M = 2,
        O = 3,
        C = 4

class Triplet:
    """A triplet of values (A, B, C) in the finite field"""
    def __init__(self, A: int, B: int, C: int):
        self._values = (F(A), F(B), F(C))
    
    def values(self):
        return self._values
        
    def __getitem__(self, index):
        if not isinstance(index, self.Index):
            raise ValueError
        return self._values[index]

    class Index(IntEnum):
        A = 0, 
        B = 1,
        C = 2

def f(Q_L, Q_R, Q_M, Q_O, Q_C, A, B, C):
    """
    Multivariate polynomial encoding correct satisfiability of solutions to the equations
    """
    return Q_L * A + Q_R * B + Q_M * A * B + Q_O * C + Q_C

def is_solution(triplet: Triplet, equation: Equation) -> bool:
    """Check whether `triplet` is a solution to `equation`"""
    Q_L = equation[Equation.Index.L] 
    Q_R = equation[Equation.Index.R]
    Q_M = equation[Equation.Index.M]
    Q_O = equation[Equation.Index.O]
    Q_C = equation[Equation.Index.C]
    A = triplet[Triplet.Index.A]
    B = triplet[Triplet.Index.B]
    C = triplet[Triplet.Index.C]
    return f(Q_L,Q_R,Q_M,Q_O,Q_C,A,B,C) == 0

In [86]:
# Equations that model the program z = xor(x, y)
equations = [
    Equation(1, 0, -1, 0, 0),
    Equation(1, 0, -1, 0, 0),
    Equation(1, 0, -1, 0, 0),
    Equation(1, 1, -2, -1, 0)
]

In [87]:
triplets = [
    Triplet(1, 1, 1),
    Triplet(1, 1, 1),
    Triplet(0, 0, 1),
    Triplet(1, 1, 0)
]

In [88]:
for triplet, equation in zip(triplets, equations):
    assert(is_solution(triplet, equation))
    
assert(not is_solution(Triplet(2, 2, 0), equations[0]))

# Polynomial interpolation

A key component in what follows is encoding vectors as polynomials through polynomial interpolation. Let's describe the context and motivation.

Let $\{E_i\}_{i=1}^N$ be a set of equations and $\{(A_i, B_i, C_i)\}_{i=1}^N$ a corresponding solution for each one. Each equation $E_i$ has its own coefficients $Q_L^i, Q_R^i$, etc. And let $H$ be a set of the form $H=\{1, \omega, \omega^2, \dots, \omega^{N-1}\}$, where $\omega^N=1$. In other words, the *order* of $\omega$ is $N$. We want to interpolate the coefficients and entries of the equations and triplets at $H$ in the following sense.

Let $q_L$ be the polynomial such that $q_L(\omega^i) = Q_L^i$ for all $i$. Similarly with $q_R, q_M, q_C$ and $q_O$. Let also $a$ be the polynomial such that $a(\omega^i) = A_i$ for all $i$. And similarly for $b$ and $c$.

By composing the multilinear polynomial $f$ with these univariate polynomials we obtain $$g = f(q_L, q_R, q_M, q_C, q_O, a, b, c).$$ The polynomial $g$ is univariate and satisfies $$g(\omega^i) = 0$$ for all $i$. Moreover, the set $\{(A_i, B_i, C_i)\}_{i=1}^N$ is a set of solutions to each equation $E_i$ if and only if $g(\omega^i) = 0$ for all $i$. On the other hand, $g$ has such a property if and only if there exists a polynomial $t$ such that $$g = (X^N - 1)t$$

Why going down this road? The Schwarz-Zippel lemma says that, with high probability, $g$ has such a decomposition if 
$$ g(z) = (z^N - 1)t(z) $$ for some random element $z$. In which case, unrolling all the reasoning back, by just checking that equality at one point $z$ we get with high probability that the polynomials $a$, $b$ and $c$ interpolate solutions to the equations $E_i$. This means that a single point check implies, with high probability, a global satisfiability of a set of solutions.

In [89]:
class Domain:
    def __init__(self, omega: int):
        """Produces the set of all powers of `omega` modulo `p` and stores them in `self._elements`"""
        omega = F(omega)
        size = omega.multiplicative_order()
        # COMPLETE
        self._elements = []
        for k in range(size):
            self._elements.append(omega ** k)
    
    @classmethod
    def of_size(cls, size: int):
        """Returns a domain of size `size`."""
        # generator of the full units group of 𝔽. That is, the powers 
        # of `generator` produce all nonzero elements of 𝔽
        generator = F.multiplicative_generator()
        p = F.order()
        if size <= 0 or (p - 1) % size != 0:
            raise ValueError
        # COMPLETE
        k = int((p-1) / size)
        return cls(generator ** k)
    
    def __len__(self):
        return len(self._elements)
    
    def __getitem__(self, index):
        return self._elements[index]

In [90]:
assert(list(Domain.of_size(8)) == [1, 4096, 65281, 16, 65536, 61441, 256, 65521])

In [110]:
def interpolate_triplets(domain: Domain, triplets: list[Triplet], index: Triplet.Index) -> Polynomial:
    """Returns the polynomial `p` such that `p(domain[i]) = triplets[i][index]"""
    size = len(triplets)
    domain = Domain.of_size(size)
    points = []
    for i in range(size):
        points.append((domain[i], triplets[i][index]))
    return Polynomial.lagrange_polynomial(points, 'divided_difference')
    
def interpolate_equations(domain: Domain, equations: list[Equation], index: Equation.Index) -> Polynomial:
    """Returns the polynomial `p` such that `p(domain[i]) = equation[i][index]"""
    size = len(equations)
    domain = Domain.of_size(size)
    points = []
    for i in range(size):
        points.append((domain[i], equations[i][index]))
    return Polynomial.lagrange_polynomial(points, 'divided_difference')

In [111]:
domain = Domain.of_size(size=len(triplets))

a = interpolate_triplets(domain, triplets, Triplet.Index.A)
b = interpolate_triplets(domain, triplets, Triplet.Index.B)
c = interpolate_triplets(domain, triplets, Triplet.Index.C)

q_L = interpolate_equations(domain, equations, Equation.Index.L)
q_R = interpolate_equations(domain, equations, Equation.Index.R)
q_M = interpolate_equations(domain, equations, Equation.Index.M)
q_O = interpolate_equations(domain, equations, Equation.Index.O)
q_C = interpolate_equations(domain, equations, Equation.Index.C)

In [113]:
X = Polynomial.monomial(1)
fpol = f(q_L, q_R, q_M, q_O, q_C, a, b, c)
r = (X ** 4 - 1) 
t= fpol // r

In [114]:
assert(t * (X ** 4 - 1) == f(q_L, q_R, q_M, q_O, q_C, a, b, c))
assert(t(0xfeca) == 49096)

### Oracles

The idea will be that the prover constructs all these polynomials and somehow communicates them to the verifier so that he can sample a random $z$ and perform the single point check. We talk about *oracles* when we want to abstract away from how that communication is done.

The size of the polynomials is as big as the size of the solutions $(A_i, B_i, C_i)$. So sending the whole set of coefficients is silly. Because reading that list of coefficients would require an effort from the verifier of the same sort as reading the set of solutions. In the end, this is solved with what's called *Polynomial Commitment Schemes*. But introducing them here would add so much complexity that the idea we are trying to convey about the single point checks would get lost in the way.

So let's imagine there's a thing called a Polynomial Oracle that the prover can send to the verifier. Let's assume that it is lightweight. And it can be used to query the value taken by a polynomial at any point. Right now we'll use a *naive oracle*, which holds the entire polynomial in a secret attribute that the verifier can't see (and the communication of this instance to the verifier is cheap because the verifier has access to the memory of the Python interpreter of the prover).

In [115]:
class Oracle:
    def __init__(self, polynomial: Polynomial):
        raise NotImplementedError("subclass responsibility")
        
    def query(self, z):
        raise NotImplementedError("subclass responsibility")

In [117]:
import sys

class NaiveOracle(Oracle):
    def __init__(self, polynomial: Polynomial):
        self._polynomial = polynomial
    
    def query(self, z):
        """
        One-time single use function. Returns the value of the polynomial at `z`.
        On first use this function dumps the polynomial and returns `None`
        for subsequent calls.
        """
        if self._polynomial is not None:
            y = self._polynomial(z)
            self._polynomial = None
            return y

In [118]:
random_polynomial = Polynomial.random_element()
oracle = NaiveOracle(random_polynomial)
assert(oracle.query(10) == random_polynomial(10))

In [125]:
class PlonkEquationSatisfiabilityProver:
    def prove(self, equations: list[Equation], triplets: list[Triplet]):
        if len(equations) != len(triplets):
            raise ValueError

        domain = Domain.of_size(size=len(equations))

        a = interpolate_triplets(domain, triplets, Triplet.Index.A)
        b = interpolate_triplets(domain, triplets, Triplet.Index.B)
        c = interpolate_triplets(domain, triplets, Triplet.Index.C)

        q_L = interpolate_equations(domain, equations, Equation.Index.L)
        q_R = interpolate_equations(domain, equations, Equation.Index.R)
        q_M = interpolate_equations(domain, equations, Equation.Index.M)
        q_O = interpolate_equations(domain, equations, Equation.Index.O)
        q_C = interpolate_equations(domain, equations, Equation.Index.C)

        Z_H = f(q_L, q_R, q_M, q_O, q_C, a, b, c)
        X = Polynomial.monomial(1)
        roots = X ** len(equations) - 1
        t = Z_H // roots

        return [NaiveOracle(a), NaiveOracle(b), NaiveOracle(c), NaiveOracle(t)]
    
class PlonkEquationSatisfiabilityVerifier:    
    def verify(self, equations: list[Equation], proof: list[NaiveOracle]) -> bool:
        if len(proof) != 4:
            raise ValueError
                
        domain = Domain.of_size(size=len(equations))

        q_L = interpolate_equations(domain, equations, Equation.Index.L)
        q_R = interpolate_equations(domain, equations, Equation.Index.R)
        q_M = interpolate_equations(domain, equations, Equation.Index.M)
        q_O = interpolate_equations(domain, equations, Equation.Index.O)
        q_C = interpolate_equations(domain, equations, Equation.Index.C)
        
        z = F.random_element()
        
        oracle_a, oracle_b, oracle_c, oracle_t = proof
        a_z = oracle_a.query(z)
        b_z = oracle_b.query(z)
        c_z = oracle_c.query(z)
        t_z = oracle_t.query(z)

        X = Polynomial.monomial(1)
        roots = X ** len(equations) - 1
        
        # Point check implies global satisfiability with high probability
        left_hand_side = f(q_L(z), q_R(z), q_M(z), q_O(z), q_C(z), a_z, b_z, c_z)
        right_hand_side = roots(z) * t_z
        
        return left_hand_side == right_hand_side

In [126]:
prover = PlonkEquationSatisfiabilityProver()
proof = prover.prove(equations, triplets)

verifier = PlonkEquationSatisfiabilityVerifier()
assert(verifier.verify(equations, proof))

### Slight improvement: Preprocessed input

You may have noticed that the prover and verifier need access to the polynomials $q_L, q_R$, etc. These only depend on the set of equations and can be precomputed to make everyone's life easier.

In [127]:
class PlonkEquationSatisfiabilitySetup:
    def setup(self, equations: list[Equation]) -> tuple:
        
        domain = Domain.of_size(size=len(equations))

        q_L = interpolate_equations(domain, equations, Equation.Index.L)
        q_R = interpolate_equations(domain, equations, Equation.Index.R)
        q_M = interpolate_equations(domain, equations, Equation.Index.M)
        q_O = interpolate_equations(domain, equations, Equation.Index.O)
        q_C = interpolate_equations(domain, equations, Equation.Index.C)
        
        prover_input = (domain, (q_L, q_R, q_M, q_O, q_C))
        verifier_input = (domain, tuple(map(NaiveOracle, (q_L, q_R, q_M, q_O, q_C))))
        
        return (prover_input, verifier_input)

In [143]:
class PlonkEquationSatisfiabilityProver:
    def prove(self, preprocessed_input: tuple, triplets: list[Triplet]):
        if len(equations) != len(triplets):
            raise ValueError
        domain, qs = preprocessed_input
        
        a = interpolate_triplets(domain, triplets, Triplet.Index.A)
        b = interpolate_triplets(domain, triplets, Triplet.Index.B)
        c = interpolate_triplets(domain, triplets, Triplet.Index.C)

        Z_H = f(*qs, a, b, c)

        X = Polynomial.monomial(1)
        roots = X ** len(triplets) - 1
        t = Z_H // roots
            
        return [NaiveOracle(a), NaiveOracle(b), NaiveOracle(c), NaiveOracle(t)]
    
class PlonkEquationSatisfiabilityVerifier:
    def verify(self, preprocessed_input: tuple, proof: list[NaiveOracle]) -> bool:
        if len(proof) != 4:
            raise ValueError

        domain, oracles_q = preprocessed_input
        z = F.random_element()
        oracle_a, oracle_b, oracle_c, oracle_t = proof
        a_z = oracle_a.query(z)
        b_z = oracle_b.query(z)
        c_z = oracle_c.query(z)
        t_z = oracle_t.query(z)

        X = Polynomial.monomial(1)
        roots = X ** len(triplets) - 1
        
        # Point check implies global satisfiability with high probability
        q_z = (oracle.query(z) for oracle in oracles_q)
            
        left_hand_side = f(*q_z, a_z, b_z, c_z)
        right_hand_side = roots(z) * t_z
        
        return left_hand_side == right_hand_side

In [145]:
setup = PlonkEquationSatisfiabilitySetup()
prover_input, verifier_input = setup.setup(equations)

prover = PlonkEquationSatisfiabilityProver()
proof = prover.prove(prover_input, triplets)

verifier = PlonkEquationSatisfiabilityVerifier()
assert(verifier.verify(verifier_input, proof))

# Constant time verification
Notice that the number of operations of the verifier does not depend on the size of the set of equations. This means that the set can have 4 equations or $2^{20}$ and the code will run in the same time. This is still silly and useless since for it to be true the verifier needs to read the instances of the oracles from the python interpreter's memory of the prover. If these would have been sent over the network, the verifier would need to read them and that would be expensive.