This notebook is based on https://github.com/dmpierre/air-ccs/blob/main/sage/air-to-ccs.ipynb. This notebook implements
- Plonk-to-CCS conversion algorithm, found in the [CCS paper](https://eprint.iacr.org/2023/552), Section 2.2.
- A PoC of incorporating Halo2 lookup into CCS+.
[The explanation of math behind the algorithm is here](https://hackmd.io/@pnyda/SkG0qaUuA)


# Derive a CCS instance from a custom gate and constant columns with no respect to copy constraints and lookups

In [3]:
# Classes to represent a custom gate

p = 21888242871839275222246405745257275088696311157297823662689037894645226208583
F = GF(p)

from dataclasses import dataclass
from typing import Literal

@dataclass(eq=True, frozen=True, order=True)
class Column:
    index: int
    kind: Literal["advice", "constant"]

@dataclass(eq=True, frozen=True, order=True)
class RelativeCellPointer:
    column: Column
    row_offset: int

@dataclass(eq=True, frozen=True, order=True)
class Monomial:
    coefficient: F
    variables: list[RelativeCellPointer]

@dataclass(eq=True, frozen=True, order=True)
class CustomGate:
    monomials: list[Monomial]

    def cell_pointers(self) -> list[RelativeCellPointer]:
        # We sometimes refer to the same set of cells from multiple monomials.
        # In that case a naive implementation would generate redundant M_j.
        # To avoid it we take distinct set of RelativeCellPointers.
        # and assign ID j for M_j by sorting it
        return sorted({variable for monomial in self.monomials for variable in monomial.variables})

    def t(self) -> list[int]:
        return len(self.cell_pointers())

    def c(self) -> list[int]:
        return [monomial.coefficient for monomial in self.monomials]

    def s(self) -> list[set[int]]:
        return [{j := self.cell_pointers().index(variable) for variable in monomial.variables} for monomial in self.monomials]
        
    def m(self, constant_columns: list[list[F]]):
        table_height = max(len(column) for column in constant_columns)
        matrices = []
    
        for cell_pointer in self.cell_pointers():
            # Here I add 1 because 1 is always in Z
            m_j = matrix(F, table_height, self.num_advice_columns() * table_height + 1)
            for y in range(table_height):
                if cell_pointer.column.kind == "advice":
                    x = 1 + self.l() + cell_pointer.column.index * table_height + (y + cell_pointer.row_offset) % table_height
                    m_j[y, x] = 1
                elif cell_pointer.column.kind == "constant":
                    # We assume here constant columns are followed by advice columns. This is a PoC code and it's not essential.
                    constant_column_index = cell_pointer.column.index - self.num_advice_columns()
                    # 0 here means that we multiply the rhs by the first element of Z which is always 1
                    m_j[y, 0] = constant_columns[constant_column_index][(y + cell_pointer.row_offset) % table_height]
                else:
                    pass

            matrices.append(m_j)
    
        return matrices

    def l(self) -> int:
        # Number of public i/o
        # To make this PoC code simple we assume l=0
        return 0

    def num_advice_columns(self) -> int:
        return len({cell_pointer.column.index for cell_pointer in self.cell_pointers() if cell_pointer.column.kind == "advice"})

# Check if a CCS instance is satisfied

In [4]:
# https://github.com/dmpierre/air-ccs/blob/main/sage/air-to-ccs.ipynb

def ccs_is_satisfied(F, z, matrices, multisets, constants):
    satisfied_instance_witness = vector(F, [0 for i in range(matrices[0].dimensions()[0])])
    z_final = vector(F, [0 for i in range(matrices[0].dimensions()[0])])
    for i, c in enumerate(constants):
        multiset = multisets[i]
        z_i = vector(F, [1 for i in range(matrices[0].dimensions()[0])])
        for j in multiset:
            z_i = z_i.pairwise_product(matrices[j] * z)
        z_final += c * z_i
    return z_final == satisfied_instance_witness

# Original Plonk

In [62]:
# g = q_m * a * b + q_l * a + q_r * b + q_o * c + q_c

## advice columns
a = Column(0, "advice")
b = Column(1, "advice")
c = Column(2, "advice")

## constant columns
q_m = Column(3, "constant")
q_l = Column(4, "constant")
q_r = Column(5, "constant")
q_o = Column(6, "constant")
q_c = Column(7, "constant")

a_cur = RelativeCellPointer(a, 0)
b_cur = RelativeCellPointer(b, 0)
c_cur = RelativeCellPointer(c, 0)
q_m_cur = RelativeCellPointer(q_m, 0)
q_l_cur = RelativeCellPointer(q_l, 0)
q_r_cur = RelativeCellPointer(q_r, 0)
q_o_cur = RelativeCellPointer(q_o, 0)
q_c_cur = RelativeCellPointer(q_c, 0)

g_plonk = CustomGate([
    Monomial(F(1), [q_m_cur, a_cur, b_cur]),
    Monomial(F(1), [q_l_cur, a_cur]),
    Monomial(F(1), [q_r_cur, b_cur]),
    Monomial(F(1), [q_o_cur, c_cur]),
    Monomial(F(1), [q_c_cur]),
])

print("c_plonk", g_plonk.c())
print("s_plonk", g_plonk.s())

# Example values taken from https://hackmd.io/nQtquuk9QCGiB9EhUPgXvg#CCS-and-Plonkish-intro
q_m_values = [F(1), F(1), F(0), F(0)]
q_l_values = [F(0), F(0), F(2), F(0)]
q_r_values = [F(0), F(0), F(2), F(0)]
q_o_values = [F(-1), F(-1), F(-1), F(0)]
q_c_values = [F(0), F(0), F(0), F(0)]
constant_columns = [q_m_values, q_l_values, q_r_values, q_o_values, q_c_values]
m_plonk = g_plonk.m(constant_columns)
print("m_plonk", *m_plonk, sep="\n\n")

a_values = [F(0), F(1), F(1), F(1)]
b_values = [F(0), F(1), F(2), F(2)]
c_values = [F(0), F(1), F(6), F(3)]

z_plonk = vector(F, [F(1)] + a_values + b_values + c_values)
print("z_plonk", z_plonk)
assert ccs_is_satisfied(F, z_plonk, m_plonk, g_plonk.s(), g_plonk.c())

c_plonk [1, 1, 1, 1, 1]
s_plonk [{0, 1, 3}, {0, 4}, {1, 5}, {2, 6}, {7}]
m_plonk

[0 1 0 0 0 0 0 0 0 0 0 0 0]
[0 0 1 0 0 0 0 0 0 0 0 0 0]
[0 0 0 1 0 0 0 0 0 0 0 0 0]
[0 0 0 0 1 0 0 0 0 0 0 0 0]

[0 0 0 0 0 1 0 0 0 0 0 0 0]
[0 0 0 0 0 0 1 0 0 0 0 0 0]
[0 0 0 0 0 0 0 1 0 0 0 0 0]
[0 0 0 0 0 0 0 0 1 0 0 0 0]

[0 0 0 0 0 0 0 0 0 1 0 0 0]
[0 0 0 0 0 0 0 0 0 0 1 0 0]
[0 0 0 0 0 0 0 0 0 0 0 1 0]
[0 0 0 0 0 0 0 0 0 0 0 0 1]

[1 0 0 0 0 0 0 0 0 0 0 0 0]
[1 0 0 0 0 0 0 0 0 0 0 0 0]
[0 0 0 0 0 0 0 0 0 0 0 0 0]
[0 0 0 0 0 0 0 0 0 0 0 0 0]

[0 0 0 0 0 0 0 0 0 0 0 0 0]
[0 0 0 0 0 0 0 0 0 0 0 0 0]
[2 0 0 0 0 0 0 0 0 0 0 0 0]
[0 0 0 0 0 0 0 0 0 0 0 0 0]

[0 0 0 0 0 0 0 0 0 0 0 0 0]
[0 0 0 0 0 0 0 0 0 0 0 0 0]
[2 0 0 0 0 0 0 0 0 0 0 0 0]
[0 0 0 0 0 0 0 0 0 0 0 0 0]

[21888242871839275222246405745257275088696311157297823662689037894645226208582                                                                             0                                                                             0      

# Fibonacci

In [54]:
# enable(X) * (fib(X) + fib(Xω) - fib(Xω^2)) = 0
fib = Column(0, "advice")
enable = Column(1, "constant")

fib_cur = RelativeCellPointer(fib, 0)
fib_next = RelativeCellPointer(fib, 1)
fib_next_next = RelativeCellPointer(fib, 2)
enable_cur = RelativeCellPointer(enable, 0)

g_fib = CustomGate([
    Monomial(1, [enable_cur, fib_cur]),
    Monomial(1, [enable_cur, fib_next]),
    Monomial(-1, [enable_cur, fib_next_next]),
])
print("c_fib", g_fib.c())
print("s_fib", g_fib.s())

enable_values = [F(1), F(1), F(0), F(0)]
m_fib = g_fib.m([enable_values])
print("m_fib", *m_fib, sep="\n\n")

fib_values = [F(1), F(1), F(2), F(3)]
z_fib = vector(F, [F(1)] + fib_values)
assert ccs_is_satisfied(F, z_fib, m_fib, g_fib.s(), g_fib.c())

c_fib [1, 1, -1]
s_fib [{0, 3}, {1, 3}, {2, 3}]
m_fib

[0 1 0 0 0]
[0 0 1 0 0]
[0 0 0 1 0]
[0 0 0 0 1]

[0 0 1 0 0]
[0 0 0 1 0]
[0 0 0 0 1]
[0 1 0 0 0]

[0 0 0 1 0]
[0 0 0 0 1]
[0 1 0 0 0]
[0 0 1 0 0]

[1 0 0 0 0]
[1 0 0 0 0]
[0 0 0 0 0]
[0 0 0 0 0]


# Apply copy constraints to a CCS instance

In [55]:
from dataclasses import dataclass

@dataclass
class AbsoluteCellPointer:
    column: Column
    row_index: int

def z_index(cell: AbsoluteCellPointer, table_height: int) -> int: 
    return 1 + cell.column.index * table_height + cell.row_index

# Step 1:
# If M_j refers to multiple elements of Z that have to be the same,
# update the reference in such a way that it all refers to the same element.
def apply_copy_constraints(copy_constraints: list[list[AbsoluteCellPointer]], matrices: list[matrix], constant_columns: list[list[F]]) -> list[matrix]:
    table_height = max(len(column) for column in constant_columns)
    num_advice_columns = (max(m_j.ncols() for m_j in matrices) - 1) // table_height
    
    matrices = [copy(m_j) for m_j in matrices]
    for m_j in matrices:
        for row_index in range(m_j.nrows()):
            for equal_cells in copy_constraints:
                for cell in equal_cells[1:]:
                    if m_j[row_index, z_index(cell, table_height)] != 0:
                        # Deduplicate equal_cells[1:] into equal_cells[0]
                        if equal_cells[0].column.kind == "advice":
                            m_j[row_index, z_index(equal_cells[0], table_height)] = m_j[row_index, z_index(cell, table_height)]
                            m_j[row_index, z_index(cell, table_height)] = 0
                        elif equal_cells[0].column.kind == "constant":
                            # We assume here constant columns are followed by advice columns. This is a PoC code and it's not essential.
                            constant_column_index = equal_cells[0].column.index - num_advice_columns
                            # 0 here means that we multiply the rhs by the first element of Z which is always 1
                            m_j[row_index, 0] = constant_columns[constant_column_index][equal_cells[0].row_index]
                            m_j[row_index, z_index(cell, table_height)] = 0
                        else:
                            pass

    return matrices


# Step 2:
# Remove elements of Z that is referenced by no M_j
def clean_up(old_matrices: list[matrix], old_z: vector) -> (list[matrix], vector):
    used_z_indices = {x for m_j in old_matrices for y in range(m_j.nrows()) for x in range(m_j.ncols()) if m_j[y, x] != 0}
    unused_z_indices = set(range(len(old_z))) - used_z_indices
    print("unused_z_indices", unused_z_indices)
    print("used_z_indices", used_z_indices)

    new_z = list(old_z)
    for unused_z_index in reversed(sorted(list(unused_z_indices))):
        new_z.pop(unused_z_index)

    new_matrices = []
    for (j, old_m_j) in enumerate(old_matrices):
        new_m_j = matrix(F, old_m_j.nrows(), len(used_z_indices))
        for y in range(old_m_j.nrows()):
            for x in range(old_m_j.ncols()):
                offset = len([unused_z_index for unused_z_index in unused_z_indices if unused_z_index < x])
                if old_m_j[y, x] != 0:
                    new_m_j[y, x - offset] = old_m_j[y, x]
        new_matrices.append(new_m_j)

    return (new_matrices, vector(F, new_z))

# Apply copy constraints to m_plonk

In [64]:
copy_constraints = [
    [AbsoluteCellPointer(a, 0), AbsoluteCellPointer(b, 0), AbsoluteCellPointer(c, 0)],
    [AbsoluteCellPointer(a, 1), AbsoluteCellPointer(b, 1), AbsoluteCellPointer(c, 1)],
    [AbsoluteCellPointer(a, 3), AbsoluteCellPointer(b, 3), AbsoluteCellPointer(c, 3)],
    [AbsoluteCellPointer(q_l, 2), AbsoluteCellPointer(b, 2)],  # This copy constraint was not in https://hackmd.io/nQtquuk9QCGiB9EhUPgXvg
]
deduplicated_m = apply_copy_constraints(copy_constraints, m_plonk, constant_columns)
(cleaned_up_m, cleaned_up_z) = clean_up(deduplicated_m, z_plonk)
print("original", *m_plonk, sep="\n\n")
print("deduplicated", *deduplicated_m, sep="\n\n")
print("cleaned_up_m", *cleaned_up_m, sep="\n\n")
print("original_z", z_plonk)
print("cleaned_up_z", cleaned_up_z)
assert ccs_is_satisfied(F, cleaned_up_z, cleaned_up_m, g_plonk.s(), g_plonk.c())

unused_z_indices {5, 6, 7, 8, 9, 10, 12}
used_z_indices {0, 1, 2, 3, 4, 11}
original

[0 1 0 0 0 0 0 0 0 0 0 0 0]
[0 0 1 0 0 0 0 0 0 0 0 0 0]
[0 0 0 1 0 0 0 0 0 0 0 0 0]
[0 0 0 0 1 0 0 0 0 0 0 0 0]

[0 0 0 0 0 1 0 0 0 0 0 0 0]
[0 0 0 0 0 0 1 0 0 0 0 0 0]
[0 0 0 0 0 0 0 1 0 0 0 0 0]
[0 0 0 0 0 0 0 0 1 0 0 0 0]

[0 0 0 0 0 0 0 0 0 1 0 0 0]
[0 0 0 0 0 0 0 0 0 0 1 0 0]
[0 0 0 0 0 0 0 0 0 0 0 1 0]
[0 0 0 0 0 0 0 0 0 0 0 0 1]

[1 0 0 0 0 0 0 0 0 0 0 0 0]
[1 0 0 0 0 0 0 0 0 0 0 0 0]
[0 0 0 0 0 0 0 0 0 0 0 0 0]
[0 0 0 0 0 0 0 0 0 0 0 0 0]

[0 0 0 0 0 0 0 0 0 0 0 0 0]
[0 0 0 0 0 0 0 0 0 0 0 0 0]
[2 0 0 0 0 0 0 0 0 0 0 0 0]
[0 0 0 0 0 0 0 0 0 0 0 0 0]

[0 0 0 0 0 0 0 0 0 0 0 0 0]
[0 0 0 0 0 0 0 0 0 0 0 0 0]
[2 0 0 0 0 0 0 0 0 0 0 0 0]
[0 0 0 0 0 0 0 0 0 0 0 0 0]

[21888242871839275222246405745257275088696311157297823662689037894645226208582                                                                             0                                                                             0  