In [50]:
from scipy.interpolate import lagrange
import galois
# from tate_bilinear_pairing import eta, ecc, f3m
from py_ecc.bn128 import bn128_curve, bn128_pairing
from py_ecc.fields import (
    bn128_FQ as FQ,
    bn128_FQ2 as FQ2,
    bn128_FQ12 as FQ12,
    bn128_FQP as FQP,
)
import random
import math
import numpy as np
# from poly_utils import PrimeField
from scipy.interpolate import lagrange
from fractions import Fraction

In [288]:
class PrimeField():
    def __init__(self, modulus):
        assert pow(2, modulus, modulus) == 2
        self.modulus = modulus

    def add(self, x, y):
        return (x+y) % self.modulus

    def sub(self, x, y):
        return (x-y) % self.modulus

    def mul(self, x, y):
        return (x*y) % self.modulus

    def exp(self, x, p):
        if p < 0:
            return self.inv(pow(x, -p, self.modulus))
        else:
            return pow(x, p, self.modulus)
    
    def toField(self, x):
        return x % self.modulus

    # Modular inverse using the extended Euclidean algorithm
    def inv(self, a):
        if a == 0:
            return 0
        lm, hm = 1, 0
        low, high = a % self.modulus, self.modulus
        while low > 1:
            r = high//low
            nm, new = hm-lm*r, high-low*r
            lm, low, hm, high = nm, new, lm, low
        return lm % self.modulus

    def multi_inv(self, values):
        partials = [1]
        for i in range(len(values)):
            partials.append(self.mul(partials[-1], values[i] or 1))
        inv = self.inv(partials[-1])
        outputs = [0] * len(values)
        for i in range(len(values), 0, -1):
            outputs[i-1] = self.mul(partials[i-1], inv) if values[i-1] else 0
            inv = self.mul(inv, values[i-1] or 1)
        return outputs

    def div(self, x, y):
        return self.mul(x, self.inv(y))

    # Evaluate a polynomial at a point. The poly must have consecutive orders starting from "init_order".
    def eval_poly_at(self, p, x, init_order = 0):
        y = 0
        power_of_x = 1
        for i, p_coeff in enumerate(p):
            y += power_of_x * p_coeff
            power_of_x = (power_of_x * x) % self.modulus
        if init_order < 0:
            change_order = self.inv(x ** -init_order)
        else:
            change_order = x ** init_order % self.modulus
        return (y * change_order) % self.modulus
    
    def eval_poly_Y(self, coeffs, order, y):
        new_coeffs = []
        for poly in coeffs:
            new_coeffs.append([[self.eval_poly_at(poly[0], y, poly[1])], 0])
        return new_coeffs, order
        
    # Arithmetic for polynomials
    def add_polys(self, a, b, init_order_a = 0, init_order_b = 0):
        init_order_result = min(init_order_a, init_order_b)
        return ([((a[i-init_order_a+init_order_result] if (i-init_order_a+init_order_result < len(a) and i-init_order_a+init_order_result >= 0) else 0) + 
                  (b[i-init_order_b+init_order_result] if (i-init_order_b+init_order_result < len(b) and i-init_order_b+init_order_result >= 0) else 0))
                % self.modulus for i in range(max(len(a)+init_order_a, len(b)+init_order_b)-init_order_result)], init_order_result)

        # Arithmetic for polynomials
    def add_polys_bivar(self, a, b, init_order_a = 0, init_order_b = 0):
        init_order_result = min(init_order_a, init_order_b)
        result_poly = []
        for i in range(max(len(a)+init_order_a, len(b)+init_order_b)-init_order_result):
            poly_a = a[i-init_order_a+init_order_result] if (i-init_order_a+init_order_result < len(a) and i-init_order_a+init_order_result >= 0) else [[0], 0]
            poly_b = b[i-init_order_b+init_order_result] if (i-init_order_b+init_order_result < len(b) and i-init_order_b+init_order_result >= 0) else [[0], 0]
#             print(poly_a)
#             print(poly_b)
            result_poly.append(field.add_polys(poly_a[0], poly_b[0], poly_a[1], poly_b[1]))
        return result_poly, init_order_result

    
    def mul_by_const(self, a, c):
        return [(x*c) % self.modulus for x in a[0]], a[1]
    
    def mul_polys(self, a, b, init_order_a = 0, init_order_b = 0):
        o = [0] * (len(a) + len(b) - 1)
        for i, aval in enumerate(a):
            for j, bval in enumerate(b):
                o[i+j] += a[i] * b[j]
        return [x % self.modulus for x in o], init_order_a + init_order_b
    
    def mul_polys_bivar(self, a, b, init_order_a = 0, init_order_b = 0):
        o = [[[0], 0]] * (len(a) + len(b) - 1)
        for i, aval in enumerate(a):
            for j, bval in enumerate(b):
                mul_result = self.mul_polys(a[i][0], b[j][0], a[i][1], b[j][1])
                o[i+j] = self.add_polys(o[i+j][0], mul_result[0], o[i+j][1], mul_result[1])

        return o, init_order_a + init_order_b
    
    
    def div_polys(self, a, b, init_order_a = 0, init_order_b = 0):
        assert len(a) >= len(b)
        a = [x for x in a]
        o = []
        apos = len(a) - 1
        bpos = len(b) - 1
        diff = apos - bpos
        while diff >= 0:
            quot = self.div(a[apos], b[bpos])
            o.insert(0, quot)
            for i in range(bpos, -1, -1):
                a[diff+i] -= b[i] * quot
            apos -= 1
            diff -= 1
        return [x % self.modulus for x in o], init_order_a - init_order_b

