In [1]:
from zk_adventures_types import F, Polynomial
from common import Sha3_256Transcript, Transcript
from common import Domain, Oracle, NaiveOracle
from typing import List, Set

# Wiring satisfiability

In this part we'll be covering the part of Plonk's protocol that guarantees that the values of the trace that correspond to the same variables take the same values across all triplets. For example, suppose our circuit is the one that computes XOR between two binary variables. There are two input variables $x$ and $y$, and one output variable $z$. The coefficients for the equations of this circuit are: 

| L | R | M | O | C |
| ---:| :----   |:----- |: ---- |:----- |
| 1     | 0     | -1    | 0     | 0     |
| 1     | 0     | -1    | 0     | 0     |
| 1     | 0     | -1    | 0     | 0     |
| 1     | 1     | -2    | -1    | 0     |

And the intended solutions are the ones where $x$, $y$ and $z$ are either $0$, $1$ and such that $\text{ XOR }(x,  y) = z$.
That is, the solutions we want are the triplets of the form (where `-` means the value is not important)


| A | B | C |
| ---:| :----   |:----- |
| x     | x     | -    |
| y     | y     | -    |
| z     | z     | -    |
| x     | y     | z    |

These are all solution triplets to the above equations. But there are more triplets that solve them that are not of that form. For example

| A | B | C |
| ---:| :----   |:----- |
| 0     | 1     | -    |
| 0     | 0     | -    |
| 1     | -1     | -    |
| 1     | 1     | 0    |

This solution does not correspond to an assignment of values to $x$, $y$ and $z = \text{ XOR }(x, y)$. The `EquationSatisfiabilityVerifier` of the previous part would accept this solution. This is because that protocol only enforces each triplet to be a solution to its corresponding equation. We need a way to enforce that values are consistent across triplets in the way we need, invalidating solutions like the above.

In this part we'll see how this is achived. The idea behind this stems from a protocol for a seemingly unrelated thing. Let's start with that and later see how it is used to solve this problem.

# Shuffle satisfiability protocol

In this section we'll see the a protocol with which a prover can prove that she has a private vector of values that is a shuffle of a public vector. More precisely, suppose there's a vector $V = (v_0, \dots, v_{N-1}) \in\mathbb{F}^N$ that's known to both a prover and a verifier. Suppose the prover holds in addition a private vector $W = (w_0, \dots, w_{N-1})$. The prover and the verifier can engage in a protocol that would succintly convince the verifier that the vector $W$ is a shuffle of the vector $V$. With this protocol, the number of operations the verifier has to perform is independent of $N$. It is constant.

In [2]:
# [TEST]

# Public vector of field elements
V = [49, 56, 7, 15, 56, 56, 56, 49, 7, 49, 15, 15, 56, 15, 7, 15]

# Private vector of field elements known only to the prover
W = [15, 15, 7, 56, 15, 15, 56, 56, 7, 56, 56, 49, 49, 15, 49, 7]

# The vector V is a shuffle of W
assert(sorted(V) == sorted(W))

#### Main idea
The idea will be that the vector $V$ is a shuffle of the vector $W$ with high probability if for a random coefficient $\alpha$ the product $(v_0 + \alpha)\cdots(v_N + \alpha)$ equals $(w_0 + \alpha)\cdots(w_N + \alpha)$. This an application of the Schwartz-Zippel lemma (think about it!).

#### A handy kind of polynomials
Any polynomial $Z$ that satisfies $$Z(\omega^i)\cdot(v_i + \alpha) = Z(\omega^{i+1})\cdot(w_i + \alpha),$$ for all $i=0,\dots, N-1$ also satisfies $$Z(\omega^{i+1}) = Z(1)\frac{(v_0 + \alpha)\cdots(v_i + \alpha)}{(w_0 + \alpha)\cdots(w_i + \alpha)}$$ for $i=0,\dots, N-1$.
In particular, for $i=N-1$ this last equality implies $$Z(1) = Z(1)\frac{(v_0 + \alpha)\cdots(v_N + \alpha)}{(w_0 + \alpha)\cdots(w_N + \alpha)}.$$
There are only two cases. Either $Z(1) = 0$ or $Z(1) \neq 0$  🦆. The latter case implies $(v_0 + \alpha)\cdots(v_N + \alpha) = (w_0 + \alpha)\cdots(w_N + \alpha).$

#### Summary
So, putting this all together we obtain the following claim.

*Claim:* The vector $W$ is a shuffle of the vector $V$ with high probability if for a random $\alpha$ there exists a polynomial $Z$ such that:
1. $Z(\omega^i)\cdot(v_i + \alpha) = Z(\omega^{i+1})\cdot(w_i + \alpha)$ for all $i=0,\dots,N-1$.
2. $Z(1) \neq 0$.

