Declare some basic symbols:
A_v denotes a variable A that is a vector
Capital letter denote Group points
Small letters denote group scalars

polynomial terms are also used in `X`, `T`, but they would be clear from context

In [2]:
from sympy import *
v, G = MatrixSymbol('v', 1, 1), MatrixSymbol('G', 1, 1)
h_n = 8
g_n = 4
gamma_v = MatrixSymbol('gamma_v', 1, h_n)
G_v = MatrixSymbol('G_v', g_n, 1)
H_v = MatrixSymbol('H_v', h_n, 1)
# T value commitment
# gamma_v scalar is multiplies to H_0, the first term of H_v. Others are 0
V = v*G + gamma_v*H_v
print(V)

gamma_v*H_v + v*G


Define
    - d_v(vector of digits of v),
    - m_v(vector of multiplicities of digits of v)
    - r_v(vector of reciprocals of v)
    - s(vector of blinding factors)

We will first attempt at dealing with the last verification equation in terms of T

$r_i = \frac{1}{(e + d_i)} => r_i \times d_i = 1 - e \times r_i$

$\Sigma(\frac{1}{(e + i)})$

In [3]:
from sympy.solvers.solveset import linear_coeffs
d_v, m_v, r_v, s_v = MatrixSymbol('d_v', 1, g_n), MatrixSymbol('m_v', 1, g_n), MatrixSymbol('r_v', 1, g_n), MatrixSymbol('s_v', 1, g_n)
ld_v, lm_v, lr_v, ls_v = MatrixSymbol('ld_v', 1, h_n), MatrixSymbol('lm_v', 1, h_n), MatrixSymbol('lr_v', 1, h_n), MatrixSymbol('ls_v', 1, h_n)

b_s, b_d, b_m, b_r = MatrixSymbol('b_s', 1, 1), MatrixSymbol('b_d', 1, 1), MatrixSymbol('b_m', 1, 1), MatrixSymbol('b_r', 1, 1)

D = b_d*G + d_v*G_v + ld_v*H_v
M = b_m*G + m_v*G_v + lm_v*H_v
R = b_r*G + r_v*G_v + lr_v*H_v
S = b_s*G + s_v*G_v + ls_v*H_v

# create w_v
T = symbols('T', commutative=True)
x = symbols('Y', commutative=True)
q = symbols('Q', commutative=True)
e = symbols('e', commutative=True)

b = symbols('bits', commutative=True)

alpha_d = transpose(Matrix([b**0, b**1, b**2, b**3]))
alpha_m = transpose(Matrix([1/e, 1/(e + 1), 1/(e + 2), 1/(e + 3)]))

# Create the public terms
alpha_r2 = transpose(Matrix([e, e, e, e]))
one_v = transpose(Matrix([1, 1, 1, 1]))
alpha_r = -1*one_v

Q_v = transpose(Matrix([q, q**2, q**3, q**4]))
Q_inv_v = transpose(Matrix([q**-1, q**-2, q**-3, q**-4]))

# Public value
P = hadamard_product(Q_inv_v, T**3*alpha_d + x*T**4*alpha_m + x*T**2*alpha_r)*G_v #  + T**3*alpha_d)*G_v
P = P + T**2*alpha_r2*G_v

# Consider the expression
C = S + T*M + T**2*D + T**3*R + 2*T**5*V + P
C = expand(C)
w_v = linear_coeffs(C, G_v)[0]
l_v = linear_coeffs(C, H_v)[0]
v_hat = linear_coeffs(C, G)[0]

v_hat = v_hat[0] + 2*T**5*one_v.dot(Q_v) # Reciprocal constraint constant
v_hat = v_hat + 2*T**5*((alpha_d*transpose(alpha_r2))[0, :].as_explicit()[0]) # Sum value constraint constant
v_hat = v_hat + 2*T**5*x*((alpha_r*transpose(hadamard_product(Q_inv_v, alpha_d)))[0, :].as_explicit()[0]) # Range check constraint constant
v_hat = v_hat + T**8*x**2*((alpha_m*transpose(hadamard_product(Q_inv_v, alpha_m)))[0, :].as_explicit()[0])

