In [1]:
import numpy as np
import galois
from dataclasses import dataclass
from typing import List, Tuple, Optional
import hashlib
from py_ecc.bn128 import G1, G2, pairing, multiply, add, neg, curve_order, field_modulus, FQ, FQ2, FQ12, twist, final_exponentiate

BN128_SCALAR_INT = curve_order
GF = galois.GF(BN128_SCALAR_INT)

In [2]:
# Classes
class TrustedSetup:
    def __init__(self, degree):
        self.tau = 10  # Toy secret, in real use random and destroy
        self.g1_powers = [multiply(G1, pow(self.tau, i, curve_order)) for i in range(degree + 1)]
        self.g2_powers = [multiply(G2, pow(self.tau, i, curve_order)) for i in range(2)]

class Commitment:
    def __init__(self, point: Optional[Tuple[FQ, FQ]]):
        self.point = point

# Functions
def commit(poly: galois.Poly, setup: TrustedSetup) -> Commitment:
    coeffs = poly.coefficients(order='asc')  # low to high
    point = None
    for i, coeff in enumerate(coeffs):
        term = multiply(setup.g1_powers[i], int(coeff))
        point = add(point, term) if point else term
    return Commitment(point)

def create_proof(poly: galois.Poly, z: GF, y: GF, setup: TrustedSetup) -> Commitment:
    x_minus_z = galois.Poly([1, -z], field=GF)
    poly_minus_y = poly - galois.Poly([y], field=GF)
    quotient, remainder = divmod(poly_minus_y, x_minus_z)
    if remainder != galois.Poly.Zero(field=GF):
        raise ValueError("Division failed")
    return commit(quotient, setup)

def verify_open(com: Commitment, z: GF, y: GF, pi: Commitment, setup: TrustedSetup) -> bool:
    g2_tau = setup.g2_powers[1]
    g2 = setup.g2_powers[0]
    g2_z = multiply(G2, int(z))
    neg_g2_z = neg(g2_z)
    tau_minus_z_g2 = add(g2_tau, neg_g2_z)
    g1_y = multiply(G1, int(y))
    neg_g1_y = neg(g1_y)
    com_minus_y = add(com.point, neg_g1_y)
    lhs = pairing(tau_minus_z_g2, pi.point)
    rhs = pairing(g2, com_minus_y)
    return lhs == rhs

def point_to_str(pt: Optional[Tuple[FQ, FQ]]) -> str:
    if pt is None:
        return "inf"
    return f"{pt[0].n},{pt[1].n}"

# Plonk classes
@dataclass
class Program:
    code: str
    n: int
    qL: galois.FieldArray
    qR: galois.FieldArray
    qO: galois.FieldArray
    qM: galois.FieldArray
    qC: galois.FieldArray
    sigma_1: np.ndarray
    sigma_2: np.ndarray
    sigma_3: np.ndarray