It turns out that when $W$ is a shuffle of $V$, such polynomial $Z$ always exists. We can construct it by interpolating the cumulative products.

In [3]:
def construct_Z_polynomial(V: List[int], W: List[int], random_coeff: int, domain: Domain):
    """
    Returns the polynomial Z of least degree such that
        * Z(1) = 1, and
        * Z(𝜔ⁱ) = ((V₀ + 𝛼)⋅⋅⋅(Vᵢ₋₁ + a)) / ((Wₒ + 𝛼)⋅⋅⋅(Wᵢ₋₁ + 𝛼)), for all i = 1, ..., N-1, 
    where 𝛼 is `random_coeff`, 𝜔 is `domain[1]` and N is the size of `domain`
    """
    products = [ 1 ]
    for i in range(1,len(V)):
        products.append( products[i-1] * ( (V[i-1] + random_coeff) / (W[i-1] + random_coeff) ) )
    return Polynomial.lagrange_polynomial(zip(domain, products))

In [4]:
# [TEST]

domain = Domain.of_size(len(V))
random_coeff = F(0xdeadbeef)
Z = construct_Z_polynomial(V, W, random_coeff, domain)
assert(Z(0xcafecafe) == 18437)

In [5]:
# [TEST]

# A polynomial Z with Z(1) != 0
domain = Domain.of_size(len(V))
random_coeff = F.random_element()
Z = construct_Z_polynomial(V, W, random_coeff, domain)
omega = domain[1]

for i, d in enumerate(domain):
    assert(Z(d) * (V[i] + random_coeff) == Z(omega * d) * (W[i] + random_coeff))
    
assert(Z(1) != 0)

In [6]:
# [TEST]

# A polynomial Z with Z(1) == 0
X = Polynomial.monomial(1)
Z = (X ** len(domain) - 1) * Polynomial.random_element((5, 10))

for i, d in enumerate(domain):
    assert(Z(d) * (V[i] + random_coeff) == Z(omega * d) * (W[i] + random_coeff))
    
assert(Z(1) == 0)

#### Preparing for Schwarz-Zippel
The equations $Z(\omega^i)\cdot(v_i + \alpha) = Z(\omega^{i+1})\cdot(w_i + \alpha)$ for all $0=1,\dots,N-1$ can be expressed as a single polynomial equality using the vanishing polynomial of the domain: $X^N - 1$.

In [7]:
def f(A, B, C, D):
    return A * B - C * D

In [8]:
# [TEST]

v = Polynomial.lagrange_polynomial(zip(domain, V))
w = Polynomial.lagrange_polynomial(zip(domain, W))

X = Polynomial.monomial(1)
A = Z
B = v + random_coeff
C = Z(omega * Polynomial.monomial(1))
D = w + random_coeff
p = f(A, B, C, D) # COMPLETE HERE
assert(p % (X ** len(domain) - 1) == 0)

t = p / (X ** len(domain) - 1)
assert(p == t * (X ** len(domain) - 1))

### The protocol

We are ready to write the Prover and the Verifier of this protocol. Just to remind, this is a protocol where:
1. There is a public vector $V$ known to both the prover and the verifier
2. The prover claims to hold a vector $W$ that's a shuffle of $V$.
3. The verifier get's convinced in constant time independent of the sizes of $V$ and $W$ (assuming the existence of an Oracle and a one-time setup phase for preprocessed inputs).

#### Diagram

