# ⚠️ Before committing to git, run "Cell -> All Output -> Clear"!

# Parameter tracking for protocols
This script contains functions for tracking the protocols effects, including correctness bounds, extraction bounds, required transcripts, and communication.
- Evaluating on symbolic variables gives abstracts expressions for the paper.
- Evaluation (partially) on actual values gives resulting actual bounds.

# TODO
- Can these functions be used for optimization?
  Currently, the "base", "ell", "baseip" and "ellip" parameters are independent variables, but they're actually tightly connected. (Specifying "ell" (resp. "ellip" gives a unique "base" resp. "baseip"). For "ellip", it is defined through required bounds already, no degree of freedom at all.)
  
## Code Robustness
- Make all parameters (except `parin`) of protocol functions into keyword-only! Prevents accidentally swapped parameters. Already caught several bugs. (Doing this in Python is not very ergonomic, but the robustness is worth the added verbosity!)

In [None]:
# Setup variables. Assert domain positive, so that $sqrt(r) r = sqrt(r^3)$ for simplifications
# See: https://doc.sagemath.org/html/en/reference/calculus/sage/symbolic/expression.html#sage.symbolic.expression.Expression.is_negative
var("v_rep", latex_name="r", domain='positive')
var("v_repin", latex_name="r_{\\mathrm{in}}", domain='positive')
var("v_repout", latex_name="r_{\\mathrm{out}}", domain='positive')
var("v_wdim", latex_name="m", domain='positive')
var("v_tdim", latex_name="d", domain='positive')
var("v_nF", latex_name="n_F", domain='positive')
var("v_ntop", latex_name="n'", domain='positive')
var("v_nbot", latex_name="n''", domain='positive')
var("v_nout", latex_name="n^{\\mathrm{out}}", domain='positive')

var("v_beta", latex_name="\\beta", domain='positive')
var("v_beta0", latex_name="\\beta_0", domain='positive')
var("v_beta1", latex_name="\\beta_1", domain='positive')
var("v_beta2", latex_name="\\beta_2", domain='positive')
var("v_beta3", latex_name="\\beta_3", domain='positive')
var("v_betaprime", latex_name="\\beta'", domain='positive')
var("v_beta0prime", latex_name="\\beta'_0", domain='positive')
var("v_beta1prime", latex_name="\\beta'_1", domain='positive')
var("v_beta2prime", latex_name="\\beta'_2", domain='positive')
var("v_beta3prime", latex_name="\\beta'_3", domain='positive')
var("v_betasis", latex_name="\\beta_{\\mathsf{sis}}", domain='positive')

var("v_snderr", latex_name="\\kappa", domain='positive')
var("v_numtr", latex_name="\\#\\mathrm{tr}", domain='positive')

var("v_ell", latex_name="\\ell", domain='positive')
var("v_base", latex_name="b", domain='positive')
var("v_ellip", latex_name="\\ell_{\\mathsf{ip}}", domain='positive')
var("v_baseip", latex_name="b_{\\mathsf{ip}}", domain='positive')

var("v_phi", latex_name="\\varphi", domain='positive')
var("v_q", latex_name="q", domain='positive')
var("v_e", latex_name="e", domain='positive')

var("gammafold", latex_name="\\gamma_{\\mathsf{fold}}", domain='positive')
var("thetafold", latex_name="\\theta_{\\mathsf{fold}}", domain='positive')
var("Cfold", latex_name="\\mathcal{C}_{\\mathsf{fold}}", domain='positive')

var("v_ringelq", latex_name="\\mathcal{R}_{q}", domain='positive') # Ring element (size) placeholder

# To get the positive (or a canonical root) we need to be explicit in Sage.
# For this we can canoncialize radicals. (Otherwise, sqrt(x*y) and sqrt(x) * sqrt(y) are NOT equal, and such simplifications will not be made.)
# See: https://doc.sagemath.org/html/en/reference/calculus/sage/symbolic/expression.html#sage.symbolic.expression.Expression.canonicalize_radical

def has_method(obj, method):
    return callable(getattr(obj, method, None))

def my_canonicalize(x):
    if has_method(x, "canonicalize_radical"):
        return x.canonicalize_radical()
    return copy.deepcopy(x)

# Estimating SIS hardness

def findMSISdelta(beta_sis, n, phi, logq):
    # Function for estimating the MSIS hardness given parameters:
    # a (n x m) matrix in \Rq along with the solution bound B. It returns the
    # root Hermite factor \delta. We use the methodology presented by
    # [GamNgu08, Mic08] and hence m is irrelevant.
    if beta_sis >= 2 ** logq:                  # Check if the norm is above q
        return 2
    logB = math.log(beta_sis, 2)		
    logdelta = logB**2 / (4*n*phi*logq)
    return 2 ** logdelta