print("Coeffs of G_v:", w_v)
print("Coeffs of H_v:", l_v)
print("Coeffs of G:", v_hat)

Coeffs of G_v: T**3*r_v + T**2*d_v + T*m_v + Matrix([[T**2*e + T**4*Y/(Q*e) + T**3/Q - T**2*Y/Q, T**4*Y/(Q**2*e + Q**2) + T**2*e + T**3*bits/Q**2 - T**2*Y/Q**2, T**4*Y/(Q**3*e + 2*Q**3) + T**2*e + T**3*bits**2/Q**3 - T**2*Y/Q**3, T**4*Y/(Q**4*e + 3*Q**4) + T**2*e + T**3*bits**3/Q**4 - T**2*Y/Q**4]]) + s_v
Coeffs of H_v: T**3*lr_v + T**2*ld_v + T*lm_v + (2*T**5)*gamma_v + ls_v
Coeffs of G: T**8*Y**2*(1/(Q*e**2) + 1/(Q**2*(e + 1)**2) + 1/(Q**3*(e + 2)**2) + 1/(Q**4*(e + 3)**2)) + 2*T**5*Y*(-1/Q - bits/Q**2 - bits**2/Q**3 - bits**3/Q**4) + 2*T**5*(Q**4 + Q**3 + Q**2 + Q) + 2*T**5*(bits**3*e + bits**2*e + bits*e + e) + 2*T**5*v[0, 0] + T**3*b_r[0, 0] + T**2*b_d[0, 0] + T*b_m[0, 0] + b_s[0, 0]


The challenge $q$ is used to separate the `n` reciprocal constraints and one final sum constraint. The sum constraint occurs in $q^0$ and other constraints start from $q^1$


Norm argument:
$|\vec{w}|^2_q + \langle \vec{l},\vec{c} \rangle = v$ for a given $C = vG + \langle w, \vec{G} \rangle + \langle l,\vec{H} \rangle$

In [4]:
import sympy.vector
R = symbols('R', commutative=True)
c_v = Matrix([[T], [T**2], [T**3], [T**4], [T**6], [T**7], [0], [0]])
c_v = R*c_v
w_norm = hadamard_product(Q_v,w_v)*transpose(w_v)
expanded_expr = expand(w_norm[0] + (l_v*c_v)[0, :].as_explicit()[0] - v_hat)

# print("T5 coeff", expanded_expr.coeff(T, 5)) # Q (.) (d_v . r_v) + r*(ld_v(2) + lm_v(3) + lr_v(1)) - 2v
print("Value Total Check:", expanded_expr.coeff(T, 5).coeff(R, 0).coeff(x, 0).coeff(q, 0))
print("Reciprocal Check:", expanded_expr.coeff(T, 5).coeff(R, 0).coeff(x, 0).coeff(q, 1))
print("Range Check:", collect(expanded_expr.coeff(T, 5).coeff(R, 0).coeff(x, 1), q))
print("Extra terms:", expanded_expr.coeff(T, 5).coeff(R, 1))


Value Total Check: 2*bits**3*d_v[0, 3] + 2*bits**2*d_v[0, 2] + 2*bits*d_v[0, 1] + 2*d_v[0, 0] - 2*v[0, 0]
Reciprocal Check: 2*e*r_v[0, 0] + 2*d_v[0, 0]*r_v[0, 0] - 2
Range Check: -2*r_v[0, 0] - 2*r_v[0, 1] - 2*r_v[0, 2] - 2*r_v[0, 3] + 2*m_v[0, 3]/(e + 3) + 2*m_v[0, 2]/(e + 2) + 2*m_v[0, 1]/(e + 1) + 2*m_v[0, 0]/e
Extra terms: ld_v[0, 2] + lm_v[0, 3] + lr_v[0, 1]