# Plonk
class Plonk:
    def __init__(self, program: Program):
        self.program = program
        self.n = program.n

        self.omega = GF.primitive_root_of_unity(self.n)
        self.roots = np.array([self.omega ** i for i in range(self.n)])
        self.roots_galois = GF(self.roots)

        self.Zh = galois.Poly.Degrees([self.n, 0], coeffs=[1, -1], field=GF)

        self.Qm = galois.lagrange_poly(self.roots_galois, program.qM)
        self.Ql = galois.lagrange_poly(self.roots_galois, program.qL)
        self.Qr = galois.lagrange_poly(self.roots_galois, program.qR)
        self.Qo = galois.lagrange_poly(self.roots_galois, program.qO)
        self.Qc = galois.lagrange_poly(self.roots_galois, program.qC)

        self.k1 = GF(2)
        self.k2 = GF(3)
        self.id_1 = self.roots_galois
        self.id_2 = self.roots_galois * self.k1
        self.id_3 = self.roots_galois * self.k2

        def get_sigma_val(idx):
            if idx < self.n: return self.roots[idx]
            elif idx < 2*self.n: return self.roots[idx - self.n] * self.k1
            else: return self.roots[idx - 2*self.n] * self.k2

        s1_vals = GF([get_sigma_val(i) for i in program.sigma_1])
        s2_vals = GF([get_sigma_val(i) for i in program.sigma_2])
        s3_vals = GF([get_sigma_val(i) for i in program.sigma_3])

        self.S1 = galois.lagrange_poly(self.roots_galois, s1_vals)
        self.S2 = galois.lagrange_poly(self.roots_galois, s2_vals)
        self.S3 = galois.lagrange_poly(self.roots_galois, s3_vals)

        self.setup = TrustedSetup(degree=3 * self.n + 6)

    def prover(self, assignments: dict):
        n = self.n

        a_vals, b_vals, c_vals = GF(assignments['a']), GF(assignments['b']), GF(assignments['c'])
        A = galois.lagrange_poly(self.roots_galois, a_vals)
        b1, b2 = GF.Random(), GF.Random()
        blind_a = (b1 * galois.Poly([1, 0], field=GF) + b2) * self.Zh
        A += blind_a

        B = galois.lagrange_poly(self.roots_galois, b_vals)
        b3, b4 = GF.Random(), GF.Random()
        blind_b = (b3 * galois.Poly([1, 0], field=GF) + b4) * self.Zh
        B += blind_b

        C = galois.lagrange_poly(self.roots_galois, c_vals)
        b5, b6 = GF.Random(), GF.Random()
        blind_c = (b5 * galois.Poly([1, 0], field=GF) + b6) * self.Zh
        C += blind_c

        com_a = commit(A, self.setup)
        com_b = commit(B, self.setup)
        com_c = commit(C, self.setup)

        transcript = point_to_str(com_a.point) + point_to_str(com_b.point) + point_to_str(com_c.point)
        hash_trans = hashlib.sha256(transcript.encode()).digest()
        beta = GF(int.from_bytes(hash_trans[:16], 'big') % BN128_SCALAR_INT)
        gamma = GF(int.from_bytes(hash_trans[16:], 'big') % BN128_SCALAR_INT)

        z_vals = [GF(1)] * (n + 1)
        for i in range(n):
            numer = (a_vals[i] + beta * self.id_1[i] + gamma) * \
                    (b_vals[i] + beta * self.id_2[i] + gamma) * \
                    (c_vals[i] + beta * self.id_3[i] + gamma)
            denom = (a_vals[i] + beta * self.S1(self.roots_galois[i]) + gamma) * \
                    (b_vals[i] + beta * self.S2(self.roots_galois[i]) + gamma) * \
                    (c_vals[i] + beta * self.S3(self.roots_galois[i]) + gamma)
            z_vals[i+1] = z_vals[i] * (numer / denom)
        z_vals = GF(z_vals[:-1])
        Z = galois.lagrange_poly(self.roots_galois, z_vals)
        b7, b8, b9 = GF.Random(), GF.Random(), GF.Random()
        blind_z = (b7 * galois.Poly([1, 0, 0], field=GF) + b8 * galois.Poly([1, 0], field=GF) + b9) * self.Zh
        Z += blind_z
        com_z = commit(Z, self.setup)

        transcript += point_to_str(com_z.point)
        hash_trans = hashlib.sha256(transcript.encode()).digest()
        alpha = GF(int.from_bytes(hash_trans, 'big') % BN128_SCALAR_INT)

        gate_constraints = A * B * self.Qm + A * self.Ql + B * self.Qr + C * self.Qo + self.Qc

        # Compute Z_shifted as composition Z(omega * X)
        coeffs = Z.coefficients()  # high to low
        d = Z.degree
        new_coeffs = [coeffs[k] * (self.omega ** (d - k)) for k in range(d + 1)]
        Z_shifted = galois.Poly(new_coeffs, field=GF)

        term1 = (A + galois.Poly([beta, gamma], field=GF)) * \
                (B + galois.Poly([beta * self.k1, gamma], field=GF)) * \
                (C + galois.Poly([beta * self.k2, gamma], field=GF)) * Z
        term2 = (A + beta * self.S1 + gamma) * \
                (B + beta * self.S2 + gamma) * \
                (C + beta * self.S3 + gamma) * Z_shifted
        perm_constraints = term1 - term2

        L1 = galois.lagrange_poly(self.roots_galois, GF([1] + [0]*(n-1)))
        start_constraints = (Z - GF(1)) * L1

        numerator = gate_constraints + alpha * perm_constraints + (alpha**2) * start_constraints
        quotient, remainder = divmod(numerator, self.Zh)
        if remainder != galois.Poly.Zero(field=GF):
            raise ValueError("Remainder not zero: constraints not satisfied")

        T = quotient
        com_t = commit(T, self.setup)

        transcript += point_to_str(com_t.point)
        hash_trans = hashlib.sha256(transcript.encode()).digest()
        zeta = GF(int.from_bytes(hash_trans, 'big') % BN128_SCALAR_INT)

        evals = {
            "a": A(zeta), "b": B(zeta), "c": C(zeta),
            "z": Z(zeta),
            "z_shifted": Z_shifted(zeta), "t": T(zeta)
        }

        pi_a = create_proof(A, zeta, evals["a"], self.setup)
        pi_b = create_proof(B, zeta, evals["b"], self.setup)
        pi_c = create_proof(C, zeta, evals["c"], self.setup)
        pi_z = create_proof(Z, zeta, evals["z"], self.setup)
        pi_t = create_proof(T, zeta, evals["t"], self.setup)
        pi_z_shifted = create_proof(Z, zeta * self.omega, evals["z_shifted"], self.setup)

        return {
            "com_a": com_a, "com_b": com_b, "com_c": com_c, "com_z": com_z, "com_t": com_t,
            "evals": evals,
            "pi_a": pi_a, "pi_b": pi_b, "pi_c": pi_c,
            "pi_z": pi_z, "pi_t": pi_t, "pi_z_shifted": pi_z_shifted
        }

    def verifier(self, proof):
        transcript = point_to_str(proof["com_a"].point) + point_to_str(proof["com_b"].point) + point_to_str(proof["com_c"].point)
        hash_trans = hashlib.sha256(transcript.encode()).digest()
        beta = GF(int.from_bytes(hash_trans[:16], 'big') % BN128_SCALAR_INT)
        gamma = GF(int.from_bytes(hash_trans[16:], 'big') % BN128_SCALAR_INT)

        transcript += point_to_str(proof["com_z"].point)
        hash_trans = hashlib.sha256(transcript.encode()).digest()
        alpha = GF(int.from_bytes(hash_trans, 'big') % BN128_SCALAR_INT)

        transcript += point_to_str(proof["com_t"].point)
        hash_trans = hashlib.sha256(transcript.encode()).digest()
        zeta = GF(int.from_bytes(hash_trans, 'big') % BN128_SCALAR_INT)

        e = proof["evals"]

        if not verify_open(proof["com_a"], zeta, e["a"], proof["pi_a"], self.setup):
            return False
        if not verify_open(proof["com_b"], zeta, e["b"], proof["pi_b"], self.setup):
            return False
        if not verify_open(proof["com_c"], zeta, e["c"], proof["pi_c"], self.setup):
            return False
        if not verify_open(proof["com_z"], zeta, e["z"], proof["pi_z"], self.setup):
            return False
        if not verify_open(proof["com_t"], zeta, e["t"], proof["pi_t"], self.setup):
            return False
        if not verify_open(proof["com_z"], zeta * self.omega, e["z_shifted"], proof["pi_z_shifted"], self.setup):
            return False

        zh_zeta = zeta**self.n - GF(1)
        l1_zeta = zh_zeta / (GF(self.n) * (zeta - GF(1)))

        qm_z, ql_z, qr_z, qo_z, qc_z = self.Qm(zeta), self.Ql(zeta), self.Qr(zeta), self.Qo(zeta), self.Qc(zeta)
        gate_part = qm_z * e["a"] * e["b"] + ql_z * e["a"] + qr_z * e["b"] + qo_z * e["c"] + qc_z

        z_zeta = e["z"]
        s1_z = self.S1(zeta)
        s2_z = self.S2(zeta)
        s3_z = self.S3(zeta)

        term1 = (e["a"] + beta * zeta + gamma) * \
                (e["b"] + beta * self.k1 * zeta + gamma) * \
                (e["c"] + beta * self.k2 * zeta + gamma) * z_zeta
        term2 = (e["a"] + beta * s1_z + gamma) * \
                (e["b"] + beta * s2_z + gamma) * \
                (e["c"] + beta * s3_z + gamma) * e["z_shifted"]
        perm_part = term1 - term2

        start_part = (z_zeta - GF(1)) * l1_zeta

        lhs = e["t"] * zh_zeta
        rhs = gate_part + alpha * perm_part + (alpha**2) * start_part

        print(f"Verifier Check:\n LHS: {lhs}\n RHS: {rhs}")
        return lhs == rhs

