This notebook is based on https://github.com/dmpierre/air-ccs/blob/main/sage/air-to-ccs.ipynb. This notebook implements
- Plonk-to-CCS conversion algorithm, found in the [CCS paper](https://eprint.iacr.org/2023/552), Section 2.2.
- A PoC of incorporating Halo2 lookup into CCS+.
[The explanation of math behind the algorithm is here](https://hackmd.io/@pnyda/SkG0qaUuA)


# The part where we derive c, M, S from custom gates

In [106]:
# Classes to represent a custom gate

from dataclasses import dataclass

@dataclass(eq=True, frozen=True, order=True)
class RelativeCellPointer:
    absolute_column_index: int
    relative_row_index: int

@dataclass
class Monomial:
    coefficient: int
    variables: list[RelativeCellPointer]

@dataclass
class CustomGate:
    monomials: list[Monomial]

    def c(self) -> list[int]:
        return [monomial.coefficient for monomial in self.monomials]

    def s(self) -> list[set[int]]:
        cell_pointers = sorted({variable for monomial in self.monomials for variable in monomial.variables})
        
        s = []
        for monomial in self.monomials:
            s_i = set()
            for variable in monomial.variables:
                s_i.add(cell_pointers.index(variable))
                        
            s.append(s_i)

        return s

    def t(self) -> list[int]:
        return len({variable for monomial in self.monomials for variable in monomial.variables})
        
    def m(self, table_width: int, table_height: int):
        cell_pointers = sorted({variable for monomial in self.monomials for variable in monomial.variables})
        matrices = []
    
        for cell in cell_pointers:
            m_j = matrix(F, table_height, table_width * table_height)
            for y in range(table_height):
                m_j[y, cell.absolute_column_index * table_height + (y + cell.relative_row_index) % table_height] = 1
            matrices.append(m_j)
    
        return matrices


p = 21888242871839275222246405745257275088696311157297823662689037894645226208583
F = GF(p)

# a * b + c = d
g_muladd = CustomGate([
    Monomial(1, [RelativeCellPointer(0, 0), RelativeCellPointer(1, 0)]),
    Monomial(1, [RelativeCellPointer(2, 0)]),
    Monomial(-1, [RelativeCellPointer(3, 0)]),
])

print("c", g_muladd.c())
print("s", g_muladd.s())

m_muladd = g_muladd.m(4, 3)
print("m_muladd", *m_muladd, sep="\n\n")

c [1, 1, -1]
s [{0, 1}, {2}, {3}]
m_muladd

[1 0 0 0 0 0 0 0 0 0 0 0]
[0 1 0 0 0 0 0 0 0 0 0 0]
[0 0 1 0 0 0 0 0 0 0 0 0]

[0 0 0 1 0 0 0 0 0 0 0 0]
[0 0 0 0 1 0 0 0 0 0 0 0]
[0 0 0 0 0 1 0 0 0 0 0 0]

[0 0 0 0 0 0 1 0 0 0 0 0]
[0 0 0 0 0 0 0 1 0 0 0 0]
[0 0 0 0 0 0 0 0 1 0 0 0]

[0 0 0 0 0 0 0 0 0 1 0 0]
[0 0 0 0 0 0 0 0 0 0 1 0]
[0 0 0 0 0 0 0 0 0 0 0 1]


In [107]:
# https://github.com/dmpierre/air-ccs/blob/main/sage/air-to-ccs.ipynb
def ccs_is_satisfied(F, z, matrices, multisets, constants):
    satisfied_instance_witness = vector(F, [0 for i in range(matrices[0].dimensions()[0])])
    z_final = vector(F, [0 for i in range(matrices[0].dimensions()[0])])
    for i, c in enumerate(constants):
        multiset = multisets[i]
        z_i = vector(F, [1 for i in range(matrices[0].dimensions()[0])])
        for j in multiset:
            z_i = z_i.pairwise_product(matrices[j] * z)
        z_final += c * z_i
    return z_final == satisfied_instance_witness

In [108]:
# Advice columns
table = [[2, 1, 6], [3, 2, 7], [4, 3, 3], [10, 5, 45]]
z = vector(F, table[0] + table[1] + table[2] + table[3])
assert ccs_is_satisfied(F, z, m_muladd, g_muladd.s(), g_muladd.c())

In [109]:
# latch(X) * (fib(X) + fib(Xω) - fib(Xω^2)) = 0
g_fib = CustomGate([
    Monomial(1, [RelativeCellPointer(0, 0), RelativeCellPointer(1, 0)]),
    Monomial(1, [RelativeCellPointer(0, 0), RelativeCellPointer(1, 1)]),
    Monomial(-1, [RelativeCellPointer(0, 0), RelativeCellPointer(1, 2)]),
])
m_fib = g_fib.m(2, 4)
print("c_fib", g_fib.c())
print("s_fib", g_fib.s())
print("m_fib", *m_fib, sep="\n\n")

# Advice columns
table = [[1, 1, 0, 0], [1, 1, 2, 3]]
z = vector(F, table[0] + table[1])
assert ccs_is_satisfied(F, z, m_fib, g_fib.s(), g_fib.c())

c_fib [1, 1, -1]
s_fib [{0, 1}, {0, 2}, {0, 3}]
m_fib

[1 0 0 0 0 0 0 0]
[0 1 0 0 0 0 0 0]
[0 0 1 0 0 0 0 0]
[0 0 0 1 0 0 0 0]

[0 0 0 0 1 0 0 0]
[0 0 0 0 0 1 0 0]
[0 0 0 0 0 0 1 0]
[0 0 0 0 0 0 0 1]

[0 0 0 0 0 1 0 0]
[0 0 0 0 0 0 1 0]
[0 0 0 0 0 0 0 1]
[0 0 0 0 1 0 0 0]

[0 0 0 0 0 0 1 0]
[0 0 0 0 0 0 0 1]
[0 0 0 0 1 0 0 0]
[0 0 0 0 0 1 0 0]


# TODO
- handling constant columns
- lookup

In [114]:
from dataclasses import dataclass

@dataclass
class AbsoluteCellPointer:
    x: int
    y: int

def z_index(cell: AbsoluteCellPointer, table_height: int) -> int: 
    return cell.x * table_height + cell.y

# Deduplicate the elements of Z that has to be the same
def apply_copy_constraints(copy_constraints: list[list[AbsoluteCellPointer]], matrices: list[matrix], table_height: int) -> list[matrix]:
    matrices = [copy(m) for m in matrices]
    for m in matrices:
        for row_index in range(m.nrows()):
            for equal_cells in copy_constraints:
                for cell in equal_cells[1:]:
                    if m[row_index, z_index(cell, table_height)] != 0:
                        m[row_index, z_index(equal_cells[0], table_height)] = m[row_index, z_index(cell, table_height)]
                        m[row_index, z_index(cell, table_height)] = 0

    return matrices
        

# Remove unused elements from Z
def clean_up_z(old_matrices: list[matrix], old_z: vector) -> (list[matrix], vector):
    new_z = list(z)
    
    used_z_indices = set()
    for m in old_matrices:
        for row_index in range(m.nrows()):
            for col_index in range(m.ncols()):
                if m[row_index, col_index] != 0:
                    used_z_indices.add(col_index)

    unused_z_indices = set(range(m.ncols())) - used_z_indices
    for unused_z_index in reversed(sorted(list(unused_z_indices))):
        new_z.pop(unused_z_index)

    # We have removed unused elements of Z. M_j is now inconsistent with new Z so we have to adjust it.
    print("unused_z_indices", unused_z_indices)
    print("used_z_indices", used_z_indices)

    new_matrices = [matrix(F, m.nrows(), len(used_z_indices)) for m in old_matrices]
    
    for (j, m) in enumerate(old_matrices):
        for row_index in range(m.nrows()):
            for col_index in range(m.ncols()):
                offset = len([unused_z_index for unused_z_index in unused_z_indices if unused_z_index < col_index])
                if m[row_index, col_index] != 0:
                    new_matrices[j][row_index, col_index - offset] = m[row_index, col_index]

    return (new_matrices, vector(F, new_z))
    

# Advice columns
table = [[1, 1, 0, 0], [1, 1, 2, 3]]
z = vector(F, table[0] + table[1])

copy_constraints = [
    [AbsoluteCellPointer(0, 0), AbsoluteCellPointer(0, 1)],
    [AbsoluteCellPointer(0, 2), AbsoluteCellPointer(0, 3)],
    [AbsoluteCellPointer(1, 0), AbsoluteCellPointer(1, 1)]
]
deduplicated_m = apply_copy_constraints(copy_constraints, m_fib, 4)
(cleaned_up_m, cleaned_up_z) = clean_up_z(deduplicated_m, z)
print("original", *m_fib, sep="\n\n")
print("deduplicated", *deduplicated_m, sep="\n\n")
print("cleaned_up_m", *cleaned_up_m, sep="\n\n")
print("original_z", z)
print("cleaned_up_z", cleaned_up_z)
assert ccs_is_satisfied(F, cleaned_up_z, cleaned_up_m, g_fib.s(), g_fib.c())

unused_z_indices {1, 3, 5}
used_z_indices {0, 2, 4, 6, 7}
original

[1 0 0 0 0 0 0 0]
[0 1 0 0 0 0 0 0]
[0 0 1 0 0 0 0 0]
[0 0 0 1 0 0 0 0]

[0 0 0 0 1 0 0 0]
[0 0 0 0 0 1 0 0]
[0 0 0 0 0 0 1 0]
[0 0 0 0 0 0 0 1]

[0 0 0 0 0 1 0 0]
[0 0 0 0 0 0 1 0]
[0 0 0 0 0 0 0 1]
[0 0 0 0 1 0 0 0]

[0 0 0 0 0 0 1 0]
[0 0 0 0 0 0 0 1]
[0 0 0 0 1 0 0 0]
[0 0 0 0 0 1 0 0]
deduplicated

[1 0 0 0 0 0 0 0]
[1 0 0 0 0 0 0 0]
[0 0 1 0 0 0 0 0]
[0 0 1 0 0 0 0 0]

[0 0 0 0 1 0 0 0]
[0 0 0 0 1 0 0 0]
[0 0 0 0 0 0 1 0]
[0 0 0 0 0 0 0 1]

[0 0 0 0 1 0 0 0]
[0 0 0 0 0 0 1 0]
[0 0 0 0 0 0 0 1]
[0 0 0 0 1 0 0 0]

[0 0 0 0 0 0 1 0]
[0 0 0 0 0 0 0 1]
[0 0 0 0 1 0 0 0]
[0 0 0 0 1 0 0 0]
cleaned_up_m

[1 0 0 0 0]
[1 0 0 0 0]
[0 1 0 0 0]
[0 1 0 0 0]

[0 0 1 0 0]
[0 0 1 0 0]
[0 0 0 1 0]
[0 0 0 0 1]

[0 0 1 0 0]
[0 0 0 1 0]
[0 0 0 0 1]
[0 0 1 0 0]

[0 0 0 1 0]
[0 0 0 0 1]
[0 0 1 0 0]
[0 0 1 0 0]
original_z (1, 1, 0, 0, 1, 1, 2, 3)
cleaned_up_z (1, 0, 1, 2, 3)


In [111]:
m = matrix(F, 2, 3)
m.ncols()

{x for x in range(3)}

{0, 1, 2}