#     def mod_polys(self, a, b):
#         return self.sub_polys(a, self.mul_polys(b, self.div_polys(a, b)))[:len(b)-1]

    # Build a polynomial from a few coefficients, together with init_order
    def sparse(self, coeff_dict):
        degree = max(coeff_dict.keys()) - min(coeff_dict.keys())
        o = [0] * (degree + 1)
        for k, v in coeff_dict.items():
            o[k - min(coeff_dict.keys())] = v % self.modulus
        return (o, min(coeff_dict.keys()))
    
    def sparse_bivar(self, coeff_dict):
        degree = max(coeff_dict.keys()) - min(coeff_dict.keys())
        o = [[[0], 0]] * (degree + 1)
#         print(o)
        for k, v in coeff_dict.items():
            o[k - min(coeff_dict.keys())] = v
        return (o, min(coeff_dict.keys()))
    
    def lagrange(self, xs, ys):
        fn = lagrange(xs, ys)
        
        MAX_DENOM = 100
        fractions = [Fraction(val).limit_denominator(MAX_DENOM)
                     for val in fn.c]
        ratios = np.array([(f.numerator, f.denominator) for f in fractions])
        factor = np.lcm.reduce(ratios[:,1])
        result = [round(v * factor) for v in fn.c]
        return result, factor, 0


In [289]:
order = bn128_curve.curve_order
# GF_curve = galois.GF(order)
# GF_field = galois.GF(bn128_curve.field_modulus)
field = PrimeField(order)

In [290]:
n = 2
q = 5
srsX = 12
srsAlpha = 10

In [291]:
G1 = bn128_curve.G1
G2 = bn128_curve.G2

srsD = n * 3
gNegativeX = [bn128_curve.multiply(bn128_curve.G1, field.exp(srsX, -i)) for i in range(1,srsD)]
gPositiveX = [bn128_curve.multiply(bn128_curve.G1, field.exp(srsX, i)) for i in range(0,srsD)]
# hNegativeX = [bn128_curve.multiply(bn128_curve.G2, field.exp(srsX, -i)) for i in range(1,srsD)]
hPositiveX = [bn128_curve.multiply(bn128_curve.G2, field.exp(srsX, i)) for i in range(0,2)]

# gNegativeAlphaX = [bn128_curve.multiply(bn128_curve.G1, field.mul(srsAlpha, field.exp(srsX, -i))) for i in range(1,srsD)]
# gPositiveAlphaX = [bn128_curve.multiply(bn128_curve.G1, field.mul(srsAlpha, field.exp(srsX, i))) for i in range(1,srsD)]
# hNegativeAlphaX = [bn128_curve.multiply(bn128_curve.G2, field.mul(srsAlpha, field.exp(srsX, -i))) for i in range(1,srsD)]
# hPositiveAlphaX = [bn128_curve.multiply(bn128_curve.G2, field.mul(srsAlpha, field.exp(srsX, i))) for i in range(0,srsD)]


# srs = [bn128_curve.multiply(bn128_curve.G1, 1), bn128_curve.multiply(bn128_curve.G1, int(s)), 
#       bn128_curve.multiply(bn128_curve.G1, int(s**2)), bn128_curve.multiply(bn128_curve.G1, int(s**3)), 
#       bn128_curve.multiply(bn128_curve.G1, int(s**4)), bn128_curve.multiply(bn128_curve.G1, int(s**5)), 
#       bn128_curve.multiply(bn128_curve.G1, int(s**6)), bn128_curve.multiply(bn128_curve.G1, int(s**7)), 
#       bn128_curve.multiply(bn128_curve.G1, int(s**8)), bn128_curve.multiply(bn128_curve.G1, int(s**9)), 
#       bn128_curve.multiply(bn128_curve.G1, int(s**10)), bn128_curve.multiply(bn128_curve.G1, int(s**11)), 
#       bn128_curve.multiply(bn128_curve.G1, int(s**12)), bn128_curve.multiply(bn128_curve.G1, int(s**13)), 
#       bn128_curve.multiply(bn128_curve.G1, int(s**14)), bn128_curve.multiply(bn128_curve.G1, int(s**15)), 
#       bn128_curve.multiply(bn128_curve.G2, 1), bn128_curve.multiply(bn128_curve.G2, int(s))]


In [292]:
# aL = np.array([1, 2, 3])
# aR = np.array([1, 2, 3])
# aO = np.array([1, 4, 9])
# u = np.array([[1, 0, 0],
#               [0, 1, 0]])
# v = np.array([[1, 0, 0],
#               [0, 1, 0]])
# w = np.array([[1, 0, 0],
#               [0, 1, 0]])
# k = np.array([3, 8])