In [3]:
# =============================================================================
# 4. Run Example: (x1 + x2) * (x2 + w1)
# =============================================================================
# Trace:
# Row 0: Inputs (x1=5, x2=6, w1=1) -> No constraint, just storage
# Row 1: Gate 0 (5 + 6 = 11)       -> Add Gate
# Row 2: Gate 1 (6 + 1 = 7)        -> Add Gate
# Row 3: Gate 2 (11 * 7 = 77)      -> Mul Gate

n = 4

# Selectors
# Row:   0  1  2  3
# Type:  In +  +  *
qL = GF([0, 1, 1, 0])
qR = GF([0, 1, 1, 0])
qM = GF([0, 0, 0, 1])
# Output coefficient is -1 for computation rows (1,2,3) to satisfy: L op R - O = 0
qO = GF([0, BN128_SCALAR_INT-1, BN128_SCALAR_INT-1, BN128_SCALAR_INT-1])
qC = GF([0, 0, 0, 0])

# Witness Assignments
#        Row 0      Row 1      Row 2      Row 3
# L   |  x1 (5)  |  x1 (5)  |  x2 (6)  |  v1 (11) |
# R   |  x2 (6)  |  x2 (6)  |  w1 (1)  |  v2 (7)  |
# O   |  w1 (1)  |  v1 (11) |  v2 (7)  |  out (77)|