| Step  | Alice                                          | Bob                                            |
|-------|------------------------------------------------|------------------------------------------------|
| S1    |Interpolates W over a domain and obtains a polynomial $w$                        |                                                |
|    |Sends an oracle $[w]$ to Bob                        |                                                |
| S2      |                                                |                       Chooses random coefficient $\alpha$     |
|       |                                                |                       Sends $\alpha$ to Alice         |
| S3  |    Constructs $Z$ using $\alpha$       |                    |
|     |    Computes $t = (Z(X) (v + \alpha) - Z(\omega X) (w + \alpha) \,/\, (X^N - 1)$      |                    |
|       |    Sends oracles $[Z]$ and $[t]$ to Bob                                             |                             |
| S4    |                   |   Chooses a random challenge $\zeta$                                             |
|       |                   |   Computes $a := Z(\zeta)$                                             |
|       |                   |   Computes $b := v(\zeta) + \alpha$                                             |
|       |                   |   Computes $c := Z(\zeta \omega)$                                             |
|       |                   |   Computes $d := w(\zeta) + \alpha$                                             |
|       |                   |   Computes $e := t(\zeta) \cdot (\zeta^N-1)$                                     |
|       |                   |   Checks that $ab-cd = e$                                     |
|       |                   |   Checks that $Z(1)$ is not zero                                     |


We'll need a polynomial oracle. Let's use a naive one like before. This time we'll need to ask more than one query, so we allow that.

In [9]:
import ctypes

class MultiQueryNaiveOracle(Oracle):
    def __init__(self, polynomial: Polynomial, num_queries: int = -1):
        self.__polynomial = polynomial
        self.__remaining_queries = num_queries
    
    def query(self, z):
        """
        Returns the value of the polynomial at `z`.
        """
        if self.__remaining_queries == 0:
            raise ValueError("No more queries left.")
        self.__remaining_queries -= 1
        return self.__polynomial(z)
    
    def __hash__(self):
        """Hash polynomial coefficients to an unsigned integer"""
        return ctypes.c_uint32(hash(tuple(self.__polynomial))).value

In [10]:
def interpolate_values(domain: Domain, values: List[int]) -> Polynomial:
    return Polynomial.lagrange_polynomial(zip(domain, values))

In [18]:
from dataclasses import dataclass
from math import ceil
from random import randint

@dataclass
class PreprocessedInput:
    # Common domain
    domain: Domain
    # public values
    V: list
    # interpolant of the public values over the domain
    v: Polynomial
    # oracle of `v`
    oracle_v: Oracle

class ShuffleSetup:
    @staticmethod
    def setup(values, log_domain_size) -> PreprocessedInput:
        """
        Computes all the relevant objects that only depend on the public values.
        This is: the common domain, the interpolant of `values` over the domain, and
        the oracle of that polynomial
        """
        domain = Domain.of_size(2 ** log_domain_size)
        poly = interpolate_values(domain, values)
        oracle = MultiQueryNaiveOracle(poly)
        return PreprocessedInput(domain=domain, V=values, v=poly, oracle_v=oracle)

class ShuffleProver:
    @staticmethod
    def simulate_send_oracle(oracle: Oracle, transcript: Transcript):
        """
        Simulates sending an oracle by adding the big endian representation of `hash(oracle)`
        to the transcript.
        """
        oracle_hash_int = hash(oracle)
        byte_length = ceil(oracle_hash_int.bit_length() / 8)
        oracle_hash_bytes = oracle_hash_int.to_bytes(byte_length, "big")
        transcript.append(oracle_hash_bytes)
        
    def prove(self, private_values, preprocessed_input: PreprocessedInput, transcript: Transcript):

        d = preprocessed_input.domain

        w = Polynomial.lagrange_polynomial(zip(d,private_values))
        oracle_w = MultiQueryNaiveOracle(w)
        ShuffleProver.simulate_send_oracle(oracle_w, transcript)
        
        r = ShuffleVerifier.simulate_send_challenge(transcript)
        Z = construct_Z_polynomial(preprocessed_input.V, private_values, r, d)
        oracle_Z = MultiQueryNaiveOracle(Z)

        omega = d[1]
        N = len(d)
        t = (Z * (preprocessed_input.v + r) - (Z(Polynomial.monomial(1) * omega) + r) * w )// (Polynomial.monomial(1) ** N - 1)
        oracle_t = MultiQueryNaiveOracle(t)
        
        return (oracle_w, oracle_Z, oracle_t)
        
class ShuffleVerifier:
    @staticmethod
    def simulate_send_challenge(transcript: Transcript):
        """
        Simulates sending a random challenge by sampling bytes from the transcript
        and interpreting them as an integer in big endian
        """
        p = F.order()
        return int.from_bytes(transcript.sample(), "big") % p

    def verify(self, proof, preprocessed_input, transcript: Transcript):
        (oracle_w, oracle_Z, oracle_t) = proof
        ShuffleProver.simulate_send_oracle(oracle_w, transcript)
        r = ShuffleVerifier.simulate_send_challenge(transcript)
        omega = preprocessed_input.domain[1]
        N = len(preprocessed_input.domain)
        oracle_v = preprocessed_input.oracle_v
        z = randint(1, F.order()-1)
        equal = oracle_Z.query(z) * (oracle_v.query(z) + r) - oracle_Z.query(omega * z) * (oracle_w.query(z) + r) == oracle_t.query(z) * (z ** N - 1)
        not_zero = oracle_Z.query(1) != 0 
        return equal and not_zero

In [19]:
# [TEST]

from random import shuffle

public_values = [49, 56, 7, 15, 56, 56, 56, 49, 7, 49, 15, 15, 56, 15, 7, 15]
preprocessed_input = ShuffleSetup.setup(public_values, 4)

prover = ShuffleProver()
private_values = [15, 15, 7, 56, 15, 15, 56, 56, 7, 56, 56, 49, 49, 15, 49, 7]
transcript = Sha3_256Transcript(b"1234")
proof = prover.prove(private_values, preprocessed_input, transcript)

verifier = ShuffleVerifier()
transcript = Sha3_256Transcript(b"1234")
assert(verifier.verify(proof, preprocessed_input, transcript))

In [20]:
# [TEST]

Ws = [
    [7, 7, 49, 15, 56, 49, 15, 56, 49, 56, 15, 7, 15, 56, 56, 15],
    [7, 7, 15, 49, 15, 56, 56, 56, 56, 15, 56, 15, 49, 7, 15, 49],
    [56, 56, 56, 15, 15, 15, 56, 49, 15, 7, 15, 7, 49, 49, 7, 56],
    [56, 7, 7, 15, 49, 15, 56, 7, 15, 49, 15, 56, 15, 15, 56, 56],
    [15, 56, 56, 49, 7, 56, 56, 15, 7, 7, 15, 49, 15, 56, 15, 49],
    [49, 7, 7, 56, 49, 15, 56, 15, 49, 56, 7, 15, 56, 15, 15, 56],
    [15, 49, 7, 15, 15, 56, 56, 49, 7, 49, 7, 56, 56, 56, 15, 15], 
    [15, 49, 56, 49, 15, 56, 15, 56, 56, 56, 15, 7, 49, 7, 15, 49],
    [56, 7, 56, 7, 15, 49, 15, 49, 15, 15, 49, 56, 15, 7, 56, 56],
    [56, 49, 15, 49, 56, 15, 7, 15, 49, 56, 15, 56, 7, 7, 56, 7],
]

ground_truth = [True, True, True, False, True, True, True, False, True, False]

from functools import reduce
init_bytes = [int(coeff).to_bytes(8, "big") for coeff in preprocessed_input.V]
init_bytes = bytes([o for part in init_bytes for o in part])
preprocessed_input = ShuffleSetup.setup(public_values, 4)

proofs = [
    prover.prove(W, preprocessed_input, Sha3_256Transcript(init_bytes)) 
    for W in Ws
]

for proof, expected_result in zip(proofs, ground_truth):
    transcript = Sha3_256Transcript(init_bytes)
    result = verifier.verify(proof, preprocessed_input, transcript)
    assert(result == expected_result)

In [21]:
# [TEST]

q = Polynomial.zero()
proof = (MultiQueryNaiveOracle(q, 1), MultiQueryNaiveOracle(q, 3), MultiQueryNaiveOracle(q, 1))
assert(not verifier.verify(proof, preprocessed_input, Sha3_256Transcript(init_bytes)))

## Wirings

The ideas of the shuffle proving protocol can be used to prove the wiring contraints we need for Plonk.

In this context we want to produce constraints that guarantee that a matrix of value sin $\mathbb{F}_p$ has a specific shape. In the example at the beginning of this notebook we wanted to restrict to matrices of the form

| A | B | C |
| ---:| :----   |:----- |
| x     | x     | -    |
| y     | y     | -    |
| z     | z     | -    |
| x     | y     | z    |

To simplify, instead of working with matrices with 3 columns, let's work with a single vector. So, the vectors we want to restrict to in the example are the vectors of the form $(x, x, y, y, z, z, x, y, z)$. What do you mean by this? We want a protocol in which one party (the prover) has a vector of that shape, say 

$$W = (77, 77, 83294, 83294, 1283, 1283, 77, 83294, 1283)$$

and wants to prove another party (the verifier) that the vector has that shape, without revealing the actual values. This by itself is not very useful, since the vector of all zeroes has that shape, but it will be interesting when combined with the equation satisfiability protocol. Being in possesion of a matrix that has the correct shape and solves a set of equations is in fact a proof of execution of a program.

But let's go back to the simple version of the protocol to prove only that a vector is of the correct shape.

The shape can be described by an array of indices, called a mask, and a vector of the correct shape is one where the same value sits for places where the same index is. In the example the mask would be 

$$M = (0, 0, 1, 1, 2, 2, 0, 1, 2).$$

A vector $W$ is a of the correct shape if $W[i] = W[j]$ if $M[i] = M[j]$.

By an implementation reason, we'll only consider vectors of length a power of two. So let's extend our vectors with zeros and the solution vectors with the first value to make them valid. So the mask is now 

$$M = (0, 0, 1, 1, 2, 2, 0, 1, 2, 0, 0, 0, 0, 0, 0, 0)$$

and a solution would be for example

$$W = (77, 77, 83294, 83294, 1283, 1283, 77, 83294, 1283, 77, 77, 77, 77, 77, 77, 77)$$



In [30]:
# [TEST]

mask = [0, 0, 1, 1, 2, 2, 0, 1, 2, 0, 0, 0, 0, 0, 0, 0]

In [35]:
import itertools

def is_of_correct_shape(W, mask):
    """
    Returns True if `mask[i] == mask[j]` implies `W[i] == W[j]`, or False otherwise.
    """
    N = range(len(mask))
    for i,j in itertools.product(N,N):
        if((mask[i] == mask[j]) and (W[i] != W[j])):
            return False
    return True


In [36]:
# [TEST]

W_invalid = [77, 77, 83294, 83294, 1283, 1283, 77, 83294, 77, 77, 77, 77, 77, 77, 77, 77]
assert(not is_of_correct_shape(W_invalid, mask))

W = [77, 77, 83294, 83294, 1283, 1283, 77, 83294, 1283, 77, 77, 77, 77, 77, 77, 77]
assert(is_of_correct_shape(W, mask))

There's a another way to decide whether a vector $W = (w_0, \dots, w_{n-1})$ is a solution which is more suited for applying the ideas of the shuffle protocol. The idea will be to construct from $W$ and $M$ two vectors $V_1$ and $V_2$ such that $V_1$ is a shuffle of $V_2$ if and only if $W$ is of the correct shape according to $M$.

To construct those vectors we'll need to construct off of $M$ a partition $\Pi$ of the set of indices $[0, \dots, n-1]$, and from that partition a permutation of the same set of indices.

### Constructing the partition

The mask $M$ defines a partition of the set of indices: the subets of indices that share the same value in the mask. In our example, where the mask is $(0, 0, 1, 1, 2, 2, 0, 1, 2, 0,0,0,0,0,0,0)$, the partition is


$$\Pi = \{\{0, 1, 6, 9, 10, 11, 12, 13, 14, 15\}, \{2, 3, 7\}, \{4, 5, 8\}\}$$

In [39]:
class Partition:
    def __init__(self, parts: List[Set[int]]):
        self._parts = parts
    
    @classmethod
    def from_mask(cls, mask):
        """
        Returns the partition of indexes determined by `mask`.
        For example, if `mask = [0,0,1,2,0,1,0]`, then 
        the associated partition is `[{0,1,4,6}, {2,5}, {3}]`
        """
        from collections import defaultdict
        parts = defaultdict(set)
        for i, value in enumerate(mask):
            parts[value].add(i)
        return cls(list(dict(parts).values()))
    
    def __iter__(self):
        return iter(self._parts)
    
    def __eq__(self, other):
        return sorted(map(sorted, self._parts)) == sorted(map(sorted, other._parts)) 

In [40]:
# [TEST]

mask = [0, 0, 1, 1, 2, 2, 0, 1, 2, 0, 0, 0, 0, 0, 0, 0]
expected_partition = Partition([
    [0, 1, 6, 9, 10, 11, 12, 13, 14, 15],
    [2, 3, 7],
    [4, 5, 8]
])

partition = Partition.from_mask(mask)
assert(partition == expected_partition)

### Constructing the permutation

Given that partition, let's construct a permutation (a bijective function) of the indices whose cycles is the partition. There are many of them, any would do.

In the example, one such function is $\sigma: \{0, 1, \dots, 15\} \to \{0, 1, \dots, 15\}$ with $\sigma(2) = 3$, $\sigma(3) = 7$, $\sigma(7) = 2$. And similar for the other partitions.

In [41]:
class Permutation:
    def __init__(self, values):
        length = len(values)
        if set(values) != set(range(length)):
            raise ValueError
        self._values = values
        
    @classmethod
    def from_partition(cls, partition: Partition):
        """
        Constructs a permutation whose set of cycles is the given partition.
        """
        values = [-1] * len(mask)
        for part in partition:
            part = list(part)
            for i in range(len(part)):
                values[part[i]] = part[(i + 1) % len(part)]
        return cls(values)
    
    @classmethod
    def from_mask(cls, mask):
        return cls.from_partition(Partition.from_mask(mask))
    
    def __call__(self, i):
        if i not in range(len(self._values)):
            raise ValueError
        return self._values[i]

In [42]:
# [TEST]

def cycle_of(index, σ: Permutation):
    """
    Returns the cycle of `index`.
    """
    cycle = set()
    while index not in cycle:
        cycle.add(index)
        index = σ(index)
    return set(cycle)

σ = Permutation.from_partition(partition)

assert(partition == Partition([cycle_of(0, σ), cycle_of(2, σ), cycle_of(4, σ)]))

### Constructing the vectors

With all this machinery, we can check whether a vector $W = (w_0, w_1, \dots, w_n)$ is of the correct shape according to the mask $M$ by checking if the following two vectors are shuffles of each other:
$$
\begin{aligned}
&A = ((0, w_0), (1, w_1), (2, w_2), \dots, (n, w_n)) \\
&B = ((\sigma(0), w_0), (\sigma(1), w_1), (\sigma(2), w_2), \dots, (\sigma(n), w_n))
\end{aligned}
$$

That is, $A$ is a shuffle of $B$ if and only if $W$ is a vector of the correct shape according to $M$.

In [43]:
# [TEST]

pairs_sorted = [(i, w) for i, w in enumerate(W)] # = list(enumerate(W))
pairs_permuted = [(σ(i), w) for i, w in enumerate(W)]

In [44]:
# [TEST]

assert(pairs_sorted == [
    (0, 77),
    (1, 77),
    (2, 83294),
    (3, 83294),
    (4, 1283),
    (5, 1283),
    (6, 77),
    (7, 83294),
    (8, 1283),
    (9, 77),
    (10, 77),
    (11, 77),
    (12, 77),
    (13, 77),
    (14, 77),
    (15, 77)
])

assert(pairs_permuted == [
    (1, 77),
    (6, 77),
    (3, 83294),
    (7, 83294),
    (5, 1283),
    (8, 1283),
    (9, 77),
    (2, 83294),
    (4, 1283),
    (10, 77),
    (11, 77),
    (12, 77),
    (13, 77),
    (14, 77),
    (15, 77),
    (0, 77)
])

# `pairs_sorted` is a shuffle of `pairs_permuted`
assert(sorted(pairs_sorted) == sorted(pairs_permuted))

### Flattening the vectors

To be able to use the techniques of the Shuffle satisfiability protocol we need to transform the vectors $A$ and $B$ into vectors in $\mathbb{F}^n$. One way of doing that is to make a random linear combination of their coordinates. Suppose $\beta\in\mathbb{F}$ is a random element. We can map the pair $(i, w)$ to the element $\beta\cdot i + w$. But the way this is done in Plonk is a little different. The element $(i, w)$ is mapped to $\beta \omega^i + w$. Then, mapping every element of $A$ and $B$ this way we obtain $V_1$ and $V_2$ respectively.

$$
\begin{aligned}
&V_1 = (\beta\omega^0 + w_0, \beta \omega^1 + w_1, \beta \omega^2 + w_2, \dots, \beta \omega^{n-1} + w_{n-1}) \\
&V_2 = (\beta \omega^{\sigma(0)} + w_0, \beta \omega^{\sigma(1)} + w_1, \beta \omega^{\sigma(2)} + w_2, \dots, \beta \omega^{\sigma(n-1)} + w_{n-1})
\end{aligned}
$$

Let $D$ be the domain $(1, \omega, \omega^2, \dots, \omega^{n-1})$. Then there's a shorthand notation for $V_1$ and $V_2$

$$
\begin{aligned}
&V_1 = \beta D + W\\
&V_2 = \beta \sigma(D) + W,
\end{aligned}
$$
where $\sigma(D)$ means the permuted domain $\sigma(D) = (\omega^{\sigma(0)}, \omega^{\sigma(1)}, \dots, \omega^{\sigma(n-1)})$

There's a very tiny probability that $B$ was not a shuffle of $A$ but $V_1$ ends up being a shuffle of $V_2$, just by chance and the effect of flattening. So things start to be probabilistic from now on.

**In summary**, if for a random $\beta\in\mathbb{F}$ the vector $\beta D + W$ is a shuffle of $\beta\sigma(D) + W$, then $W$ is of the correct shape with high probability according to the mask that defines $\sigma$.

In [47]:
def construct_V1_and_V2(domain, σ: Permutation, W: List[int], random_coeff):
    """
    Returns the vectors 
        `V_1 := domain * random_coeff + W` and,
        `V_2 := σ(domain) * random_coeff + W`,
    """    
    V_1 = [ domain[i] * random_coeff + W[i] for i in range(len(W)) ] 
    V_2 = [ domain[σ(i)] * random_coeff + W[i] for i in range(len(W)) ] 
    return V_1, V_2

In [48]:
# [TEST]

random_coeff = F(0xcafe)
domain = Domain.of_size(len(W))
V1, V2 = construct_V1_and_V2(domain, σ, W, random_coeff)

assert(V1[5] == 47043)
assert(V2[5] == 14854)

assert(sorted(V1) == sorted(V2))

### The protocol
Public to both the Prover and the Verifier is a mask $M$ and the permutation $\sigma$ derived from it. 

| Step  | Prover                                          | Verifier                                            |
|-------|------------------------------------------------|------------------------------------------------|
| S1    |Interpolates W over a domain and obtains a polynomial $w$                        |                                                |
|    |Sends an oracle $[w]$ to Bob                        |                                                |
| S2      |                                                |                       Chooses random coefficients $\alpha$ and $\beta$     |
|       |                                                |                       Sends $\alpha$ and $\beta$ to the Prover         |
| S3  |    Constructs $V_1 = \beta D + W$ and $V_2 = \beta\sigma(D) + W$       |                    |
|   |    Constructs $Z$ for $V_1$ and $V_2$ using $\alpha$       |                    |
|     |    Computes $t = (Z(X) (v_1 + \alpha) - Z(\omega X) (v_2 + \alpha) / (X^N - 1)$      |                    |
|       |    Sends oracles $[Z]$ and $[t]$ to the Verifier                                             |                             |
| S4    |                   |   Chooses a random challenge $\zeta$                                             |
|       |                   |   Computes $a := Z(\zeta)$                                             |
|       |                   |   (*) Computes $b := v_1(\zeta) + \alpha$                                             |
|       |                   |   Computes $c := Z(\zeta \omega)$                                             |
|       |                   |   (*) Computes $d := v_2(\zeta) + \alpha$                                             |
|       |                   |   Computes $e := t(\zeta) \cdot (\zeta^N-1)$                                     |
|       |                   |   Checks that $ab-cd = e$                                     |
|       |                   |   Checks that $Z(1)$ is not zero                                     |

After all this ceremony the Verifier get's convinced that the oracle $[w]$ he received at the beginning interpolates some values $W$ such that $\alpha D + W$ is a shuffle of $\alpha\sigma(D) + W$. So $W$ is of the correct shape according to the public mask $M$ with high probability.

(*) The values $v_1(\zeta)$ and $v_2(\zeta)$ can be computed from $[w](\zeta)$, the domain and the permutation. All of this is available to the verifier.

### (*) Evaluating $v_1(\zeta)$ and $v_2(\zeta)$

Suppose you are the verifier and have received an oracle $[w]$ of some polynomial $w$. You also know the permutation $\sigma$ and the random coefficient $\beta$ used by the prover to build $V_1$ and $V_2$. You also know the domain of interpolation. Let $v_1$ and $v_2$ be the interpolants of $V_1$ and $V_2$. You sample a random element $\zeta \in \mathbb{F}$. Your goal is to find the values of $v_1(\zeta)$ and $v_2(\zeta)$ only from $[w]$ and $\sigma$.

In [57]:
# [TEST]

# You know this
mask = [0, 0, 1, 1, 2, 2, 0, 1, 2, 0, 0, 0, 0, 0, 0, 0]
permutation = Permutation.from_mask(mask)
domain = Domain.of_size(len(mask))
beta = F(0xcafe)

# You don't know this
W = [77, 77, 83294, 83294, 1283, 1283, 77, 83294, 1283, 77, 77, 77, 77, 77, 77, 77]
V_1, V_2 = construct_V1_and_V2(domain, permutation, W, beta)
w = interpolate_values(domain, W)
v_1 = interpolate_values(domain, V_1)
v_2 = interpolate_values(domain, V_2)

# You know this
zeta = F(0xdeadbeef)
oracle_w = MultiQueryNaiveOracle(w, 1)

In [58]:
def evaluate_v1_and_v2(domain, permutation, oracle_w, beta, zeta):
    """
    Returns the evaluations of the interpolants of V_1 and V_2 at `zeta`,
    where
        `V_1 := domain * random_coeff + W` and,
        `V_2 := permutation(domain) * random_coeff + W`,
    """
    permutated_domain = [ domain[permutation(i)] for i in range(len(domain)) ] 
    p = Polynomial.lagrange_polynomial(zip(domain, permutated_domain))

    w_z = oracle_w.query(zeta)
    eval_v1 = zeta * beta + w_z
    eval_v2 = p(zeta) * beta + w_z
    return eval_v1, eval_v2

In [59]:
# [TEST]

eval_v1, eval_v2 = evaluate_v1_and_v2(domain, permutation, oracle_w, beta, zeta)

assert(eval_v1 == v_1(zeta))
assert(eval_v2 == v_2(zeta))

## Prover and Verifier

In [64]:
from dataclasses import dataclass
from math import ceil
from collections import defaultdict

@dataclass
class PreprocessedInput:
    domain: Domain
    permutation: Permutation
    permutation_polynomial: Polynomial
    permutation_oracle: Oracle

class MaskConsistencySetup:
    @staticmethod
    def setup(mask, log_domain_size) -> NaiveOracle:
        """
        Computes the preprocessed input that is derived from all the public values.
        That is, the common interpolation domain, the permutation associated with the mask,
        the polynomial that interpolates the values of the permuted domain, and its oracle.
        """
        domain = Domain.of_size(2 ** log_domain_size)
        permutation = Permutation.from_mask(mask)
        polynomial = Polynomial.lagrange_polynomial( zip( domain, [domain[permutation(i)] for i in range(len(domain))] ) ) 
        oracle = MultiQueryNaiveOracle(polynomial)
        return PreprocessedInput(domain, permutation, polynomial, oracle)

In [65]:
# [TEST]

mask = [0, 0, 1, 1, 2, 2, 1, 2, 3, 0, 0, 0, 0, 0, 0, 0]
public_input = MaskConsistencySetup.setup(mask, 4)

domain = public_input.domain
permutation = public_input.permutation

assert(public_input.domain[2] == 4096)
assert(public_input.permutation(4) == 5)
assert(public_input.permutation_polynomial(domain[5]) == domain[permutation(5)])
assert(public_input.permutation_polynomial(F(0xcafe)) == public_input.permutation_oracle.query(F(0xcafe)))

In [88]:
@dataclass
class Proof:
    oracle_w: Oracle
    oracle_Z: Oracle
    oracle_t: Oracle

class MaskConsistencyProver:
    @staticmethod
    def simulate_send_oracle(oracle: Oracle, transcript: Transcript):
        oracle_hash_int = hash(oracle)
        byte_length = ceil(oracle_hash_int.bit_length() / 8)
        oracle_hash_bytes = oracle_hash_int.to_bytes(byte_length, "big")
        transcript.append(oracle_hash_bytes)
        
    def prove(self, private_values, preprocessed_input: PreprocessedInput, transcript: Transcript):

        d = preprocessed_input.domain
        w = Polynomial.lagrange_polynomial(zip(d,private_values))
        oracle_w = MultiQueryNaiveOracle(w)
        MaskConsistencyProver.simulate_send_oracle(oracle_w, transcript)

        beta = MaskConsistencyVerifier.simulate_send_challenge(transcript)
        alfa = MaskConsistencyVerifier.simulate_send_challenge(transcript)

        V_1, V_2 = construct_V1_and_V2(d, preprocessed_input.permutation, private_values, beta)

        Z = construct_Z_polynomial(V_1, V_2, alfa, d)
        oracle_Z = MultiQueryNaiveOracle(Z)

        X = Polynomial.monomial(1)
        v_1 = X * beta + w

        S = preprocessed_input.permutation_polynomial
        v_2= S * beta + w

        omega = d[1]
        N = len(d)
        p = f(Z, v_1 + alfa, Z(omega * X), v_2 + alfa)
        t = p // (Polynomial.monomial(1) ** N - 1)
        oracle_t = MultiQueryNaiveOracle(t)
        return Proof(oracle_w, oracle_Z, oracle_t)

class MaskConsistencyVerifier:
    @staticmethod
    def simulate_send_challenge(transcript: Transcript):
        p = F.order()
        return int.from_bytes(transcript.sample(), "big") % p

    def verify(self, proof, preprocessed_input, transcript: Transcript):

        domain = preprocessed_input.domain
        permut_poly = preprocessed_input.permutation_polynomial
        oracle_w = proof.oracle_w
        oracle_Z = proof.oracle_Z
        oracle_t = proof.oracle_t
        
        MaskConsistencyProver.simulate_send_oracle(oracle_w, transcript)
        beta = MaskConsistencyVerifier.simulate_send_challenge(transcript)
        alfa = MaskConsistencyVerifier.simulate_send_challenge(transcript)
        
        omega = preprocessed_input.domain[1]
        N = len(preprocessed_input.domain)

        zeta = F.random_element()

        eval_w = oracle_w.query(zeta)
        eval_v1 = zeta * beta + eval_w
        eval_v2 = permut_poly(zeta) * beta + eval_w
        eval_f = f(oracle_Z.query(zeta), eval_v1 + alfa, oracle_Z.query(omega * zeta), eval_v2 + alfa)
        
        equal = eval_f == oracle_t.query(zeta) * (zeta ** N - 1)
        not_zero = oracle_Z.query(1) != 0 
        return equal and not_zero


In [89]:
# [TEST]

mask = [0, 0, 1, 1, 2, 2, 1, 2, 3, 0, 0, 0, 0, 0, 0, 0]
public_input = MaskConsistencySetup.setup(mask, 4)

private_values = [77, 77, 83294, 83294, 1283, 1283, 77, 83294, 1283, 77, 77, 77, 77, 77, 77, 77]
transcript = Sha3_256Transcript(b"1234")
proof = MaskConsistencyProver().prove(private_values, public_input, transcript)

assert(proof.oracle_w.query(0xdeadbeef) == 6039)
assert(proof.oracle_Z.query(0xdeadbeef) == 6030)
assert(proof.oracle_t.query(0xdeadbeef) == 15349)

In [90]:
# [TEST]

mask = [0, 0, 1, 1, 2, 2, 0, 1, 2, 0, 0, 0, 0, 0, 0, 0]
public_input = MaskConsistencySetup.setup(mask, 4)

elements = [F.random_element(), F.random_element(), F.random_element()]
private_values = [elements[i] for i in mask]
transcript = Sha3_256Transcript(b"1234")
proof = MaskConsistencyProver().prove(private_values, public_input, transcript)

verifier = MaskConsistencyVerifier()
transcript = Sha3_256Transcript(b"1234")
assert(verifier.verify(proof, public_input, transcript))