In [None]:
import dataclasses
from dataclasses import dataclass
import copy

#from sage.repl.display.pretty_print import SagePrettyPrinter

@dataclass
class RedParams:
    """Keep track of all parameters specifying the reduction properties, e.g., claims w.r.t. correctness and soundness in protocols."""
    # We don't have \mu here, since it won't matter and we don't represent any tensor structure.
    # TODO: Also add communication tracking.
    _: dataclasses.KW_ONLY # All parameters are KW-only.
    phi: int
    q: int
    e: int
    rep: int
    wdim: int
    nF: int
    ntop: int
    nbot: int
    nout: int
    corbeta: float
    sndbeta: float
    snderr: float
    betasis: float
    numtr: float
    prover_comm: int     # Communication sent by prover until now. Protocols used as reductions, so don't include sending the final witness (of size (ntop + nbot)*wdim*rep * ceil(log(beta)) _bits_)
    verifier_comm: int
    def deepcopy(self):
        return copy.deepcopy(self)
    def pretty_print(self):
        # Pretty printing seems pretty involved. (Horribly so? With(out) good reason?)
        # Can't just get a string output in a simple manner... Hack things together like this for now...
        self.canonicalize()
        show((v_phi, self.phi))
        show((v_q, self.q))
        show((v_e, self.e))
        show((v_rep, self.rep))
        show((v_wdim, self.wdim))
        show((v_nF, self.nF))
        show(((v_ntop, v_nbot, v_nout), (self.ntop, self.nbot, self.nout)))
        show((v_beta1, self.corbeta))
        show((v_beta0prime, self.sndbeta))
        #show((v_betasis, self.betasis))
        show((v_snderr, self.snderr))
        show((v_numtr, self.numtr))
        show(("#prover_comm", self.prover_comm))
        show(("#verifier_comm", self.verifier_comm))        
    def canonicalize(self):
        def simplify(x):
            my_canonicalize(x).factor() # FIXME: Not sure if ".factor()" is a good idea or not.
        self.phi = my_canonicalize(self.phi)
        self.q = my_canonicalize(self.q)
        self.e = my_canonicalize(self.e)
        self.rep = my_canonicalize(self.rep)
        self.wdim = my_canonicalize(self.wdim)
        self.nF = my_canonicalize(self.nF)
        self.ntop = my_canonicalize(self.ntop)
        self.nbot = my_canonicalize(self.nbot)
        self.nout = my_canonicalize(self.nout)
        self.corbeta = my_canonicalize(self.corbeta)
        self.sndbeta = my_canonicalize(self.sndbeta)
        self.snderr = my_canonicalize(self.snderr)
        self.betasis = my_canonicalize(self.betasis)
        self.numtr = my_canonicalize(self.numtr)
        self.prover_comm = my_canonicalize(self.prover_comm)
        self.verifier_comm = my_canonicalize(self.verifier_comm)


In [None]:
par_default = RedParams(
  phi = v_phi,
  q = v_q,
  e = v_q,
  rep = v_rep, 
  wdim = v_wdim, 
  nF = v_nF,
  ntop = v_ntop, 
  nbot = v_nbot, 
  nout = v_nout, 
  corbeta = v_beta0, 
  sndbeta = v_beta1prime, 
  snderr = v_snderr, 
  numtr = v_numtr, 
  betasis = v_betasis, 
  prover_comm = 0, 
  verifier_comm = 0
  )
## Some tests
#par2 = dataclasses.replace(par_default, rep = par_default.rep + 5)
#par3 = par2.deepcopy()
#par3.rep += 2
#print(par_default)
#print(par2)
#print(par3)
#par2.pretty_print()

In [None]:
## Atomic Protocol

def pi_bdecomp(parin, *, base, ell):
    parout = parin.deepcopy()
    # FIXME: The ell's depend on the input! How to capture that here? Make a new variable?
    #        Good enough fix for parameter testing: use $\ell$ as input instead, set base to $(2\beta + )^{1/\ell)$.
    parout.rep = parin.rep * ell
    parout.corbeta = 1/2 * sqrt(ell * parin.rep * parin.wdim) * v_phi^(3/2) * base # Hard bound, not a factor!
    #parout.sndbeta = 9 * (parin.corbeta)^(1 - 1/ell) * parin.sndbeta
    parout.sndbeta = 2 * parin.corbeta * parin.sndbeta # Simplified, sloppy bound
    parout.prover_comm += v_ringelq * (ell-1) * (parin.ntop + parin.nbot) * parin.rep
    return parout