L = GF([5, 5, 6, 11])
R = GF([6, 6, 1, 7])
O = GF([1, 11, 7, 77])

# Wiring (Permutations)
# Indices: L=0..3, R=4..7, O=8..11
sigma = list(range(3 * n))
def connect(i, j): sigma[i], sigma[j] = sigma[j], sigma[i]

# 1. Wire x1 (5): L[0] -> L[1]
connect(0, 1)

# 2. Wire x2 (6): R[0] -> R[1] -> L[2]
connect(4, 5)
connect(5, 2)

# 3. Wire w1 (1): O[0] -> R[2]
connect(8, 6)

# 4. Wire v1 (11): O[1] -> L[3]
connect(9, 3)

# 5. Wire v2 (7): O[2] -> R[3]
connect(10, 7)

program = Program("(x1+x2)(x2+w1)", n, qL, qR, qO, qM, qC, 
                  np.array(sigma[0:n]), np.array(sigma[n:2*n]), np.array(sigma[2*n:3*n]))

plonk = Plonk(program)
print("Generating Proof for (5+6)*(6+1) = 77...")
proof = plonk.prover({'a': L, 'b': R, 'c': O})
print("Verifying Proof...")
assert plonk.verifier(proof), "Verification Failed"
print("SUCCESS: Proof Verified!")

Generating Proof for (5+6)*(6+1) = 77...
Verifying Proof...
Verifier Check:
 LHS: 21828049561480831531051403413155953295024721628514504579244575221537862876022
 RHS: 21828049561480831531051403413155953295024721628514504579244575221537862876022
SUCCESS: Proof Verified!