aL = np.array([4, 9])
aR = np.array([9, 4])
aO = np.array([36, 36])
u = np.array([[0, 0], 
               [1, 0], 
               [0, 1], 
               [0, 0], 
               [0, 0]])
v = np.array([[0, 0], 
               [0, 0], 
               [0, 0], 
               [1, 0], 
               [0, 1]])
w = np.array([[1, -1], 
               [0, 0], 
               [0, 0], 
               [0, 0], 
               [0, 0]])
k = np.array([0, 4, 9, 9, 4])

assignment = [aL, aR, aO]
circuit = [u,v,w,k]

assert (aL @ u.T + aR @ v.T + aO @ w.T == k).all()

In [293]:
def rPoly(aL, aR, aO, n):
    list_of_coeff = np.concatenate([np.flip(aO), np.flip(aR), aL])
    list_of_power = np.concatenate([np.arange(-2*n, 0), np.arange(1,n+1)])
    list_of_bi_coeff = []
    for i in range(len(list_of_coeff)):
        dummy_dict = {}
        dummy_dict[list_of_power[i]] = list_of_coeff[i]
        list_of_bi_coeff.append(field.sparse(dummy_dict))
    return field.sparse_bivar(dict(zip(list_of_power, list_of_bi_coeff)))


def sPoly(u,v,w, n, q):
    uiYs = []
    viYs = []
    wiYs = []
    for i in range(n):
        uiYs.insert(0,field.sparse(dict(zip(np.arange(n+1, n+q+1), u[:,i]))))
        viYs.append(field.sparse(dict(zip(np.arange(n+1, n+q+1), v[:,i]))))
        wiPart1 = field.sparse(dict(zip(np.arange(n+1, n+q+1), w[:,i])))
        wiPart2 = field.sparse(dict(zip([-i-1, i+1], [-1, -1])))
        wiYs.append(field.add_polys(wiPart1[0], wiPart2[0], wiPart1[1], wiPart2[1]))
#     return np.concatenate([uiYs, viYs, wiYs], dtype=object)
    return field.sparse_bivar(dict(zip(np.concatenate([np.arange(-n, 0), np.arange(1,2*n+1)]), np.concatenate([uiYs, viYs, wiYs], dtype=object))))

def kPoly(k, n, q):
    return [[field.mul_by_const(field.sparse(dict(zip(np.arange(n+1, n+q+1), k))), -1)], 0]

# def sparse(coeff_dict):
#     degree = max(coeff_dict.keys()) - min(coeff_dict.keys())
#     o = [0] * (degree + 1)
#     for k, v in coeff_dict.items():
#         o[k - min(coeff_dict.keys())] = v % order
#     return (o, min(coeff_dict.keys()))
sXY = sPoly(u,v,w,n,q)
# rXY = rPoly(aL,aR,aO,n)
rXY = rPoly(aL,aR,aO,n)
neg_kXY = kPoly(k, n, q)
# kY = kPoly(k, n, q)
# field.add_polys(rY[0], rY[0], rY[1], rY[1])
rXY
neg_kXY
sXY
r_dash_XY = field.add_polys_bivar(rXY[0], sXY[0], rXY[1], sXY[1])
r_dash_XY
rX1 = field.eval_poly_Y(rXY[0], rXY[1], 1)
rX1
cX = field.mul_polys_bivar(rX1[0], r_dash_XY[0], rX1[1], r_dash_XY[1])
cX
cX = field.add_polys_bivar(cX[0], neg_kXY[0], cX[1], neg_kXY[1])

cX


([([1296, 0, 0, 0, 0], -4),
  ([1296, 1296, 0, 0, 0], -4),
  ([144, 1296, 144, 0, 0, 0, 0, 0, 0, 36, 0, 0], -4),
  ([324, 144, 144, 324, 0, 0, 0, 0, 36, 36, 0, 0], -4),
  ([0, 324, 16, 324, 0, 0, 0, 0, 36, 4, 0, 0], -4),
  ([144, 0, 36, 36, 0, 144, 0, 0, 4, 9, 36, 0], -4),
  ([324, 144, 0, 81, 0, 144, 324, 0, 9, 0, 36, 36], -4),
  ([324,
    16,
    21888242871839275222246405745257275088548364400416034343698204186575808495581,
    0,
    21888242871839275222246405745257275088548364400416034343698204186575808495597,
    324,
    36,
    0,
    4,
    4,
    36],
   -3),
  ([0, 0, 0, 0, 0, 0, 0, 0, 0, 0], -2),
  ([21888242871839275222246405745257275088548364400416034343698204186575808495581,
    77,
    0,
    21888242871839275222246405745257275088548364400416034343698204186575808495613,
    45,
    21888242871839275222246405745257275088548364400416034343698204186575808495585,
    9,
    0,
    0,
    9],
   -2),
  ([21888242871839275222246405745257275088548364400416034343698204186575808