#pi_bdecomp(par_default, base=v_base, ell=v_ell).pretty_print()

def pi_split(parin, *, tdim):
    parout = parin.deepcopy()
    # FIXME: The ell's depend on the input! How to capture that here? Make a new variable?
    parout.rep = parin.rep * tdim
    parout.wdim = parin.wdim / tdim
    parout.sndbeta = sqrt(tdim) * parin.sndbeta
    parout.prover_comm += v_ringelq * (tdim - 1) * (parin.ntop + parin.nbot) * parin.rep
    return parout

#pi_split(par_default, tdim=v_tdim).pretty_print()

def pi_fold(parin, *, repout):
    parout = parin.deepcopy()
    parout.rep = repout
    parout.corbeta = sqrt(repout) * parin.rep * gammafold * parin.corbeta
    parout.sndbeta = 2*sqrt(parin.rep) * thetafold * parin.sndbeta
    parout.snderr  = parin.snderr + parin.rep/(Cfold^repout)
    parout.numtr   = parin.numtr * (parin.rep + 1)
    parout.verifier_comm += v_ringelq * repout * parin.rep
    return parout

#pi_fold(par_default, repout=v_repout).pretty_print()

def pi_batch(parin):
    # Implicit assumption: Batch to \nbot = 1
    # FIXME: Requirement 2\beta'_0 \leq \betasis not reflected anywhere...
    parout = parin.deepcopy()
    parout.nbot = 1
    parout.snderr  = parin.snderr + parin.rep * parin.nbot / v_q^v_e
    parout.numtr   = 2*parin.numtr
    parout.verifier_comm += v_ringelq
    return parout

#pi_batch(par_default).pretty_print()

def pi_ipmain(parin, *, ellip, docheck = True):
    # Implicit assumption: Batch to \nbot = 1
    # FIXME: Requirement 2\beta'_0 \leq \betasis not reflected anywhere...
    parout = parin.deepcopy()
    parout.rep = parin.rep + ellip
    parout.nF = parin.nF + 3
    parout.nbot = parin.nbot + 3
    parout.corbeta  = sqrt(2) * parin.corbeta
    parout.snderr  = parin.snderr + 2*parin.wdim / v_q^v_e
    parout.numtr   = 2*parin.numtr
    parout.prover_comm += v_ringelq * (ellip * (parin.ntop + parin.nbot) # Y'
                                        + 3 * (parin.rep + ellip)) # Y_E, Y'_E
    parout.verifier_comm += v_ringelq
    if docheck: # FIXME: Not sure if this works, since v_phi is always a symbolic variable.
        if (2 * parin.corbeta^2 + 1)^(1/ellip) > 2 * parin.corbeta / sqrt(parin.wdim * v_phi^3):
            return None # The $\baseip$-ary decomposition from $\ellip$-pieces does not yield sufficiently small norm.
    return parout

#pi_ipmain(par_default, ellip=v_ell).pretty_print()


In [None]:
## Composite Protocols 

# Normmain: Only the norm and batch. Don't do the last folding.
def pi_normmain(parin, *, bound, repout, ellip):
    par = parin.deepcopy()
    par.prover_comm += log(bound) # Communication of bound $\nu$.
    par = pi_ipmain(par, ellip=ellip)
    par = dataclasses.replace(par, sndbeta = bound) # Norm check fixes extraction bound
    par = pi_batch(par)
    return par

#pi_normmain(par_default, bound=v_beta, repout=v_repout, ellip=v_ellip).pretty_print()

# IPfull: Check IP, batch and fold
def pi_ipfull(parin, *, repout, ellip):
    par = parin.deepcopy()
    par = pi_ipmain(par, ellip=ellip)
    par = pi_batch(par)
    par = pi_fold(par, repout=repout)
    return par

#pi_ipfull(par_default, repout=v_repout, ellip=v_ellip).pretty_print()

# Norm full: Check norm (and batch) and fold
def pi_normfull(parin, *, bound, repout, ellip):
    par = parin.deepcopy()
    par = pi_normmain(par, bound=bound, repout=repout, ellip=ellip)
    par = dataclasses.replace(par, sndbeta = bound)
    par = pi_fold(par, repout=repout)
    return par

#pi_normfull(par_default, bound=v_beta, repout=v_repout, ellip=v_ellip).pretty_print()

# Decomp-Split-and-fold: simple version without norm checks
def pi_decomp_split_fold(parin, *, base, ell, tdim, repout):
    # FIXME Handling base and ell correctly as dependent variables.
    par = parin.deepcopy()
    par = pi_bdecomp(par, base=base, ell=ell)
    par = pi_split(par, tdim=tdim)
    par = pi_fold(par, repout=repout)
    return par

