This notebook is based on https://github.com/dmpierre/air-ccs/blob/main/sage/air-to-ccs.ipynb. This notebook implements
- Plonk-to-CCS conversion algorithm, found in the [CCS paper](https://eprint.iacr.org/2023/552), Section 2.2.
- A PoC of incorporating Halo2 lookup into CCS+.
[The explanation of math behind the algorithm is here](https://hackmd.io/@pnyda/SkG0qaUuA)


# The part where we derive c, M, S from custom gates

In [135]:
# Classes to represent a custom gate

from dataclasses import dataclass
from typing import Literal

@dataclass(eq=True, frozen=True, order=True)
class Column:
    index: int
    kind: Literal["advice", "constant"]

@dataclass(eq=True, frozen=True, order=True)
class RelativeCellPointer:
    column: Column
    row_offset: int

@dataclass(eq=True, frozen=True, order=True)
class Monomial:
    coefficient: F
    variables: list[RelativeCellPointer]

@dataclass(eq=True, frozen=True, order=True)
class CustomGate:
    monomials: list[Monomial]

    def cell_pointers(self) -> list[RelativeCellPointer]:
        # We sometimes refer to the same set of cells from multiple monomials.
        # In that case a naive implementation would generate redundant M_j.
        # To avoid it we take distinct set of RelativeCellPointers.
        # and assign ID j for M_j by sorting it
        return sorted({variable for monomial in self.monomials for variable in monomial.variables})

    def t(self) -> list[int]:
        return len(self.cell_pointers())

    def c(self) -> list[int]:
        return [monomial.coefficient for monomial in self.monomials]

    def s(self) -> list[set[int]]:
        return [{j := self.cell_pointers().index(variable) for variable in monomial.variables} for monomial in self.monomials]
        
    def m(self, constant_columns: list[list[F]]):
        table_height = max(len(column) for column in constant_columns)
        matrices = []
    
        for cell_pointer in self.cell_pointers():
            # Here I add 1 because 1 is always in Z
            m_j = matrix(F, table_height, self.num_advice_columns() * table_height + 1)
            for y in range(table_height):
                if cell_pointer.column.kind == "advice":
                    x = cell_pointer.column.index * table_height + (y + cell_pointer.row_offset) % table_height
                    m_j[y, x] = 1
                elif cell_pointer.column.kind == "constant":
                    # We multiply the 1 in Z by constants
                    last_element_of_z = self.num_advice_columns() * table_height
                    constant_column_index = cell_pointer.column.index - self.num_advice_columns()
                    m_j[y, last_element_of_z] = constant_columns[constant_column_index][(y + cell_pointer.row_offset) % table_height]
                else:
                    pass

            matrices.append(m_j)
    
        return matrices

    def num_advice_columns(self) -> int:
        return len({cell_pointer.column.index for cell_pointer in self.cell_pointers() if cell_pointer.column.kind == "advice"})


p = 21888242871839275222246405745257275088696311157297823662689037894645226208583
F = GF(p)

# g = q_m * a * b + q_l * a + q_r * b + q_o * c + q_c
## advice columns
a = RelativeCellPointer(Column(0, "advice"), 0)
b = RelativeCellPointer(Column(1, "advice"), 0)
c = RelativeCellPointer(Column(2, "advice"), 0)
## constant columns
q_m = RelativeCellPointer(Column(3, "constant"), 0)
q_l = RelativeCellPointer(Column(4, "constant"), 0)
q_r = RelativeCellPointer(Column(5, "constant"), 0)
q_o = RelativeCellPointer(Column(6, "constant"), 0)
q_c = RelativeCellPointer(Column(7, "constant"), 0)

g_plonk = CustomGate([
    Monomial(F(1), [q_m, a, b]),
    Monomial(F(1), [q_l, a]),
    Monomial(F(1), [q_r, b]),
    Monomial(F(1), [q_o, c]),
    Monomial(F(1), [q_c]),
])

print("c_plonk", g_plonk.c())
print("s_plonk", g_plonk.s())

# Example values taken from https://hackmd.io/nQtquuk9QCGiB9EhUPgXvg#CCS-and-Plonkish-intro
q_m_values = [F(1), F(1), F(0), F(0)]
q_l_values = [F(0), F(0), F(2), F(0)]
q_r_values = [F(0), F(0), F(2), F(0)]
q_o_values = [F(-1), F(-1), F(-1), F(0)]
q_c_values = [F(0), F(0), F(0), F(0)]
constant_columns = [q_m_values, q_l_values, q_r_values, q_o_values, q_c_values]
m_plonk = g_plonk.m(constant_columns)
print(len(m_plonk))
print("m_plonk", *m_plonk, sep="\n\n")

c_plonk [1, 1, 1, 1, 1]
s_plonk [{0, 1, 3}, {0, 4}, {1, 5}, {2, 6}, {7}]
8
m_plonk

[1 0 0 0 0 0 0 0 0 0 0 0 0]
[0 1 0 0 0 0 0 0 0 0 0 0 0]
[0 0 1 0 0 0 0 0 0 0 0 0 0]
[0 0 0 1 0 0 0 0 0 0 0 0 0]

[0 0 0 0 1 0 0 0 0 0 0 0 0]
[0 0 0 0 0 1 0 0 0 0 0 0 0]
[0 0 0 0 0 0 1 0 0 0 0 0 0]
[0 0 0 0 0 0 0 1 0 0 0 0 0]

[0 0 0 0 0 0 0 0 1 0 0 0 0]
[0 0 0 0 0 0 0 0 0 1 0 0 0]
[0 0 0 0 0 0 0 0 0 0 1 0 0]
[0 0 0 0 0 0 0 0 0 0 0 1 0]

[0 0 0 0 0 0 0 0 0 0 0 0 1]
[0 0 0 0 0 0 0 0 0 0 0 0 1]
[0 0 0 0 0 0 0 0 0 0 0 0 0]
[0 0 0 0 0 0 0 0 0 0 0 0 0]

[0 0 0 0 0 0 0 0 0 0 0 0 0]
[0 0 0 0 0 0 0 0 0 0 0 0 0]
[0 0 0 0 0 0 0 0 0 0 0 0 2]
[0 0 0 0 0 0 0 0 0 0 0 0 0]

[0 0 0 0 0 0 0 0 0 0 0 0 0]
[0 0 0 0 0 0 0 0 0 0 0 0 0]
[0 0 0 0 0 0 0 0 0 0 0 0 2]
[0 0 0 0 0 0 0 0 0 0 0 0 0]

[                                                                            0                                                                             0                                                                             0    

In [136]:
# https://github.com/dmpierre/air-ccs/blob/main/sage/air-to-ccs.ipynb
def ccs_is_satisfied(F, z, matrices, multisets, constants):
    satisfied_instance_witness = vector(F, [0 for i in range(matrices[0].dimensions()[0])])
    z_final = vector(F, [0 for i in range(matrices[0].dimensions()[0])])
    for i, c in enumerate(constants):
        multiset = multisets[i]
        z_i = vector(F, [1 for i in range(matrices[0].dimensions()[0])])
        for j in multiset:
            z_i = z_i.pairwise_product(matrices[j] * z)
        z_final += c * z_i
    return z_final == satisfied_instance_witness

In [137]:
# Advice columns
advice_columns = [
    [F(0), F(1), F(0), F(1)], 
    [F(1), F(1), F(0), F(1)],
    [F(0), F(1), F(0), F(2)]
]
z = vector(F, advice_columns[0] + advice_columns[1] + advice_columns[2] + [F(1)])
assert ccs_is_satisfied(F, z, m_plonk, g_plonk.s(), g_plonk.c())

In [141]:
# latch(X) * (fib(X) + fib(Xω) - fib(Xω^2)) = 0
g_fib = CustomGate([
    Monomial(1, [RelativeCellPointer(Column(0, "advice"), 0), RelativeCellPointer(Column(1, "constant"), 0)]),
    Monomial(1, [RelativeCellPointer(Column(0, "advice"), 1), RelativeCellPointer(Column(1, "constant"), 0)]),
    Monomial(-1, [RelativeCellPointer(Column(0, "advice"), 2), RelativeCellPointer(Column(1, "constant"), 0)]),
])
print("c_fib", g_fib.c())
print("s_fib", g_fib.s())

constant_columns = [[F(1), F(1), F(0), F(0)]]
m_fib = g_fib.m(constant_columns)
print("m_fib", *m_fib, sep="\n\n")

advice_columns = [[F(1), F(1), F(2), F(3)]]
z = vector(F, advice_columns[0] + [F(1)])
assert ccs_is_satisfied(F, z, m_fib, g_fib.s(), g_fib.c())

c_fib [1, 1, -1]
s_fib [{0, 3}, {1, 3}, {2, 3}]
m_fib

[1 0 0 0 0]
[0 1 0 0 0]
[0 0 1 0 0]
[0 0 0 1 0]

[0 1 0 0 0]
[0 0 1 0 0]
[0 0 0 1 0]
[1 0 0 0 0]

[0 0 1 0 0]
[0 0 0 1 0]
[1 0 0 0 0]
[0 1 0 0 0]

[0 0 0 0 1]
[0 0 0 0 1]
[0 0 0 0 0]
[0 0 0 0 0]


# TODO
- lookup

In [145]:
from dataclasses import dataclass

@dataclass
class AbsoluteCellPointer:
    x: int
    y: int

def z_index(cell: AbsoluteCellPointer, table_height: int) -> int: 
    return cell.x * table_height + cell.y

# Deduplicate the elements of Z that has to be the same
def apply_copy_constraints(copy_constraints: list[list[AbsoluteCellPointer]], matrices: list[matrix], table_height: int) -> list[matrix]:
    matrices = [copy(m) for m in matrices]
    for m in matrices:
        for row_index in range(m.nrows()):
            for equal_cells in copy_constraints:
                for cell in equal_cells[1:]:
                    if m[row_index, z_index(cell, table_height)] != 0:
                        m[row_index, z_index(equal_cells[0], table_height)] = m[row_index, z_index(cell, table_height)]
                        m[row_index, z_index(cell, table_height)] = 0

    return matrices
        

# Remove unused elements from Z
def clean_up_z(old_matrices: list[matrix], old_z: vector) -> (list[matrix], vector):
    new_z = list(z)
    
    used_z_indices = set()
    for m in old_matrices:
        for row_index in range(m.nrows()):
            for col_index in range(m.ncols()):
                if m[row_index, col_index] != 0:
                    used_z_indices.add(col_index)

    unused_z_indices = set(range(m.ncols())) - used_z_indices
    for unused_z_index in reversed(sorted(list(unused_z_indices))):
        new_z.pop(unused_z_index)

    # We have removed unused elements of Z. M_j is now inconsistent with new Z so we have to adjust it.
    print("unused_z_indices", unused_z_indices)
    print("used_z_indices", used_z_indices)

    new_matrices = [matrix(F, m.nrows(), len(used_z_indices)) for m in old_matrices]
    
    for (j, m) in enumerate(old_matrices):
        for row_index in range(m.nrows()):
            for col_index in range(m.ncols()):
                offset = len([unused_z_index for unused_z_index in unused_z_indices if unused_z_index < col_index])
                if m[row_index, col_index] != 0:
                    new_matrices[j][row_index, col_index - offset] = m[row_index, col_index]

    return (new_matrices, vector(F, new_z))
    

# Advice columns
advice_column = [1, 1, 2, 3]
z = vector(F, advice_column + [F(1)])

copy_constraints = [
    [AbsoluteCellPointer(0, 0), AbsoluteCellPointer(0, 1)]
]
deduplicated_m = apply_copy_constraints(copy_constraints, m_fib, 4)
(cleaned_up_m, cleaned_up_z) = clean_up_z(deduplicated_m, z)
print("original", *m_fib, sep="\n\n")
print("deduplicated", *deduplicated_m, sep="\n\n")
print("cleaned_up_m", *cleaned_up_m, sep="\n\n")
print("original_z", z)
print("cleaned_up_z", cleaned_up_z)
assert ccs_is_satisfied(F, cleaned_up_z, cleaned_up_m, g_fib.s(), g_fib.c())

unused_z_indices {1}
used_z_indices {0, 2, 3, 4}
original

[1 0 0 0 0]
[0 1 0 0 0]
[0 0 1 0 0]
[0 0 0 1 0]

[0 1 0 0 0]
[0 0 1 0 0]
[0 0 0 1 0]
[1 0 0 0 0]

[0 0 1 0 0]
[0 0 0 1 0]
[1 0 0 0 0]
[0 1 0 0 0]

[0 0 0 0 1]
[0 0 0 0 1]
[0 0 0 0 0]
[0 0 0 0 0]
deduplicated

[1 0 0 0 0]
[1 0 0 0 0]
[0 0 1 0 0]
[0 0 0 1 0]

[1 0 0 0 0]
[0 0 1 0 0]
[0 0 0 1 0]
[1 0 0 0 0]

[0 0 1 0 0]
[0 0 0 1 0]
[1 0 0 0 0]
[1 0 0 0 0]

[0 0 0 0 1]
[0 0 0 0 1]
[0 0 0 0 0]
[0 0 0 0 0]
cleaned_up_m

[1 0 0 0]
[1 0 0 0]
[0 1 0 0]
[0 0 1 0]

[1 0 0 0]
[0 1 0 0]
[0 0 1 0]
[1 0 0 0]

[0 1 0 0]
[0 0 1 0]
[1 0 0 0]
[1 0 0 0]

[0 0 0 1]
[0 0 0 1]
[0 0 0 0]
[0 0 0 0]
original_z (1, 1, 2, 3, 1)
cleaned_up_z (1, 2, 3, 1)


In [111]:
m = matrix(F, 2, 3)
m.ncols()

{x for x in range(3)}

{0, 1, 2}