In [1]:
from collections import namedtuple
from sage.rings.polynomial.polydict import ETuple
load('utils.sage') # log2_strict, MerkleTree, Transcript

In [2]:
F = GF(15*2^27+1); print(F)
RR.<X> = F[]
FE.<u> = F.extension(X^2 - 11); print(FE)

Finite Field of size 2013265921
Finite Field in u of size 2013265921^2


In [3]:
factor(F.order() - 1), factor(FE.order() - 1)

(2^27 * 3 * 5, 2^28 * 3 * 5 * 31 * 32472031)

In [4]:
def to_bits(x, n_bits):
    for i in range(n_bits):
        yield (x >> i) & 1

# The boolean hypercube
def B(n_bits):
    for i in range(2^n_bits):
        yield list(to_bits(i, n_bits))

In [5]:
def eq(ys, xs=None, R=None):
    ys = list(ys)
    if xs is None:
        if R is None:
            R = PolynomialRing(ys[0].parent(), len(ys), 'x')
        xs = list(R.gens())
    assert len(xs) == len(ys)
    return product(
        x*y + (1 - x)*(1 - y)
        for x, y in zip(xs, ys)
    )

def mle(evals):
    n_bits = log2_strict(len(evals))
    return sum(eq(to_bits(i, n_bits))*evals[i] for i in range(2^n_bits))

x = matrix.random(F, 1, 16).row(0)
assert x == vector([mle(x)(b) for b in B(4)])

In [6]:
def rs_matrix(field, msg_bits, rate_bits):
    if msg_bits == 0:
        return matrix.ones(field, 1, 2^rate_bits)
    inner = rs_matrix(field, msg_bits-1, rate_bits)
    sg = two_adic_subgroup(field, msg_bits+rate_bits)
    T1, T2 = [matrix.diagonal(d) for d in (sg[:len(sg)/2], sg[len(sg)/2:])]
    return matrix.block([[inner, inner], [inner * T1, inner * T2]], subdivide=False)

print(rs_matrix(F, 2, 1))
print('vs:')
print(codes.ReedSolomonCode(F, 8, 4).generator_matrix())

[         1          1          1          1          1          1          1          1]
[         1 1728404513 2013265920  284861408          1 1728404513 2013265920  284861408]
[         1 1592366214 1728404513  211723194 2013265920  420899707  284861408 1801542727]
[         1  211723194  284861408 1592366214 2013265920 1801542727 1728404513  420899707]
vs:
[         1          1          1          1          1          1          1          1]
[         1 1592366214 1728404513  211723194 2013265920  420899707  284861408 1801542727]
[         1 1728404513 2013265920  284861408          1 1728404513 2013265920  284861408]
[         1  211723194  284861408 1592366214 2013265920 1801542727 1728404513  420899707]


In [7]:
def multilinear_coeffs(p):
    d = p.dict()
    if len(p.parent().gens()) == 1:
        return vector([d.get(0, 0), d.get(1, 0)])
    else:
        return vector([
            d.get(ETuple(tuple(b)), 0)
            for b in B(len(p.parent().gens()))
        ])

In [8]:
def encode(v, rate_bits):
    n_bits = log2_strict(len(v))
    field = v.base_ring()
    return v * rs_matrix(field, n_bits, rate_bits)

def interpolate2(x0, y0, x1, y1, x):
    return y0 + (y1 - y0) * (x - x0) / (x1 - x0)

def fold_evals(v, beta):
    # diag(T1) = (g^0, .., g^n/2), diag(T2) = (g^(n/2+1), .., g^(n-1))
    n_bits = log2_strict(len(v))
    sg = two_adic_subgroup(v.base_ring(), n_bits)
    T1, T2 = sg[:len(v)/2], sg[len(v)/2:]
    return vector([
        interpolate2(T1[j], v[j], T2[j], v[j+len(v)/2], beta)
        for j in range(len(v)/2)
    ])

In [14]:
EvalProof = namedtuple('EvalProof', [
    'comms', 'y', 'hs'
])

def basefold_prove(t, evals, rate_bits, challenge_field):
    n_vars = log2_strict(len(evals))
    f = mle(evals)
    codeword = encode(multilinear_coeffs(f), rate_bits)
    oracles = [MerkleTree(codeword)]
    t.observe(oracles[0].root())
    z = [t.challenge(challenge_field) for _ in range(n_vars)]
    y = f(z)
    print(z, y)
    rs = []
    def eval_sum(b):
        x = b + [X] + rs
        return f(x) * eq(z)(x)
    hs = [sum(eval_sum(b) for b in B(n_vars-1))]
    for i in range(n_vars):
        rs.insert(0, t.challenge(challenge_field))
        codeword = fold_evals(codeword, rs[0])
        oracles.append(MerkleTree(codeword))
        t.observe(oracles[-1].root())
        if i < n_vars - 1:
            hs.append(sum(eval_sum(b) for b in B(n_vars-i-2)))
    print(f'{codeword=}')
    print(f'{f(rs)=}')
    return EvalProof(comms=[o.root() for o in oracles], y=y, hs=hs)