Now, let's move onto trying to adjust other terms in different powers of t that we cannot balance. Makes sense to try from last power as terms till $t^8$ are public terms. Forcing `gamma` corresponding to the terms  to be 0.

In [5]:
print("Last terms(T12):", expanded_expr.coeff(T, 13))
print("Last terms(T12):", expanded_expr.coeff(T, 12))
print("Last terms(T11):", expanded_expr.coeff(T, 11))
print("Last terms(T10):", expanded_expr.coeff(T, 10))
print("Last terms(T9):", expanded_expr.coeff(T, 9))
print("Last terms(T8):", expanded_expr.coeff(T, 8).collect(q))
print("Last terms(T1):", expanded_expr.coeff(T, 1).collect(q))

Last terms(T12): 0
Last terms(T12): 2*R*gamma_v[0, 5]
Last terms(T11): 2*R*gamma_v[0, 4]
Last terms(T10): R*lr_v[0, 5]
Last terms(T9): 2*R*gamma_v[0, 3] + R*ld_v[0, 5] + R*lr_v[0, 4]
Last terms(T8): 2*R*gamma_v[0, 2] + R*ld_v[0, 4] + R*lm_v[0, 5]
Last terms(T1): 2*Q**4*m_v[0, 3]*s_v[0, 3] + 2*Q**3*m_v[0, 2]*s_v[0, 2] + 2*Q**2*m_v[0, 1]*s_v[0, 1] + 2*Q*m_v[0, 0]*s_v[0, 0] + R*ls_v[0, 0] - b_m[0, 0]


Now, we have a $t^7$ term that depends on $\vec{ls}$. It is easy for the prover to set this value so that the overall term becomes zero.

In [6]:
print("Last terms(T7):", expanded_expr.coeff(T, 7))
print("Last terms(T7):", expanded_expr.coeff(T, 7).coeff(R, 0).coeff(q, 2))
print("Last terms(T7):", expanded_expr.coeff(T, 7).coeff(R, 1))

Last terms(T7): 2*Q**4*Y*bits**3/(Q**8*e + 3*Q**8) + 2*Q**4*Y*r_v[0, 3]/(Q**4*e + 3*Q**4) + 2*Q**3*Y*bits**2/(Q**6*e + 2*Q**6) + 2*Q**3*Y*r_v[0, 2]/(Q**3*e + 2*Q**3) + 2*Q**2*Y*bits/(Q**4*e + Q**4) + 2*Q**2*Y*r_v[0, 1]/(Q**2*e + Q**2) + 2*R*gamma_v[0, 1] + R*lm_v[0, 4] + R*lr_v[0, 3] + R*ls_v[0, 5] + 2*Y*r_v[0, 0]/e + 2*Y/(Q*e)
Last terms(T7): 2*Y*bits/(Q**4*e + Q**4) + 2*Y*r_v[0, 1]/(Q**2*e + Q**2)
Last terms(T7): 2*gamma_v[0, 1] + lm_v[0, 4] + lr_v[0, 3] + ls_v[0, 5]


Note: This comes back to the $t^7$ issue that we were discussing with Liam. In this setting, it is impossible to balance the $t^7$ because it depends on secret data and challenge data that is unknown yet. There are two approaches to fix this issue
- The most natural way would be to introduce a $t^7$ term in the $\vec{c}$
- The other way(as Liam suggested) would be to modify the $\vec{c}$ so that the term that generates $t^7$ from $l_r$ can be set "free". To do this, we would need to set $t^4$ power with an additional challenge.

This second route is promising and probably works, but for understanding, we focus on the first route

The logic for the remaining powers of $t^i$, $i \leq 7$ and $i \ne 5$ is the same. We have a free term from the blinded $ls$ vector that we can use to balance out the equation. 

Note that it is crucial that we cannot generate power of $t^5$ from the blinding term as it would disturb the soundness of the protocol and allow to prover to create bad proofs