#pi_decomp_split_fold(par_default, base=v_base, ell=v_ell, tdim=v_tdim, repout=v_repout).pretty_print()

# Decomp-and-fold: 
def pi_decomp_fold(parin, *, base, ell, repout):
    # FIXME Handling base and ell correctly as dependent variables.
    par = parin.deepcopy()
    par = pi_bdecomp(par, base=base, ell=ell)
    par = pi_fold(par, repout=repout)
    return par

# pi_decomp_fold(par_default, base=v_base, ell=v_ell, repout=v_repout).pretty_print()

# Split-and-fold: 
def pi_split_fold(parin, *, tdim, repout):
    # FIXME Handling base and ell correctly as dependent variables.
    par = parin.deepcopy()
    par = pi_split(par, tdim=tdim)
    par = pi_fold(par, repout=repout)
    return par

# pi_split_fold(par_default, tdim=v_tdim, repout=v_repout).pretty_print()

# Norm-Split-Fold: Check the norm before the split-and-fold occurs.
def pi_norm_split_fold(parin, *, repout, bound, tdim, base, ell, ellip):
    # Caution! Changing the defaults does not change the function defaults! They're set at function definition!?!
    par = parin.deepcopy()
    par = pi_normfull(par, bound=bound, repout=repout, ellip=ellip)
    par = pi_decomp_split_fold(par, base=base, ell=ell, tdim=tdim, repout=repout)
    return par

# parnsf = pi_norm_split_fold(par_default, repout=v_repout, bound=v_beta, tdim=v_tdim, base=v_base, ell=v_ell, ellip=v_ellip)
# parnsf.pretty_print()

# Split-Norm-Fold: Check norm after decomp-split.
def pi_split_norm_fold(parin, *, repout, bound, tdim, base, ell, ellip):
    # Caution! Changing the defaults does not change the function defaults! They're set at function definition!?!
    par = parin.deepcopy()
    par = pi_bdecomp(par, base=base, ell=ell)
    par = pi_split(par, tdim=tdim)
    par = pi_normmain(par, bound=bound, repout=repout, ellip=ellip)
    par = pi_fold(par, repout=repout)
    return par

# parsnf = pi_split_norm_fold(par_default, repout=v_repout, bound=v_beta, tdim=v_tdim, base=v_base, ell=v_ell, ellip=v_ellip)
# parsnf.pretty_print()

In [None]:
# Comparison of norm-split-fold over split-norm-fold for correctness, relaxed soundness, and communication.
# Simplifications are made to get readable results.
# - Always simplify to ellip = ell, and repout = rep. That's not fully correct, but close enough.
# - Also set nbot = 1 and use ntop = nout - nbot for communication.
# - Use v_beta=1 for communication quotient. Should be a tiny difference, but huge simplification

parnsf = pi_norm_split_fold(par_default, repout=v_repout, bound=v_beta, tdim=v_tdim, base=v_base, ell=v_ell, ellip=v_ellip)
# parnsf.pretty_print()

parsnf = pi_split_norm_fold(par_default, repout=v_repout, bound=v_beta, tdim=v_tdim, base=v_base, ell=v_ell, ellip=v_ellip)
# parsnf.pretty_print()

print("Norm-Split-Fold versus Split-Norm-Fold")
show("Correctness quotient\t", (parnsf.corbeta/parsnf.corbeta).subs(v_ellip = v_ell, v_repout = v_rep).canonicalize_radical())
show("Extraction quotient\t", (parnsf.sndbeta/parsnf.sndbeta).subs(v_ellip = v_ell, v_repout = v_rep).canonicalize_radical())
show("Communication difference\t", (parnsf.prover_comm - parsnf.prover_comm).subs(v_ellip = v_ell, v_repout = v_rep, v_ntop = v_nout - v_nbot).subs(v_nbot = 1).canonicalize_radical().factor())
show("Communication quotient\t", (parnsf.prover_comm/parsnf.prover_comm).subs(v_ellip = v_ell, v_repout = v_rep, v_beta = 1, v_ntop = v_nout - v_nbot).subs(v_nbot = 1).canonicalize_radical()) # Simplifies beta

In [None]:
# Make some random somewhat reasonable choices
show("Communication quotient\t", (parnsf.prover_comm/parsnf.prover_comm)
     .subs(v_ellip = v_ell, v_repout = v_rep, v_beta = 1, v_ntop = v_nout - v_nbot)
     .subs(v_nbot = 1, v_ell = 5, v_ntot = 5, v_rep = 16, v_tdim=2).canonicalize_radical()) # Simplifies beta