In [1]:
import galois
import numpy as np

In [2]:
class PlonkEnvironment:

    def __init__(self, prime: int, degree: int):
        self.GF = galois.GF(prime)
        self.d = degree
        # Find primitive root of unity
        self.w = self.GF.primitive_root_of_unity(self.d)
        # Generate the domain Omega: {1, w, w^2, ..., w^(d-1)}
        self.omega_domain = self.GF([self.w**i for i in range(self.d)])

    def lagrange_interpolate(self, x_values, y_values):
        """Wrapper for Lagrange interpolation."""
        return galois.lagrange_poly(x_values, y_values)

    def get_element(self, power):
        """Returns w^power."""
        return self.w ** power

In [3]:
class CircuitDefinition:
    """
    Holds the specific data for the circuit: (x_1+x_2)(x_2+w_1).
    Contains the Computational Trace, Selector config, and Wiring.
    """
    def __init__(self, env: PlonkEnvironment):
        self.env = env
        
        # 1. The Raw Trace (L, R, O) serialized into a single list
        # Standardized Trace Y indices: 0..11 mapping to w^0..w^11.
        # w^0..w^2 (Gate 0), w^3..w^5 (Gate 1), w^6..w^8 (Gate 2), w^9..w^11 (Inputs/Padding)
        self.trace_values = self.env.GF([5, 6, 11, 6, 1, 7, 11, 7, 77, 1, 6, 5])

        # 2. Public Inputs
        # The values that the public verifier checks against (last 2 elements of trace)
        self.public_inputs = self.env.GF([6, 5])
        
        # 3. Selector Configuration (1 for Add, 0 for Mul)
        # Gate 0 (Add), Gate 1 (Add), Gate 2 (Mul)
        self.selector_values = self.env.GF([1, 1, 0])
        
        # 4. Wiring Permutation (Sigma/W)
        # Maps the domain elements to reflect the circuit wiring
        self.wiring_permutation = self.env.GF([
            1, 112, 85, 109, 81, 63, 192, 144, 108, 84, 49, 130
        ])

    @property
    def gate_indices(self):
        """Returns the domain indices corresponding to the start of each gate."""
        # Gates start at 0, 3, 6 (w^0, w^3, w^6)
        return self.env.omega_domain[0:9:3]

    @property
    def input_domain(self):
        """Returns the domain elements reserved for inputs (last 2)."""
        return self.env.omega_domain[-2:]

In [6]:
class PlonkProtocol:
    """
    Executes the PLONK-style protocol steps:
    1. Interpolate Trace Polynomial T(X)
    2. Check Input Constraints
    3. Check Gate Constraints
    4. Check Permutation/Wiring Constraints
    """
    def __init__(self, env: PlonkEnvironment, circuit: CircuitDefinition):
        self.env = env
        self.circuit = circuit
        self.T_X = None # Trace Polynomial
        self.S_X = None # Selector Polynomial
        self.W_X = None # Permutation Polynomial

    def step_1_compile_trace(self):
        """Interpolate the computational trace T(X)."""
        self.T_X = self.env.lagrange_interpolate(self.env.omega_domain, self.circuit.trace_values)
        
        # Sanity check from original script
        val = self.T_X(self.env.get_element(-1))
        return self.T_X

    def step_2_verify_inputs(self):
        """
        Prove T encodes correct inputs for public input positions.
        """
        
        # Interpolate the Public Input Polynomial V_inp(X)
        V_inp_X = self.env.lagrange_interpolate(self.circuit.input_domain, self.circuit.public_inputs)
        
        # ZeroTest: T(x) - V(x) should be 0 on input domain
        for i, domain_point in enumerate(self.circuit.input_domain):
            val_t = self.T_X(domain_point)
            val_v = V_inp_X(domain_point)
            assert val_t - val_v == 0, f"Input mismatch at index {i}"
        

    def step_3_verify_gates(self):
        """
        Prove every gate is evaluated correctly using Selector Polynomial S(X).
        Constraint: S(x)*(L+R) + (1-S(x))*(L*R) - O == 0
        """        
        # Interpolate Selector S(X)
        self.S_X = self.env.lagrange_interpolate(self.circuit.gate_indices, self.circuit.selector_values)
        
        # ZeroTest on all gate positions
        for i in self.circuit.gate_indices:
            # Get L, R, O values relative to the gate start index i
            L = self.T_X(i * self.env.get_element(0)) # w^0 shift
            R = self.T_X(i * self.env.get_element(1)) # w^1 shift
            O = self.T_X(i * self.env.get_element(2)) # w^2 shift
            
            # Selector value
            s_val = self.S_X(i)
            
            # Arithmetic Constraint
            constraint = s_val * (L + R) + (self.env.GF(1) - s_val) * (L * R) - O
            
            assert constraint == 0, f"Gate constraint failed at index {i}"


    def step_4_verify_wiring(self):
        """
        Prove wiring is correct.
        Checks if the multiset of values in T(X) matches the multiset of values in T(W(X)).
        """
        
        # Interpolate Wiring Polynomial W(X)
        self.W_X = self.env.lagrange_interpolate(self.env.omega_domain, self.circuit.wiring_permutation)
        
        # Collect roots (values of functions)
        # f(a) = T(a) for all a in Omega
        roots_f = [self.T_X(y) for y in self.env.omega_domain]

        # g(a) = T(W(a)) for all a in Omega
        roots_g = [self.T_X(self.W_X(y)) for y in self.env.omega_domain]

        # Build polynomials from these roots: (X - r1)(X - r2)...
        f_hat = galois.Poly.Roots(roots_f, field=self.env.GF)
        g_hat = galois.Poly.Roots(roots_g, field=self.env.GF)

        # Check if polynomials are identical (dividing them should yield 1)
        Q, R = divmod(f_hat, g_hat)
        
        assert Q == 1, "Permutation Quotient is not 1"
        assert R == 0, "Permutation Remainder is not 0"
        
    def run_output_check(self):
        """Evaluate final output."""
        # Output is at the last wire of the last gate (gate 2, output wire is w^8)
        output_val = self.T_X(self.env.get_element(8))
        print("Verification Passed")
        return output_val

In [7]:
# 1. Setup Environment
# Using Prime=193 and Degree=12 as in your original script
env = PlonkEnvironment(prime=193, degree=12)

# 2. Define Circuit Data
circuit = CircuitDefinition(env)

# 3. Initialize Protocol
plonk = PlonkProtocol(env, circuit)

# 4. Run Steps
plonk.step_1_compile_trace()
plonk.step_2_verify_inputs()
plonk.step_3_verify_gates()
plonk.step_4_verify_wiring()

# 5. Check Result
final_output = plonk.run_output_check()

Verification Passed