In [4]:
# =============================================================================
# 4. Run Example: Vitalik's Equation x^3 + x + 5 = 35
# =============================================================================
# Trace Logic:
# Step 1: Compute x^2      (x * x = a)
# Step 2: Compute x^3      (a * x = b)
# Step 3: Compute x^3 + x  (b + x = c)
# Step 4: Add 5            (c + 5 = out)
# Step 5: Public Check     (out === 35)

n = 8 # Circuit size (padded to next power of 2)

# --- Selector Polynomials (The "Program") ---
# qM, qL, qR, qO, qC determine the logic of each row.
# General Constraint: qL*L + qR*R + qM*L*R + qO*O + qC = 0

# Row 0: Multiplication (x * x = a)
#   1*L*R - 1*O = 0  -> qM=1, qO=-1
# Row 1: Multiplication (a * x = b)
#   1*L*R - 1*O = 0  -> qM=1, qO=-1
# Row 2: Addition (b + x = c)
#   1*L + 1*R - 1*O = 0 -> qL=1, qR=1, qO=-1
# Row 3: Addition with Constant (c + 5 = out)
#   1*L + 5 - 1*O = 0   -> qL=1, qC=5, qO=-1 (qR=0 because R is unused)
# Row 4: Public Input Check (out = 35)
#   1*L - 35 = 0        -> qL=1, qC=-35 (qO=0, effectively acts as a check)
# Rows 5-7: Dummy/Padding gates (all zeros)

qL = GF([0, 0, 1, 1, 1, 0, 0, 0])
qR = GF([0, 0, 1, 0, 0, 0, 0, 0])
# Note: BN128_SCALAR_INT-1 is the field representation of -1
qO = GF([BN128_SCALAR_INT-1, BN128_SCALAR_INT-1, BN128_SCALAR_INT-1, BN128_SCALAR_INT-1, 0, 0, 0, 0])
qM = GF([1, 1, 0, 0, 0, 0, 0, 0])
qC = GF([0, 0, 0, 5, BN128_SCALAR_INT-35, 0, 0, 0])

# --- Witness Assignments (The "Memory") ---
# We execute the program with x = 3
w_x = 3
w_a = w_x * w_x  # 9
w_b = w_a * w_x  # 27
w_c = w_b + w_x  # 30
w_out = w_c + 5  # 35

# Columns assignments for the gates defined above:
# Row 0: L=3, R=3, O=9   (3*3=9)
# Row 1: L=9, R=3, O=27  (9*3=27)
# Row 2: L=27, R=3, O=30 (27+3=30)
# Row 3: L=30, R=0, O=35 (30+5=35) - R is dummy here
# Row 4: L=35, R=0, O=0  (35-35=0) - R, O dummy here

L = GF([w_x, w_a, w_b, w_c, w_out, 0, 0, 0])
R = GF([w_x, w_x, w_x, 0, 0, 0, 0, 0])
O = GF([w_a, w_b, w_c, w_out, 0, 0, 0, 0])

# --- Copy Constraints (The "Wiring") ---
# We must enforce that 'x' in Row 0 is the same 'x' in Row 1 and Row 2, etc.
# Indices map as: L=0..7, R=8..15, O=16..23

sigma = list(range(3 * n))
def connect(i, j): sigma[i], sigma[j] = sigma[j], sigma[i]

# 1. Wire 'x' together:
# L[0] (idx 0) -> R[0] (idx 8) -> R[1] (idx 9) -> R[2] (idx 10)
connect(0, 8); connect(8, 9); connect(9, 10)

# 2. Wire 'a' together:
# Output of Row 0 (O[0], idx 16) -> Input of Row 1 (L[1], idx 1)
connect(16, 1)

# 3. Wire 'b' together:
# Output of Row 1 (O[1], idx 17) -> Input of Row 2 (L[2], idx 2)
connect(17, 2)

# 4. Wire 'c' together:
# Output of Row 2 (O[2], idx 18) -> Input of Row 3 (L[3], idx 3)
connect(18, 3)

# 5. Wire 'out' together:
# Output of Row 3 (O[3], idx 19) -> Input of Row 4 (L[4], idx 4)
connect(19, 4)