def basefold_verify(t, proof, n_vars, rate_bits, challenge_field):
    comms, y, hs = proof
    t.observe(comms[0])
    z = [t.challenge(challenge_field) for _ in range(n_vars)]
    print(z, y)
    rs = []
    old_eval = y
    for comm, h in zip(comms[1:], hs):
        assert h(0) + h(1) == old_eval
        rs.insert(0, t.challenge(challenge_field))
        old_eval = h(rs[0])
        t.observe(comm)
    # todo: IOPP.query
    final_eval = old_eval / eq(rs, z)
    print(f'{final_eval=}')
    assert MerkleTree(encode(vector([final_eval]), 1)).root() == comms[-1]

# Parameters
n_vars = 5
rate_bits = 1
with seed(0):
    evals = matrix.random(F, 1, 2^n_vars).row(0)
print(len(evals))

proof = basefold_prove(Transcript(), evals, rate_bits, FE)
print(); print(proof); print()
basefold_verify(Transcript(), proof, n_vars, rate_bits, FE)

32
[1499181377*u + 1199012862, 1180221880*u + 629569893, 1534477982*u + 1391300206, 1216036703*u + 1991132894, 1250196892*u + 905577601] 1996849893*u + 1856409673
codeword=(798045927*u + 1910377036, 798045927*u + 1910377036)
f(rs)=798045927*u + 1910377036

EvalProof(comms=[b'fe94afbb', b'afec1c6c', b'b8eb985a', b'79b5c7da', b'9ce40ee9', b'dcf9c335'], y=1996849893*u + 1856409673, hs=[(1257690178*u + 1814137626)*X^2 + (369943786*u + 1481993966)*X + 1191240925*u + 286772001, (1199660290*u + 684527314)*X^2 + (553659853*u + 766749018)*X + 201412270*u + 1109390407, (1889954906*u + 460017120)*X^2 + (703272500*u + 624244290)*X + 1720892355*u + 1589138416, (1229167621*u + 1429114330)*X^2 + (1160454070*u + 56502040)*X + 1544554262*u + 1542444642, (980821393*u + 1768304341)*X^2 + (888699390*u + 1646143300)*X + 504404100*u + 1175369309])

[1499181377*u + 1199012862, 1180221880*u + 629569893, 1534477982*u + 1391300206, 1216036703*u + 1991132894, 1250196892*u + 905577601] 1996849893*u + 1856409673
f

In [10]:
EvalProof = namedtuple('EvalProof', [
    'comms', 'y', 'hs'
])

In [40]:
def binfold_prove(domain_field, challenge_field, t, evals, z, y, rate_bits):
    n_vars = log2_strict(len(evals))
    d = domain_field.degree()
    d_bits = log2_strict(d)
    assert len(evals) % d == 0, "domain degree must divide trace length"
    
    packed_evals = vector([domain_field(evals[i:i+d]) for i in range(0, len(evals), d)])
    poly = multilinear_coeffs(mle(packed_evals))
    codeword = encode(poly, rate_bits)
    
    oracles = [MerkleTree(codeword)]
    t.observe(oracles[0].root())
    
    f = mle(evals)
    rs = []
    def eval_sum(b):
        r.<X> = domain_field[]
        x = b + [X] + rs
        return f(x) * eq(z)(x)
    
    hs = [sum(eval_sum(b) for b in B(n_vars-1))]
    for i in range(n_vars - d_bits):
        rs.insert(0, t.challenge(challenge_field))
        codeword = fold_evals(codeword, rs[0])
        oracles.append(MerkleTree(codeword))
        t.observe(oracles[-1].root())
        if True or i < n_vars - 1:
            hs.append(sum(eval_sum(b) for b in B(n_vars-i-2)))
    
    print(f'{codeword=}')
    u = domain_field.gens()[0]
    assert all(symbol == f([0]+rs) + u * f([1]+rs) for symbol in codeword)
    
    return EvalProof(comms=[o.root() for o in oracles], y=y, hs=hs)

