In [1]:
F193 = GF(64*3+1)
TWO_ADICITY = 6
ORDER = 64*3+1
MULTIPLICATIVE_GENERATOR = F193.primitive_element()
ROOT_OF_UNITY = F193(MULTIPLICATIVE_GENERATOR**3)
F193.primitive_element()
R193.<X, A, B, X0, X1, X2, X3, X4, X5, A0, A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15> = PolynomialRing(F193)

In [2]:
from typing import Union, Optional, TypeVar
from random import randint, Random
from utils import prime_field_inv, is_power_of_two, log_2, next_power_of_two

In [3]:
class Fp:
    field_modulus = ORDER

    ROOT_OF_UNITY = ROOT_OF_UNITY
    MULTIPLICATIVE_GENERATOR = MULTIPLICATIVE_GENERATOR
    TWO_ADICITY = TWO_ADICITY

    def __init__(self, val: Integer | int) -> None:
        if self.field_modulus is None:
            raise AttributeError("Field Modulus hasn't been specified")

        if isinstance(val, Fp):
            self.n = val.n % self.field_modulus
        elif isinstance(val, Integer):
            self.n = val % self.field_modulus
        elif Integer(val) == Integer(val):
            self.n = Integer(val) % self.field_modulus
        else:
            raise TypeError(
                "Expected an int or Fp object, but got object of type {}"
                .format(type(val))
            )

    def __add__(self, other: Union[Integer, "Fp"]) -> "Fp":
        if isinstance(other, Fp):
            on = other.n
        elif isinstance(other, Integer):
            on = other
        elif Integer(other) == Integer(other):
            on = Integer(other)
        else:
            raise TypeError(
                "Expected an int or Fp object, but got object of type {}"
                .format(type(other))
            )

        return type(self)((self.n + on) % self.field_modulus)

    def __mul__(self, other: Union["Fp", Integer]) -> "Fp":
        if isinstance(other, Fp):
            on = other.n
        elif isinstance(other, Integer):
            on = other
        elif Integer(other) == Integer(other):
            on = Integer(other)
        else:
            raise TypeError(
                "Expected an int or Fp object, but got object of type {}"
                .format(type(other))
            )

        return type(self)((self.n * on) % self.field_modulus)

    def __rmul__(self, other: Union["Fp", Integer]) -> "Fp":
        return self * other

    def __radd__(self, other: Union["Fp", Integer]) -> "Fp":
        return self + other

    def __rsub__(self, other: Union["Fp", Integer]) -> "Fp":
        if isinstance(other, Fp):
            on = other.n
        elif isinstance(other, Integer):
            on = other
        elif Integer(other) == Integer(other):
            on = Integer(other)
        else:
            raise TypeError(
                "Expected an int or Fp object, but got object of type {}"
                .format(type(other))
            )

        return type(self)((on - self.n) % self.field_modulus)

    def __sub__(self, other: Union["Fp", Integer]) -> "Fp":
        if isinstance(other, Fp):
            on = other.n
        elif isinstance(other, Integer):
            on = other
        elif Integer(other) == Integer(other):
            on = Integer(other)
        else:
            raise TypeError(
                "Expected an int or Fp object, but got object of type {}"
                .format(type(other))
            )

        return type(self)((self.n - on) % self.field_modulus)

    def __div__(self, other: Union["Fp", Integer]) -> "Fp":
        if isinstance(other, Fp):
            on = other.n
        elif isinstance(other, Integer):
            on = other
        elif Integer(other) == Integer(other):
            on = Integer(other)
        else:
            raise TypeError(
                "Expected an int or Fp object, but got object of type {}"
                .format(type(other))
            )

        return type(self)(
            self.n * prime_field_inv(on, self.field_modulus) % self.field_modulus
        )

    def __truediv__(self, other: Union["Fp", Integer]) -> "Fp":
        return self.__div__(other)

    def __rdiv__(self, other: Union["Fp", Integer]) -> "Fp":
        if isinstance(other, Fp):
            on = other.n
        elif isinstance(other, Integer):
            on = other
        elif Integer(other) == Integer(other):
            on = Integer(other)
        else:
            raise TypeError(
                "Expected an int or Fp object, but got object of type {}"
                .format(type(other))
            )

        return type(self)(
            prime_field_inv(self.n, self.field_modulus) * on % self.field_modulus
        )

    def __rtruediv__(self, other: Union["Fp", Integer]) -> "Fp":
        return self.__rdiv__(other)

    # def __pow__(self, other: Union["Fp", Integer]) -> "Fp":
    #     if other == 0:
    #         return type(self)(1)
    #     elif other == 1:
    #         return type(self)(self.n)
    #     elif other % 2 == 0:
    #         return (self * self) ** (other // 2)
    #     else:
    #         return ((self * self) ** (other // 2)) * self

    def __eq__(self, other: Union["Fp", Integer]) -> bool:
        if isinstance(other, Fp):
            return self.n == other.n
        elif isinstance(other, Integer):
            return self.n == other
        elif Integer(other) == Integer(other):
            return self.n == Integer(other)
        else:
            raise TypeError(
                "Expected an int or Fp object, but got object of type {}"
                .format(type(other))
            )

    def __ne__(self, other: Union["Fp", Integer]) -> bool:
        return not self == other

    def __neg__(self) -> "Fp":
        return type(self)(-self.n)

    def __str__(self) -> str:
        return self.repr()
        
    def __repr__(self) -> str:
        return self.repr()
    
    # Override the default (inefficient) __pow__ function in py_ecc.fields.field_elements.FQ
    def __pow__(self: "Fp", other: int) -> "Fp":
        return type(self)(pow(self.n, other, self.field_modulus))
    
    # def __repr__(self) -> str:
    #     return repr(self.n)

    def __int__(self) -> int:
        return self.n

    @classmethod
    def one(cls) -> "Fp":
        return cls(1)

    @classmethod
    def zero(cls) -> "Fp":
        return cls(0)
    
    @classmethod
    def neg_one(cls) -> "Fp":
        return cls(cls.field_modulus - 1)

    @classmethod
    def rand(cls, rndg: Optional[Random] = None) -> "Fp":
        if rndg is None:
            return cls(randint(1, cls.field_modulus - 1))
        return cls(rndg.randint(1, cls.field_modulus - 1))
    
    @classmethod
    def random(cls) -> "Fp":
        return cls.rand()
    
    @classmethod
    def rands(cls, rndg: Random, n: int) -> list["Fp"]:
        return [cls(rndg.randint(1, cls.field_modulus - 1)) for _ in range(n)]
    
    @classmethod
    def from_bytes(cls, b: bytes) -> "Fp":
        i = int.from_bytes(b, "big")
        return cls(i)
    
    def inv(self) -> "Fp":
        return Fp(prime_field_inv(self.n, self.field_modulus))
    
    def repr(self) -> str:
        k = self.field_modulus // 2
        if self.n < k:
            return f"{self.n}"
        else:
            return f"-{self.field_modulus - self.n}"
        
    def exp(self: "Fp", other: int) -> "Fp":
        return type(self)(pow(self.n, other, self.field_modulus))
    
    @classmethod
    def compute_root_of_unity(cls) -> "Fp":
        return cls(pow(cls.MULTIPLICATIVE_GENERATOR, ((cls.field_modulus - 1) // 2 ** cls.TWO_ADICITY), cls.field_modulus))
    
    @classmethod
    def root_of_unity(cls) -> "Fp":
        return cls(cls.ROOT_OF_UNITY)
    
    @classmethod
    def multiplicative_generator(cls) -> "Fp":
        return cls(cls.MULTIPLICATIVE_GENERATOR)

    @classmethod
    def nth_root_of_unity(cls, n: int) -> "Fp":
        assert is_power_of_two(n), "n must be a power of two"
        return cls(pow(cls.ROOT_OF_UNITY, 2**(cls.TWO_ADICITY - log_2(n)), cls.field_modulus))

In [4]:
from unipoly2 import UniPolynomial, UniPolynomialWithFft, bit_reverse_permutation
from mle2 import MLEPolynomial
UniPolynomialWithFft.set_field_type(Fp)
MLEPolynomial.set_field_type(Fp)
from merkle import MerkleTree

In [5]:
f_mle = MLEPolynomial([Fp(1), Fp(3), Fp(2), Fp(1),
                    Fp(2), Fp(-2), Fp(1), Fp(2)], 3)
f_cm = MerkleTree(f_mle.evals)
us = [Fp(2), Fp(-1), Fp(2)]
v = f_mle.evaluate(us)
v, f_cm.root

(-40, 'c86117467dc48f44adcbdff0f6ccfd5e1edfbf358ff48efe3441f2cb8b34aca2')

In [6]:
def rs_encode(f: list[Field], coset: Field, blowup_factor: int) -> list[Field]:
    n = next_power_of_two(len(f))
    N = n * blowup_factor

    omega_Nth = Fp.nth_root_of_unity(N)
    # print(f"omega_Nth = {omega_Nth}")
    k = log_2(N)
    # print(f"n = {n}, N = {N}, k = {k}")
    vec = f + [Fp.zero()] * (N - len(f))
    # print(f"vec = {vec}, len(vec) = {len(vec)}")
    return UniPolynomialWithFft.fft_coset_rbo(vec, coset, k, omega=omega_Nth)
     

In [7]:
f = f_mle.evals
alpha = Fp(2)
f_len = len(f)
g = Fp.multiplicative_generator()
blowup_factor = 2
f_code_len = f_len * blowup_factor
print(f"f_code_len = {f_code_len}, g={g}, blowup_factor={blowup_factor}")

f_code_len = 16, g=5, blowup_factor=2


## Commit phase

In [8]:
debug = 2
twiddles = UniPolynomialWithFft.precompute_twiddles_for_fft(f_code_len, is_bit_reversed=True)
coset = Fp.multiplicative_generator()

In [9]:
f = f_mle.evals.copy()

# Sumcheck Round 0

eq = MLEPolynomial.eqs_over_hypercube(us)
n = len(f)
half = n >> 1
sum0 = v
v, f, eq

(-40, [1, 3, 2, 1, 2, -2, 1, 2], [2, -4, -1, 2, -4, 8, 2, -4])

In [10]:
alpha = Fp(31)

In [11]:
def compute_alpha_powers(alpha: Fp, n: int) -> list[Fp]:
    return [alpha**(2**i) for i in range(n)]
alpha_powers = compute_alpha_powers(alpha, 3)
alpha_powers

[31, -4, 16]

In [40]:
def fold(f: list[Fp], r: Fp) -> list[Fp]:
    n = len(f)
    assert n % 2 == 0, f"n = {n}, n must be even"
    half = n >> 1
    f_even = f[::2]
    f_odd = f[1::2]
    return [f_even[i] + r * (f_odd[i] - f_even[i]) for i in range(half)]
fold([A0, A1, A2, A3, A4, A5, A6, A7], X0)

[-X0*A0 + X0*A1 + A0,
 -X0*A2 + X0*A3 + A2,
 -X0*A4 + X0*A5 + A4,
 -X0*A6 + X0*A7 + A6]

In [72]:
def expanded_partial_evaluate(f: list[Fp], us: list[Fp]) -> list[Fp]:
    """
    Evaluate mle polynomial *partially* from x_{n-1}, x_{n-2}, ..., x_{n-k}
    """
    k, n = len(us), len(f)
    assert n == 2**k, f"n = {n}, k = {k}"

    rs, e = [], f.copy()
    half = n >> 1
    for i in range(k):
        e_low, e_high = e[:half], e[half:]
        e = [e_low[j] + us[k-i-1] * (e_high[j] - e_low[j]) for j in range(half)]
        rs += e
        half >>= 1
    return rs

expanded_partial_evaluate([A0, A1, A2, A3], [X0, X1])

[-X1*A0 + X1*A2 + A0,
 -X1*A1 + X1*A3 + A1,
 X0*X1*A0 - X0*X1*A1 - X0*X1*A2 + X0*X1*A3 - X0*A0 - X1*A0 + X0*A1 + X1*A2 + A0]

In [73]:
expanded_partial_evaluate([Fp(1), Fp(2), Fp(3), Fp(4)], [Fp(3), Fp(2)])

[5, 6, 8]

In [74]:
MLEPolynomial([Fp(1), Fp(2), Fp(3), Fp(4)], 2).evaluate([Fp(3), Fp(2)])

8

In [75]:
expanded_partial_evaluate([Fp(1), Fp(2), Fp(3), Fp(4), Fp(5), Fp(6), Fp(7), Fp(8)], [Fp(4), Fp(2), Fp(3)])

[13, 14, 15, 16, 17, 18, 21]

In [76]:
MLEPolynomial([Fp(1), Fp(2), Fp(3), Fp(4), Fp(5), Fp(6), Fp(7), Fp(8)], 3).evaluate([Fp(4), Fp(2), Fp(3)])

21

In [14]:
f_at_alpha = f_mle.evaluate(alpha_powers)
f_at_alpha


62

In [15]:
# FRI Round 0

f0_code = rs_encode(f, coset, 2)
if debug > 1:
    print(f"P> check f0_code")
    f_orig_evals = bit_reverse_permutation(f0_code)
    f_orig_coeffs = UniPolynomialWithFft.ifft_coset(f_orig_evals, coset, log_2(f_code_len))
    f_orig = UniPolynomial(f_orig_coeffs)
    assert f_orig.coeffs == f, f"f_orig != f0, f_orig = {f_orig.coeffs}, f = {f}"
    print(f"P> check f0_code passed")

P> check f0_code
P> check f0_code passed


In [16]:
alpha0 = Fp(32)
alpha0_powers = compute_alpha_powers(alpha0, 3)
alpha0_powers

[32, 59, 7]

In [17]:
# Sumcheck Round 1

f0_even = f[::2]
f0_odd = f[1::2]
f0_even_mle = MLEPolynomial(f0_even, f_mle.num_var-1)
f0_odd_mle = MLEPolynomial(f0_odd, f_mle.num_var-1)

# construct hz(X): hz(X) = f(X, u1, u2,...)
hz_at_0 = f0_even_mle.evaluate(us[1:])
hz_at_1 = f0_odd_mle.evaluate(us[1:])
hz = [hz_at_0, hz_at_1]

assert (UniPolynomial.evaluate_from_evals(hz, us[0], [Fp(0), Fp(1)])) == v, \
    f"hz(us[0]) = {UniPolynomial.evaluate_from_evals(hz, us[0])}, v = {v}"

In [18]:
# alpha0 = Fp.rand()
r0 = Fp(-30)
r0

-30

In [19]:
# Sumcheck fold 

f1 = [(Fp(1) - r0) * f0_even[i] + r0 * f0_odd[i] for i in range(half)]

# compute the new sum = h(r0)
hz_r0 = UniPolynomial.evaluate_from_evals(hz, r0, [Fp(0), Fp(1)])

In [20]:
half >>= 1

In [21]:
# Sumcheck Round 2

f1_even = f1[::2]
f1_odd = f1[1::2]
f1_even_mle = MLEPolynomial(f1_even, f_mle.num_var-2)
f1_odd_mle = MLEPolynomial(f1_odd, f_mle.num_var-2)

# construct hz(X): hz(X) = f1(X, u1, u2,...)
hz_at_0 = f1_even_mle.evaluate(us[2:])
hz_at_1 = f1_odd_mle.evaluate(us[2:])
hz = [hz_at_0, hz_at_1]

assert (UniPolynomial.evaluate_from_evals(hz, us[1], [Fp(0), Fp(1)])) == hz_r0, \
    f"hz(us[1]) = {UniPolynomial.evaluate_from_evals(hz, us[1])}, hz_r0 = {hz_r0}"


In [22]:
# alpha0 = Fp.rand()
r1 = Fp(-32)
r1

-32

In [23]:
# Sumcheck fold 

f2 = [(Fp(1) - r1) * f1_even[i] + r1 * f1_odd[i] for i in range(half)]

# compute the new sum = h(r1)
hz_r1 = UniPolynomial.evaluate_from_evals(hz, r1, [Fp(0), Fp(1)])
hz_r1

-52

In [24]:
half >>= 1

In [25]:
# Sumcheck Round 2

f2_even = f2[::2]
f2_odd = f2[1::2]
f2_even_mle = MLEPolynomial(f2_even, f_mle.num_var-3)
f2_odd_mle = MLEPolynomial(f2_odd, f_mle.num_var-3)

# construct hz(X): hz(X) = f2(X) = f(r0, r1, X)
hz_at_0 = f2_even_mle.evaluate(us[3:])
hz_at_1 = f2_odd_mle.evaluate(us[3:])
hz = [hz_at_0, hz_at_1]

assert (UniPolynomial.evaluate_from_evals(hz, us[2], [Fp(0), Fp(1)])) == hz_r1, \
    f"hz(us[2]) = {UniPolynomial.evaluate_from_evals(hz, us[2])}, hz_r1 = {hz_r1}"

In [26]:
# alpha0 = Fp.rand()
r2 = Fp(-33)
r2


-33

In [27]:
# Sumcheck fold 

f3 = [(Fp(1) - r2) * f2_even[i] + r2 * f2_odd[i] for i in range(half)]

# compute the new sum = h(r2)
hz_r2 = UniPolynomial.evaluate_from_evals(hz, r2, [Fp(0), Fp(1)])
hz_r2

-86

In [28]:
# Check the final result: h(r0, r1, r2) = f(r0, r1, r2)

hz_r2 == f_mle.evaluate([r0, r1,r2])

True

## Simplified Sumcheck

In [29]:
sum_checked = v
half = n >> 1
k = len(us)
# f_code = f0_code
# coset = self.coset_gen
constant = None
r_vec = []
sumcheck_h_vec = []
# code_commitments = []
# trees = [MerkleTree(f_code)]
# codes = [f_code]

In [30]:
f_partial_evals = expanded_partial_evaluate(f, us)
f_partial_evals

rs=[]
rs=[3, -7, 0, 3]
rs=[3, -7, 0, 3, 6, -17]


[3, -7, 0, 3, 6, -17, -40]

In [31]:
# Sumcheck Round 0
i = 0 

f_even = f[::2]
f_odd = f[1::2]
print(f"P> f = {f}")

# construct h_i(X) = f_i(r0,r1,...,X,u_{i+1},u_{i+2},...,u_k)
h_at_0 = f_partial_evals[-3]
h_at_1 = f_partial_evals[-2]
h_at_u = f_partial_evals[-1]
h_at_u_plus_1 = h_at_u + (h_at_1 - h_at_0)

print(f"P> h_at_0 = {h_at_0}")
print(f"P> h_at_1 = {h_at_1}")
print(f"P> h_at_u = {h_at_u}")
print(f"P> h_plus_1_at_u = {h_at_u_plus_1}")

if debug > 0:
    print(f"P> check h(0), h(1), h(u), h(u+1)")
    partial_f_mle = MLEPolynomial(f, k-i)
    assert h_at_0 == partial_f_mle.evaluate([Fp(0)]+us[i+1:]), \
        f"h_at_0 = {h_at_0}, partial_f_mle.evaluate([Fp(0)]+us[i+1:]) = {partial_f_mle.evaluate([Fp(0)]+us[i+1:])}"
    assert h_at_1 == partial_f_mle.evaluate([Fp(1)]+us[i+1:]), \
        f"h_at_1 = {h_at_1}, partial_f_mle.evaluate([Fp(1)]+us[i+1:]) = {partial_f_mle.evaluate([Fp(1)]+us[i+1:])}"
    assert h_at_u == partial_f_mle.evaluate(us[i:]), \
        f"h_at_u = {h_at_u}, partial_f_mle.evaluate(us[i:]) = {partial_f_mle.evaluate(us[i:])}"
    assert h_at_u_plus_1 == partial_f_mle.evaluate([us[i]+Fp(1)]+us[i+1:]), \
        f"h_plus_1_at_u = {h_at_u_plus_1}, partial_f_mle.evaluate([us[i]+Fp(1)]+us[i+1:]) = {partial_f_mle.evaluate([us[i]+Fp(1)]+us[i+1:])}"
    # print(f"P> f(0) = {partial_f_mle.evaluate([Fp(0)]+us[i+1:])}")
    # print(f"P> f(u) = {partial_f_mle.evaluate(us[i:])}")
    # print(f"P> f(u+1) = {partial_f_mle.evaluate([us[i]+Fp(1)]+us[i+1:])}")
    assert h_at_u == sum_checked, \
        f"h_at_u = {h_at_u}, sum_checked = {sum_checked}"
    print(f"P> check h_at_u passed, {h_at_u} = {sum_checked}")
sumcheck_h_vec.append(h_at_u)

# tr.absorb(b"h(X)", h_eval_at_u_plus_1)

# check sum
# assert UniPolynomial.evaluate_from_evals(h, us[i], [Field(0), Field(1)]) == sum_checked, \
#     f"h(us[{i}]) = {UniPolynomial.evaluate_from_evals(h, us[i], [Field(0), Field(1)])}, sum_checked = {sum_checked}"

# Receive a random number from the verifier
r = Fp(-12)

if debug > 0:
    print(f"P> r[{i}] = {r}")


# NOTE: The verifier compute f_i(r), which is *equal to* f_{i+1}(u_i). And then in 
#   the next round, the prover will only need to send f_{i+1}(u_i+1) to the verifier.
#   The computation of f_{i+1}(r) at the verifier's side costs only one multiplication.
h_at_r = h_at_u + (h_at_u_plus_1 - h_at_u) * (r - us[i])

if debug > 0:
    print(f"P> check new sumcheck passed")
    assert h_at_r == MLEPolynomial(f, k-i).evaluate([r]+us[i+1:]), \
        f"h_at_r = {h_at_r}, f(r) = {MLEPolynomial(f, k-i).evaluate([r]+us[i+1:])}"
    new_sum = UniPolynomial.evaluate_from_evals([h_at_u, h_at_u_plus_1], r, [us[i], us[i]+Fp(1)])
    assert h_at_r == new_sum
    print(f"P> check new sumcheck passed, {h_at_r} = {new_sum}")

# fold f

# f_folded = [(Fp(1) - r) * f_even[i] + r * f_odd[i] for i in range(half)]
f_folded = fold(f, r)

f_parital_evals_trimmed = f_partial_evals[:-1]
f_parital_even = f_parital_evals_trimmed[::2]
f_parital_odd = f_parital_evals_trimmed[1::2]
print(f"P> f_parital_evals_trimmed = {f_parital_evals_trimmed}")
f_parital_evals_folded = fold(f_parital_evals_trimmed, r)


P> f = [1, 3, 2, 1, 2, -2, 1, 2]
P> h_eval_at_0 = 6
P> h_eval_at_1 = -17
P> h_eval_at_u = -40
P> h_eval_at_u_plus_1 = -63
P> check h(0), h(1), h(u), h(u+1)
P> check h_eval_at_u passed, -40 = -40
P> r[0] = -12
P> check new sumcheck passed
P> check new sumcheck passed, 89 = 89
P> f_parital_evals_trimmed = [3, -7, 0, 3, 6, -17]


In [32]:
# # update parameters for the next round
sum_checked = h_at_r
r_vec.append(r)
f = f_folded
f_partial_evals = f_parital_evals_folded
# f_code = f_code_folded
half >>= 1
coset *= coset

In [35]:
f_partial_evals, half

([-70, -36, 89], 2)

In [37]:
# Sumcheck Round 1
i = 1

f_even = f[::2]
f_odd = f[1::2]
print(f"P> f = {f}")

# construct h(X)
h_eval_at_0 = f_partial_evals[-3]
h_eval_at_1 = f_partial_evals[-2]
h_eval_at_u = f_partial_evals[-1]
h_eval_at_u_plus_1 = h_eval_at_u + (h_eval_at_1 - h_eval_at_0)

print(f"P> h_eval_at_0 = {h_eval_at_0}")
print(f"P> h_eval_at_1 = {h_eval_at_1}")
print(f"P> h_eval_at_u = {h_eval_at_u}")
print(f"P> h_eval_at_u_plus_1 = {h_eval_at_u_plus_1}")

if debug > 0:
    print(f"P> check h(0), h(1), h(u), h(u+1)")
    partial_f_mle = MLEPolynomial(f, k-i)
    assert h_eval_at_0 == partial_f_mle.evaluate([Fp(0)]+us[i+1:]), \
        f"h_eval_at_0 = {h_eval_at_0}, partial_f_mle.evaluate([Fp(0)]+us[i+1:]) = {partial_f_mle.evaluate([Fp(0)]+us[i+1:])}"
    assert h_eval_at_1 == partial_f_mle.evaluate([Fp(1)]+us[i+1:]), \
        f"h_eval_at_1 = {h_eval_at_1}, partial_f_mle.evaluate([Fp(1)]+us[i+1:]) = {partial_f_mle.evaluate([Fp(1)]+us[i+1:])}"
    assert h_eval_at_u == partial_f_mle.evaluate(us[i:]), \
        f"h_eval_at_u = {h_eval_at_u}, partial_f_mle.evaluate(us[i:]) = {partial_f_mle.evaluate(us[i:])}"
    assert h_eval_at_u_plus_1 == partial_f_mle.evaluate([us[i]+Fp(1)]+us[i+1:]), \
        f"h_eval_at_u_plus_1 = {h_eval_at_u_plus_1}, partial_f_mle.evaluate([us[i]+Fp(1)]+us[i+1:]) = {partial_f_mle.evaluate([us[i]+Fp(1)]+us[i+1:])}"
    # print(f"P> f(0) = {partial_f_mle.evaluate([Fp(0)]+us[i+1:])}")
    # print(f"P> f(u) = {partial_f_mle.evaluate(us[i:])}")
    # print(f"P> f(u+1) = {partial_f_mle.evaluate([us[i]+Fp(1)]+us[i+1:])}")
    assert h_eval_at_u == sum_checked, \
        f"h_eval_at_u = {h_eval_at_u}, sum_checked = {sum_checked}"
    print(f"P> check h_eval_at_u passed, {h_eval_at_u} = {sum_checked}")
sumcheck_h_vec.append(h_eval_at_u_plus_1)

# tr.absorb(b"h(X)", h_eval_at_u_plus_1)

# check sum
# assert UniPolynomial.evaluate_from_evals(h, us[i], [Field(0), Field(1)]) == sum_checked, \
#     f"h(us[{i}]) = {UniPolynomial.evaluate_from_evals(h, us[i], [Field(0), Field(1)])}, sum_checked = {sum_checked}"

# Receive a random number from the verifier
r = Fp(-13)

if debug > 0:
    print(f"P> r[{i}] = {r}")

h_eval_at_r = h_eval_at_u + (h_eval_at_u_plus_1 - h_eval_at_u) * (r - us[i])

if debug > 0:
    print(f"P> check new sumcheck passed")
    assert h_eval_at_r == MLEPolynomial(f, k-i).evaluate([r]+us[i+1:]), \
        f"h_eval_at_r = {h_eval_at_r}, f(r) = {MLEPolynomial(f, k-i).evaluate([r]+us[i+1:])}"
    new_sum = UniPolynomial.evaluate_from_evals([h_eval_at_u, h_eval_at_u_plus_1], r, [us[i], us[i]+Fp(1)])
    assert h_eval_at_r == new_sum
    print(f"P> check new sumcheck passed, {h_eval_at_r} = {new_sum}")

# fold f

# f_folded = [(Fp(1) - r) * f_even[i] + r * f_odd[i] for i in range(half)]
f_folded = fold(f, r)

f_parital_evals_trimmed = f_partial_evals[:-1]
f_parital_even = f_parital_evals_trimmed[::2]
f_parital_odd = f_parital_evals_trimmed[1::2]
print(f"P> f_parital_evals_trimmed = {f_parital_evals_trimmed}")
f_parital_evals_folded = fold(f_parital_evals_trimmed, r)
print(f"P> f_parital_evals_folded = {f_parital_evals_folded}")


P> f = [-23, 14, 50, -11]
P> h_eval_at_0 = -70
P> h_eval_at_1 = -36
P> h_eval_at_u = 89
P> h_eval_at_u_plus_1 = -70
P> check h(0), h(1), h(u), h(u+1)
P> check h_eval_at_u passed, 89 = 89
P> r[1] = -13
P> check new sumcheck passed
P> check new sumcheck passed, 67 = 67
P> f_parital_evals_trimmed = [-70, -36]
P> f_parital_evals_folded = [67]


In [38]:
# # update parameters for the next round
sum_checked = h_eval_at_r
r_vec.append(r)
f = f_folded
f_partial_evals = f_parital_evals_folded
# f_code = f_code_folded
half >>= 1
coset *= coset

In [None]:
# Sumcheck Round 2 (final round)
i = 2

f_even = f[::2]
f_odd = f[1::2]
print(f"P> f = {f}")

# construct h(X)
h_eval_at_0 = f_partial_evals[0]
h_eval_at_1 = f_partial_evals[0]
h_eval_at_u = f_partial_evals[0]
h_eval_at_u_plus_1 = h_eval_at_u + (h_eval_at_1 - h_eval_at_0)

print(f"P> h_eval_at_0 = {h_eval_at_0}")
print(f"P> h_eval_at_1 = {h_eval_at_1}")
print(f"P> h_eval_at_u = {h_eval_at_u}")
print(f"P> h_eval_at_u_plus_1 = {h_eval_at_u_plus_1}")

if debug > 0:
    print(f"P> check h(0), h(1), h(u), h(u+1)")
    partial_f_mle = MLEPolynomial(f, k-i)
    assert h_eval_at_0 == partial_f_mle.evaluate([Fp(0)]+us[i+1:]), \
        f"h_eval_at_0 = {h_eval_at_0}, partial_f_mle.evaluate([Fp(0)]+us[i+1:]) = {partial_f_mle.evaluate([Fp(0)]+us[i+1:])}"
    assert h_eval_at_1 == partial_f_mle.evaluate([Fp(1)]+us[i+1:]), \
        f"h_eval_at_1 = {h_eval_at_1}, partial_f_mle.evaluate([Fp(1)]+us[i+1:]) = {partial_f_mle.evaluate([Fp(1)]+us[i+1:])}"
    assert h_eval_at_u == partial_f_mle.evaluate(us[i:]), \
        f"h_eval_at_u = {h_eval_at_u}, partial_f_mle.evaluate(us[i:]) = {partial_f_mle.evaluate(us[i:])}"
    assert h_eval_at_u_plus_1 == partial_f_mle.evaluate([us[i]+Fp(1)]+us[i+1:]), \
        f"h_eval_at_u_plus_1 = {h_eval_at_u_plus_1}, partial_f_mle.evaluate([us[i]+Fp(1)]+us[i+1:]) = {partial_f_mle.evaluate([us[i]+Fp(1)]+us[i+1:])}"
    # print(f"P> f(0) = {partial_f_mle.evaluate([Fp(0)]+us[i+1:])}")
    # print(f"P> f(u) = {partial_f_mle.evaluate(us[i:])}")
    # print(f"P> f(u+1) = {partial_f_mle.evaluate([us[i]+Fp(1)]+us[i+1:])}")
    assert h_eval_at_u == sum_checked, \
        f"h_eval_at_u = {h_eval_at_u}, sum_checked = {sum_checked}"
    print(f"P> check h_eval_at_u passed, {h_eval_at_u} = {sum_checked}")
sumcheck_h_vec.append(h_eval_at_u_plus_1)

# tr.absorb(b"h(X)", h_eval_at_u_plus_1)

# check sum
# assert UniPolynomial.evaluate_from_evals(h, us[i], [Field(0), Field(1)]) == sum_checked, \
#     f"h(us[{i}]) = {UniPolynomial.evaluate_from_evals(h, us[i], [Field(0), Field(1)])}, sum_checked = {sum_checked}"

# Receive a random number from the verifier
r = Fp(-13)

if debug > 0:
    print(f"P> r[{i}] = {r}")

# NOTE: The verifier compute f_i(r), which is *equal to* f_{i+1}(u_i). And then in 
#   the next round, the prover will only need to send f_{i+1}(u_i+1) to the verifier.
h_eval_at_r = h_eval_at_u + (h_eval_at_u_plus_1 - h_eval_at_u) * (r - us[i])

if debug > 0:
    print(f"P> check new sumcheck passed")
    assert h_eval_at_r == MLEPolynomial(f, k-i).evaluate([r]+us[i+1:]), \
        f"h_eval_at_r = {h_eval_at_r}, f(r) = {MLEPolynomial(f, k-i).evaluate([r]+us[i+1:])}"
    new_sum = UniPolynomial.evaluate_from_evals([h_eval_at_u, h_eval_at_u_plus_1], r, [us[i], us[i]+Fp(1)])
    assert h_eval_at_r == new_sum
    print(f"P> check new sumcheck passed, {h_eval_at_r} = {new_sum}")

# fold f

# f_folded = [(Fp(1) - r) * f_even[i] + r * f_odd[i] for i in range(half)]
f_folded = fold(f, r)

f_parital_evals_trimmed = f_partial_evals[:-1]
f_parital_even = f_parital_evals_trimmed[::2]
f_parital_odd = f_parital_evals_trimmed[1::2]
print(f"P> f_parital_evals_trimmed = {f_parital_evals_trimmed}")
f_parital_evals_folded = fold(f_parital_evals_trimmed, r)
print(f"P> f_parital_evals_folded = {f_parital_evals_folded}")


In [None]:
# # Sumcheck Round 1
# i = 0 

# f_even = f[::2]
# f_odd = f[1::2]
# print(f"P> f = {f}")

# # construct h(X)
# h_eval_at_0 = MLEPolynomial(f_even, k-i-1).evaluate(us[i+1:])
# h_eval_at_1 = MLEPolynomial(f_odd, k-i-1).evaluate(us[i+1:])
# h_eval_at_u = (Fp(1) - us[i]) * h_eval_at_0 + us[i] * h_eval_at_1
# h_eval_at_u_plus_1 = h_eval_at_u + (h_eval_at_1 - h_eval_at_0)

# print(f"P> h_eval_at_0 = {h_eval_at_0}")
# print(f"P> h_eval_at_1 = {h_eval_at_1}")
# print(f"P> h_eval_at_u = {h_eval_at_u}")
# print(f"P> h_eval_at_u_plus_1 = {h_eval_at_u_plus_1}")

# if debug > 0:
#     print(f"P> check h(0), h(1), h(u), h(u+1)")
#     partial_f_mle = MLEPolynomial(f, k-i)
#     assert h_eval_at_0 == partial_f_mle.evaluate([Fp(0)]+us[i+1:]), \
#         f"h_eval_at_0 = {h_eval_at_0}, partial_f_mle.evaluate([Fp(0)]+us[i+1:]) = {partial_f_mle.evaluate([Fp(0)]+us[i+1:])}"
#     assert h_eval_at_1 == partial_f_mle.evaluate([Fp(1)]+us[i+1:]), \
#         f"h_eval_at_1 = {h_eval_at_1}, partial_f_mle.evaluate([Fp(1)]+us[i+1:]) = {partial_f_mle.evaluate([Fp(1)]+us[i+1:])}"
#     assert h_eval_at_u == partial_f_mle.evaluate(us[i:]), \
#         f"h_eval_at_u = {h_eval_at_u}, partial_f_mle.evaluate(us[i:]) = {partial_f_mle.evaluate(us[i:])}"
#     assert h_eval_at_u_plus_1 == partial_f_mle.evaluate([us[i]+Fp(1)]+us[i+1:]), \
#         f"h_eval_at_u_plus_1 = {h_eval_at_u_plus_1}, partial_f_mle.evaluate([us[i]+Fp(1)]+us[i+1:]) = {partial_f_mle.evaluate([us[i]+Fp(1)]+us[i+1:])}"
#     # print(f"P> f(0) = {partial_f_mle.evaluate([Fp(0)]+us[i+1:])}")
#     # print(f"P> f(u) = {partial_f_mle.evaluate(us[i:])}")
#     # print(f"P> f(u+1) = {partial_f_mle.evaluate([us[i]+Fp(1)]+us[i+1:])}")
#     assert h_eval_at_u == sum_checked, \
#         f"h_eval_at_u = {h_eval_at_u}, sum_checked = {sum_checked}"
#     print(f"P> check h_eval_at_u passed, {h_eval_at_u} = {sum_checked}")
# sumcheck_h_vec.append(h_eval_at_u_plus_1)

# # tr.absorb(b"h(X)", h_eval_at_u_plus_1)

# # check sum
# # assert UniPolynomial.evaluate_from_evals(h, us[i], [Field(0), Field(1)]) == sum_checked, \
# #     f"h(us[{i}]) = {UniPolynomial.evaluate_from_evals(h, us[i], [Field(0), Field(1)])}, sum_checked = {sum_checked}"

# # Receive a random number from the verifier
# r = Fp(-12)

# if debug > 0:
#     print(f"P> r[{i}] = {r}")

# h_eval_at_r = h_eval_at_u + (h_eval_at_u_plus_1 - h_eval_at_u) * (r - us[i])

# if debug > 0:
#     print(f"P> check new sumcheck passed")
#     assert h_eval_at_r == MLEPolynomial(f, k-i).evaluate([r]+us[i+1:]), \
#         f"h_eval_at_r = {h_eval_at_r}, f(r) = {MLEPolynomial(f, k-i).evaluate([r]+us[i+1:])}"
#     new_sum = UniPolynomial.evaluate_from_evals([h_eval_at_u, h_eval_at_u_plus_1], r, [us[i], us[i]+Fp(1)])
#     assert h_eval_at_r == new_sum
#     print(f"P> check new sumcheck passed, {h_eval_at_r} = {new_sum}")

# # fold f

# f_folded = [(Fp(1) - r) * f_even[i] + r * f_odd[i] for i in range(half)]