program = Program("x^3 + x + 5 = 35", n, qL, qR, qO, qM, qC, 
                  np.array(sigma[0:n]), np.array(sigma[n:2*n]), np.array(sigma[2*n:3*n]))

plonk = Plonk(program)
print("Generating Proof")
proof = plonk.prover({'a': L, 'b': R, 'c': O})
print("Verifying Proof...")
assert plonk.verifier(proof), "Verification Failed"
print("Verification Succeeded")

Generating Proof
Verifying Proof...
Verifier Check:
 LHS: 19554100873310025582395413796556580491145494638819468597990652040322175923193
 RHS: 19554100873310025582395413796556580491145494638819468597990652040322175923193
Verification Succeeded


In [5]:
# =============================================================================
# 4. Run Example: Pythagorean Theorem (a^2 + b^2 = c^2)
# =============================================================================
# We check the triplet (3, 4, 5)
# Gate 0: 3 * 3 = 9   (compute a^2)
# Gate 1: 4 * 4 = 16  (compute b^2)
# Gate 2: 5 * 5 = 25  (compute c^2)
# Gate 3: 9 + 16 = 25 (check sum)

n = 4

# --- Selectors ---
# Row 0, 1, 2 are Multiplications (qM=1)
# Row 3 is Addition (qL=1, qR=1, qO=-1) checking L+R = O
qL = GF([0, 0, 0, 1])
qR = GF([0, 0, 0, 1])
qO = GF([BN128_SCALAR_INT-1, BN128_SCALAR_INT-1, BN128_SCALAR_INT-1, BN128_SCALAR_INT-1])
qM = GF([1, 1, 1, 0])
qC = GF([0, 0, 0, 0])

# --- Witness Assignments ---
# a=3, b=4, c=5
val_a, val_b, val_c = 3, 4, 5
s1 = val_a**2 # 9
s2 = val_b**2 # 16
s3 = val_c**2 # 25

#        Row 0      Row 1      Row 2      Row 3
# L   |  a (3)   |  b (4)   |  c (5)   |  s1 (9)  |
# R   |  a (3)   |  b (4)   |  c (5)   |  s2 (16) |
# O   |  s1 (9)  |  s2 (16) |  s3 (25) |  s3 (25) |

L = GF([val_a, val_b, val_c, s1])
R = GF([val_a, val_b, val_c, s2])
O = GF([s1,    s2,    s3,    s3])

# --- Copy Constraints (Wiring) ---
# Indices: L=0..3, R=4..7, O=8..11
sigma = list(range(3 * n))
def connect(i, j): sigma[i], sigma[j] = sigma[j], sigma[i]

# 1. Wire 'a' input (L[0] <-> R[0])
connect(0, 4)

# 2. Wire 'b' input (L[1] <-> R[1])
connect(1, 5)

# 3. Wire 'c' input (L[2] <-> R[2])
connect(2, 6)

# 4. Wire 's1' (a^2): Output of Row 0 -> Input L of Row 3
connect(8, 3)

# 5. Wire 's2' (b^2): Output of Row 1 -> Input R of Row 3
connect(9, 7)

# 6. Wire 's3' (c^2): Output of Row 2 -> Output of Row 3
# We are forcing the result of (a^2+b^2) to be equal to (c^2)
connect(10, 11)

program = Program("Pythagorean: 3^2 + 4^2 = 5^2", n, qL, qR, qO, qM, qC, 
                  np.array(sigma[0:n]), np.array(sigma[n:2*n]), np.array(sigma[2*n:3*n]))

plonk = Plonk(program)
print("Generating Proof for (3,4,5)...")
proof = plonk.prover({'a': L, 'b': R, 'c': O})
print("Verifying Proof...")
assert plonk.verifier(proof), "Verification Failed"
print("SUCCESS: Proof Verified!")

Generating Proof for (3,4,5)...
Verifying Proof...
Verifier Check:
 LHS: 18138521904379644128250743623693862765320302612963744433316591358425590682434
 RHS: 18138521904379644128250743623693862765320302612963744433316591358425590682434
SUCCESS: Proof Verified!