def binfold_verify(domain_field, challenge_field, t, proof, n_vars, z, rate_bits):
    comms, y, hs = proof
    t.observe(comms[0])
    rs = []
    old_eval = y
    for comm, h in zip(comms[1:], hs[:-1]):
        assert h(0) + h(1) == old_eval
        rs.insert(0, t.challenge(challenge_field))
        old_eval = h(rs[0])
        t.observe(comm)

    # todo: IOPP.query
    assert hs[-1](0) + hs[-1](1) == old_eval
    f0 = hs[-1](0) / eq([0]+rs, z)
    f1 = hs[-1](1) / eq([1]+rs, z)
    u = domain_field.gens()[0]
    print(f'{f0 + u*f1=}')
    assert MerkleTree(encode(vector([f0 + u*f1]), 1)).root() == comms[-1]

n_vars = 5
rate_bits = 1
with seed(0):
    evals = matrix.random(F, 1, 2^n_vars).row(0)
print(len(evals), evals)

with seed(123): z = [FE.random_element() for _ in range(n_vars)]
y = mle(evals)(z)
print(f'proving p({z}) = {y}')

proof = binfold_prove(FE, FE, Transcript(), evals, z, y, rate_bits)
print(); print(proof); print()
binfold_verify(FE, FE, Transcript(), proof, n_vars, z, rate_bits)

32 (239314053, 1752017257, 1104552879, 1403531344, 95970379, 40893827, 714018130, 1670834276, 1117980501, 1899472573, 392374258, 1127884130, 1505102375, 1320876833, 171853400, 389593204, 1909242991, 960203414, 1249605700, 171838036, 1894868810, 581665188, 1659567657, 324824829, 321456214, 421210553, 1859528577, 61056999, 1858629403, 300476657, 1062500079, 1911182510)
proving p([90983862*u + 1672417394, 1464525613*u + 374489362, 1933342190*u + 274816201, 982149277*u + 250167490, 731848166*u + 1638096986]) = 337604180*u + 68796779
codeword=(1450103034*u + 1213348500, 1450103034*u + 1213348500)

EvalProof(comms=[b'6e98b27f', b'b9f481e8', b'ad293db2', b'a93a1612', b'375b3c20'], y=337604180*u + 68796779, hs=[(1586601167*u + 454330318)*X^2 + (1178072827*u + 1906621340)*X + 799731014*u + 1873821442, (363794974*u + 1298389066)*X^2 + (1112165725*u + 1678096280)*X + 683578733*u + 1595399468, (1149845144*u + 65564683)*X^2 + (939432230*u + 718652166)*X + 844424944*u + 625665703, (1337902894*u + 12

In [41]:
F17 = GF(17)
R17.<X17> = F17[]
F289.<z17> = F17.extension(X17^2 - 3)
F17, F289

(Finite Field of size 17, Finite Field in z17 of size 17^2)

In [42]:
factor(17-1), factor(17^2-1)

(2^4, 2^5 * 3^2)

##### Committing to a trace with 5 vars, which is larger than the alphabet of GF(17)

In [44]:
n_vars = 5
rate_bits = 1
with seed(1):
    evals = matrix.random(F17, 1, 2^n_vars).row(0)
print(len(evals), evals)

with seed(123): z = [F289.random_element() for _ in range(n_vars)]
y = mle(evals)(z)
print(f'proving p({z}) = {y}')

proof = binfold_prove(F289, F289, Transcript(), evals, z, y, rate_bits)
print(); print(proof); print()
binfold_verify(F289, F289, Transcript(), proof, n_vars, z, rate_bits)

32 (8, 5, 15, 10, 13, 0, 11, 16, 15, 4, 9, 8, 8, 8, 6, 1, 8, 11, 8, 13, 6, 5, 8, 12, 5, 6, 1, 9, 14, 10, 9, 16)
proving p([4*z17 + 15, 16*z17 + 7, 15*z17, 12*z17 + 8, 10*z17 + 10]) = 11*z17 + 8
codeword=(6*z17 + 7, 6*z17 + 7)

EvalProof(comms=[b'8882532c', b'8bf4c59f', b'42543e65', b'64d9d43d', b'346ea1f7'], y=11*z17 + 8, hs=[(15*z17 + 4)*X^2 + 9*X + 15*z17 + 6, (8*z17 + 13)*X^2 + (8*z17 + 8)*X + 9*z17 + 7, (5*z17 + 16)*X^2 + (8*z17 + 10)*X + 12*z17 + 16, X^2 + (8*z17 + 2)*X + 10*z17 + 6, (15*z17 + 5)*X^2 + (8*z17 + 6)*X + 8*z17 + 10])

f0 + u*f1=6*z17 + 7