In [7]:
print("Last terms(T0):", expanded_expr.coeff(T, 0))
temp_expr = expand((l_v*c_v)[0, :].as_explicit()[0])
print("temp expr:", temp_expr.coeff(T, 1))
# print("Last terms(T10):", expanded_expr.coeff(T, 10))
# print("Last terms(T9):", expanded_expr.coeff(T, 9))
print("Last terms(T8):", expanded_expr.coeff(T, 8))

Last terms(T0): Q**4*s_v[0, 3]**2 + Q**3*s_v[0, 2]**2 + Q**2*s_v[0, 1]**2 + Q*s_v[0, 0]**2 - b_s[0, 0]
temp expr: R*ls_v[0, 0]
Last terms(T8): Q**4*Y**2/(Q**8*e**2 + 6*Q**8*e + 9*Q**8) + Q**3*Y**2/(Q**6*e**2 + 4*Q**6*e + 4*Q**6) + Q**2*Y**2/(Q**4*e**2 + 2*Q**4*e + Q**4) + 2*R*gamma_v[0, 2] + R*ld_v[0, 4] + R*lm_v[0, 5] - Y**2/(Q**4*e**2 + 6*Q**4*e + 9*Q**4) - Y**2/(Q**3*e**2 + 4*Q**3*e + 4*Q**3) - Y**2/(Q**2*e**2 + 2*Q**2*e + Q**2)


Write down all constraints on blinding factors 

In [8]:
from sympy import *
n0, n1 = symbols('n0, n1', commutative=True)
c0, c1 = symbols('c0, c1', commutative=True)
l0, l1 = symbols('l0, l1', commutative=True)
G0, G1 = symbols('G0, G1', commutative=True)
H0, H1 = symbols('H0, H1', commutative=True)
e, q = symbols('e q', commutative=True)

v_init = n0*n0*q*q + n1*n1*q*q*q*q + l0*c0 + l1*c1
n_prime = n0/q + e*n1
l_prime = l0 + e*l1
c_prime = c0 + e*c1

x = c0*l1 + c1*l0 + 2*n0*n1*q**3
r = c1*l1 + n1*n1*q**4
v1_final = v_init + e * x + (e**2 - 1)*r

v_final = n_prime * n_prime*q*q + l_prime * c_prime

print(expand(v_final))
print(expand(v1_final))


c0*e*l1 + c0*l0 + c1*e**2*l1 + c1*e*l0 + e**2*n1**2*q**2 + 2*e*n0*n1*q + n0**2
c0*e*l1 + c0*l0 + c1*e**2*l1 + c1*e*l0 + e**2*n1**2*q**4 + 2*e*n0*n1*q**3 + n0**2*q**2


In [9]:
from sympy import *
n0, n1, n2, n3 = symbols('n0 n1 n2 n3', commutative=True)
G0, G1, G2, G3 = symbols('G0 G1 G2 G3', commutative=True)
e1, e2, q = symbols('e1, e2, q', commutative=True)

# v_init = n0*n0*q + n1*n1*q*q + n2*n2*q*q*q + n3*n3*q*q*q*q
# n_prime1 = n0*q**(-1/2) + e1*n2
# n_prime2 = n1*q**(-1/2) + e1*n3
# n_final = n_prime1 * q**(-1) + n_prime2 * e2
# print(n_final)

# X1 = 2*(n0*n2 + n1*n3)*q**(3/2)
# R1 = (n2*n2 + n3*n3)*q**2
# X2 = 2*(n_prime1*n_prime2)*q**3
# R2 = (n_prime2*n_prime2)*q**4
# v1_final = v_init + e1 * X1 + (e1**2 - 1)*R1 + e2 * X2 + (e2**2 - 1)*R2

# v_final = n_final * n_final*q**4

# print(simplify(v_final))
# print(simplify(v1_final))

