# ## Multilinear Extensions and Encoding Forms
# In HyperKZG (as inherited from Gemini-PCS), a function on {0,1}^n is extended into a multilinear polynomial. 
# Two major representations exist:
# - **Coefficient Form**: Tensor basis
# - **Evaluation Form**: Point-value basis using eq_tilde(X)
#
# This section corresponds to the following equations in the HyperKZG article:
# - Equation (1): $	ilde{f}(X_0, ..., X_{n-1})$ in coefficient form
# - Equation (2): $f(X) = \sum f_i X^i$ (univariate encoding)
# - Equation (3): Tensor dot product view: $\langle 
ec{f}, \otimes_i(1, u_i) 
angle$

# Let's start with the setup.

# Code Cell:
from sympy import symbols, simplify, expand
from sympy.abc import x
from itertools import product

# Parameters
n = 3  # Number of variables (log_2 N)
N = 2**n  # Total number of points in {0,1}^n

print(f"Setting up MLE with n={n} variables. This gives N=2^n={N} total basis terms.")

# Define input symbols X_0, ..., X_{n-1} and u_0, ..., u_{n-1}
X = symbols(f"X0:{n}")  # X0, X1, X2
u = symbols(f"u0:{n}")  # u0, u1, u2

print("Defined symbolic variables:")
print("X =", X)
print("u =", u)

# Binary encoding function (little-endian)
def bits(i, n):
    return list(map(int, format(i, f"0{n}b")))

# eq_tilde = \prod_i ( (1 - X_i)*(1 - u_i) + X_i*u_i )
def eq_tilde(bits_i, u_vector):
    acc = 1
    for bit, ui, Xi in zip(bits_i, u_vector, X):
        acc *= (1 - bit)*(1 - ui) + bit*ui
    return acc


# ## Evaluation Form of MLE Polynomial
# This covers Equation (4):
# $$ 	ilde{f}(X) = \sum_{i=0}^{N-1} a_i \cdot 	ilde{eq}(	ext{bits}(i), X) $$
# Let us define symbolic a_i and construct this MLE.

# Code Cell:
a = symbols(f"a0:{N}")
print("\nConstructing evaluation-form MLE:")
print("Basis coefficients a_i:", a)

f_eval = sum(a[i] * eq_tilde(bits(i, n), u) for i in range(N))

# Simplify and display the evaluation form
f_eval_simplified = simplify(expand(f_eval))
print("\nSymbolic evaluation form of MLE (simplified):")
print(f_eval_simplified)


# ## Univariate Encoding from Coefficient Form
# These correspond to Equations (5), (6), and (7):
# $$ f(X) = \sum f_i X^i $$
# $$ f(-X) = \sum (-1)^i f_i X^i $$
# $$ f(X) \pm f(-X) $$ for folding

# Code Cell:
f = symbols(f"f0:{N}")
print("\nDefining univariate polynomial from coefficients f_i:")
print("f =", f)

f_X = sum(f[i] * x**i for i in range(N))
print("\nf(X) =", f_X)

f_X_neg = f_X.subs(x, -x)
print("\nf(-X) =", f_X_neg)

f_X_plus = simplify(f_X + f_X_neg)
f_X_minus = simplify(f_X - f_X_neg)

print("\nf(X) + f(-X) =", f_X_plus)
print("\nf(X) - f(-X) =", f_X_minus)

f_X_plus, f_X_minus


# ## Round 1 Folding (Equation 8, 9)
# HyperKZG applies the split-and-fold trick directly to the evaluation form.
# Given $	ilde{f}^{(0)}$ over $\{0,1\}^n$, we split evaluations into even and odd halves
# and define:
#
# $$ 	ilde{f}^{(1)}(x) = (1 - u_0) \cdot 	ilde{f}^{(0)}(0, x) + u_0 \cdot 	ilde{f}^{(0)}(1, x) $$
#
# where $x \in \{0,1\}^{n-1}$.

# Code Cell:
print("\n--- Round 1 Folding ---")

# We will simulate f^(0) as the a_i vector over {0,1}^n
print("Simulating f^(0) using symbolic coefficients a0 to a7")

f0_vals = a  # a0 to a7

# Split into even and odd (according to MSB = 0 or 1)
even_half = f0_vals[:N//2]  # (X0=0, X1,X2)
odd_half = f0_vals[N//2:]   # (X0=1, X1,X2)

print("Even (X0=0) half:", even_half)
print("Odd  (X0=1) half:", odd_half)

# Compute folded values: f1[i] = (1 - u0)*even[i] + u0*odd[i]
f1 = [(1 - u[0]) * even_half[i] + u[0] * odd_half[i] for i in range(N//2)]

print("\nFolded values for f^(1)(X1,X2):")
for i, expr in enumerate(f1):
    print(f"f1[{i}] =", expr)