v_init = n0*n0*q + n1*n1*q*q + n2*n2*q*q*q + n3*n3*q*q*q*q
n_prime1 = n0*q**(-1/2) + e1*n1
n_prime2 = n2*q**(-1/2) + e1*n3
n_final = n_prime1*q**(-1) + n_prime2 * e2
# print(n_final)

X1 = 2*(n0*n1*q**(3/2) + n2*n3*q**(7/2))
R1 = (n1*n1*q**2 + n3*n3*q**4)

v_int1 = v_init + e1 * X1 + (e1**2 - 1)*R1
v_int2 = n_prime1*n_prime1*q**2 + n_prime2*n_prime2*q**4

print(expand(v_int1))
print(expand(v_int2))
print(expand(v_int1) - expand(v_int2))

# X2 = 2*(n_prime1*n_prime2)*q**3
# R2 = (n_prime2*n_prime2)*q**4
# v1_final = v_init + e1 * X1 + (e1**2 - 1)*R1 + e2 * X2 + (e2**2 - 1)*R2

# v_final = n_final * n_final * q**4

# print(expand(v_final).coeff(e1))
# print(expand(v1_final).coeff(e1))

# expand(v_final) == expand(v1_final)

e1**2*n1**2*q**2 + e1**2*n3**2*q**4 + 2*e1*n0*n1*q**1.5 + 2*e1*n2*n3*q**3.5 + n0**2*q + n2**2*q**3
e1**2*n1**2*q**2 + e1**2*n3**2*q**4 + 2*e1*n0*n1*q**1.5 + 2*e1*n2*n3*q**3.5 + n0**2*q**1.0 + n2**2*q**3.0
-n0**2*q**1.0 + n0**2*q + n2**2*q**3 - n2**2*q**3.0


In [10]:
from sympy import *
x1, x2, y1, y2, a1, a2, b1, b2 = symbols('x1 x2 y1 y2 a1 a2 b1 b2', commutative=True)
xr, yr, lambdar = symbols('xr yr lambdar', commutative=True)
xb1, xb2, yb1, yb2 = symbols('xb1 xb2 yb1 yb2', commutative=True)
xb1 = x1 + a1
xb2 = x2 + a2
yb1 = y1 + b1
yb2 = y2 + b2

lambdar = (y2 - y1)/(x2 - x1)
xr = lambdar**2 - x1 - x2
yr = lambdar*(x1 - xr) - y1
print(xr)
print(yr)

-x1 - x2 + (-y1 + y2)**2/(-x1 + x2)**2
-y1 + (-y1 + y2)*(2*x1 + x2 - (-y1 + y2)**2/(-x1 + x2)**2)/(-x1 + x2)


In [11]:
from sympy import *
s, m, d, r, alpha_m, t = symbols('s m d r alpha_m t', commutative=True)
w = s + m*t + d*t**2 + r*t**3 + alpha_m*t**4
print(expand(w*w).collect(t))

alpha_m**2*t**8 + 2*alpha_m*r*t**7 + 2*m*s*t + s**2 + t**6*(2*alpha_m*d + r**2) + t**5*(2*alpha_m*m + 2*d*r) + t**4*(2*alpha_m*s + d**2 + 2*m*r) + t**3*(2*d*m + 2*r*s) + t**2*(2*d*s + m**2)


In [12]:
lm0, lm1, lm2, lm3, lm4, lm5 = symbols('lm0 lm1 lm2 lm3 lm4 lm5', commutative=True)
ld2, ld4 = symbols('ld2 ld4', commutative=True)
T = symbols('T', commutative=True)
expr = lm0*T**2 + lm1*T**3 + lm2*T**4 + lm3*T**5 + lm4*T**7 + lm5*T**8
expr = expr + ld2*T**5 + ld4*T**8
print(expand(expr).collect(T))

T**8*(ld4 + lm5) + T**7*lm4 + T**5*(ld2 + lm3) + T**4*lm2 + T**3*lm1 + T**2*lm0
