# Efficiently Provable Bit Decoding in a Finite Field

Below, we provide an implementation of a protocol that can generate proofs for the value of the $i^{th}$ bit of $n$ bits encoded in a single field element.

This protocol operates within finite field arithmetic, leveraging the property that some numbers have square roots while others don't, as a mechanism to encode bits into a field element.

## Overview
The protocol's intent is to take $n$ bits, encode them into a single field element, and subsequently allow for a proof that a given $i^{th}$ bit of the encoded $n$ bits was indeed either a 1 or a 0. Impressively, our protocol requires only a single constraint to provide this proof for the $i^{th}$ bit.

## Motivation

The purpose of the protocol is to allow to compress binary lookup tables, as well as speed up the bit decomposition. Binary lookup tables, such as ones storing a function $f$ evaluation from domain of $\{0, 1\}^n \rightarrow \{0, 1\}$, usually need a single field field to store each evaluation. Our protocol allows to reduce the row size by a factor of n with a single additional constraint per looked up value by compressing multiple consequtive evaluations into a single field value.

## Protocol


### Pre-processing 

A significant portion of this protocol relies on a precomputation step that establishes the encoding table. Once computed, this table is universally applicable and doesn't require subsequent recalculations. Currently, the computation time for this table, given $n$ bits, is approximately $O(4^n)$, with the table size also being $O(2^n)$. We anticipate that future refinements may significantly enhance this efficiency.

### Encoding

Post the precomputation step, the encoding of $n$ bits is straightforward. It involves looking up the encoding for the bit set in the previously calculated table.

#### Example

Let us consider we want to encode bits `[1, 0, 1]`. Then we will find $x$, such that:

- $x$ has a square root
- $x + 1$ does not
- $x + 2$ does not
- $x + 3$ does
- $x + 4$ does
- $x + 5$ does not

This indeed requires briteforsing to find the encoding. However, using the encoding table we do not need to repeat it.

### Decoding

The decoding process for the $i^{th}$ bit requires examining whether the sum of the encoding field element $\text{enc_v}$ and $2 \times i$ is a square. If true, the bit decodes to $1$; if the sum of the encoding field element $\text{enc_v}$ and $2 \times i + 1$ is a square, the bit decodes to $0$. If neither condition is met, the decoding is deemed invalid.

### Proving

The utilization of the square root property is central to our decision-making process primarily because of its verifiability. If a prover supplies the square root of a number, it's feasible to validate with one constraint that the number indeed has a square root.

Capitalizing on this, our protocol can assert that the value of the $i^{th}$ bit was indeed a $1$ (or $0$) by verifying that $ \text{root}^2 == \text{enc_v} + 2 \times i + 1 $.

Consequently, the entirety of the $n$ bits can be decoded with a mere $n$ constraints.


In [1]:
from sage.all import *;
import math

### Field Setup

In [2]:
# Finite Field we have a constraint system in
p = 0x30644e72e131a029b85045b68181585d2833e84879b9709143e1f593f0000001
F = FiniteField(p)

### Generate the Encoding Table 
This requires quite a bit of work, O(2^(2n)) in the amount of bits we want to encode.
Though we only need to do it once and then we can cash results, use the decoding table anywhere.

In [3]:
def get_sqrt(n):
    """
    Return the square root of a number `n` in a finite field.
    
    Parameters:
    - n: Number to compute the square root for.

    Returns:
    - Square root of `n` if it exists, otherwise None.
    """
    
    fn = F(n)
    
    try:
        return sqrt(fn, extend=False)
    except ValueError:
        return None

In [4]:
def generate_square_roots_data(amount, start=0):
    """
    Precompute the square roots of a specified amount of numbers, starting 
    from a given number.

    Parameters:
    - amount: The number of square roots to compute.
    - start: The starting number to compute the square root from. Defaults to 0.

    Returns:
    - A list of tuples, each containing a number and a boolean indicating 
      whether the number has a valid square root.
    """

    return [(i, get_sqrt(i) is not None) for i in range(start, start + amount)]

In [5]:
def has_distinct_alternating_bits(data, start_idx, length):
    """
    Verify if the sequence extracted from the data has distinct alternating bits
    for every even index. Terminate quickly for non-compliant sequences.
    """
    for offset in range(0, length, 2):
        current_bit, next_bit = data[start_idx + offset][1], data[start_idx + offset + 1][1]

        # A sequence isn't of interest if two consecutive bits are the same
        if (current_bit and next_bit) or (not current_bit and not next_bit):
            return False

    return True


def generate_encoding_map(bit_size_to_encode, data=None, extend_data=False):
    """
    Construct a mapping of encoding values to the original numbers based 
    on the defined encoding criteria.
    """
    
    # Calculate encoding bit size as it's double the input bit size
    encoding_bit_size = bit_size_to_encode * 2
    
    # If data isn't provided, create an initial data structure
    if data is None:
        data = generate_square_roots_data(encoding_bit_size)
        extend_data = True
    
    # Initialize the map with None values
    encoding_map = {}
    start_idx = 0
    required_encodings = 2 ** bit_size_to_encode
    
    while required_encodings > 0:
        if start_idx == len(data) - encoding_bit_size + 1:
            if not extend_data:
                return encoding_map
            
            # Extend data for the next round of encoding
            data += generate_square_roots_data(len(data), len(data))
             
        if not has_distinct_alternating_bits(data, start_idx, encoding_bit_size):
            start_idx += 1
            continue
            
        # Derive the original value from the encoded pattern
        pattern_bits = [data[start_idx + offset][1] for offset in range(encoding_bit_size)]
        original_value = sum([int(bit) << offset for offset, bit in enumerate(pattern_bits)])

        # If this original value hasn't been encoded yet, store it
        if original_value not in encoding_map:
            encoding_map[original_value] = data[start_idx][0]
            required_encodings -= 1

        start_idx += 1

    return encoding_map


In [6]:
def has_alternating_bits_pattern(num, bit_count):
    """
    Check if the binary representation of a number has an alternating bits pattern 
    for the given bit count.
    """
    for _ in range(bit_count):
        # Extract the last two bits
        consecutive_bits = num & 0b11
        # If not 01 and not 10 in binary, return False
        if consecutive_bits not in (0b01, 0b10):
            return False
        num >>= 2  # Check the next pair of bits

    # If there are no remaining bits after checking, it's a valid pattern
    return num == 0


def count_missing_entries_for_valid_patterns(encoding_table, bit_size_to_encode):
    """
    Count the number of valid bit patterns (of interest) 
    that are missing in the encoding table.
    """
    missing_count = 0
    total_patterns = 2 ** (bit_size_to_encode * 2)  # Calculating the total number of patterns

    for pattern_value in range(total_patterns):
        if has_alternating_bits_pattern(pattern_value, bit_size_to_encode) and pattern_value not in encoding_table:
            missing_count += 1

    return missing_count

#### Generate the Setup on the go 
The `create_encoding_map` function supports generating the sqrt data on the fly, though it is not very efficient.

After we generate the table we run the `count_none_for_interest_indices` function that checks how many values still have not a valid encoding in the `encoding_table`. For a well formed `encoding_table` this should be zero.

In [None]:
# Define the encoding size
encoding_size = 16

# Create the encoding table. This method has the capability to generate square root data on-the-fly when data isn't pre-provided.
# Note: Generating the data dynamically, though possible, may not be the most efficient approach.
encoding_table = generate_encoding_map(encoding_size)

# Display the total number of square roots utilized for the given encoding size.
print(f"Utilized {len(encoding_table)} square roots to encode with {encoding_size} bits.")

# Identify and display the count of encodings that couldn't be mapped with the generated square roots.
# A well-optimized encoding table would ideally have this count as zero.
missing_encodings_count = count_missing_entries_for_valid_patterns(encoding_table, encoding_size)
print(f"{missing_encodings_count} of the {2**(encoding_size)} possible encodings are currently without a suitable square root mapping.")


#### Pregenerate the Setup
We can also precompute the sqrt data using the `generate_square_roots_data` function. This will take a while and we need to set the size we want to generate. However, all subsequent requests will be much faster, such as when we generate the `encoding_table`.

In [8]:
# Define the size of the data we want to pregenerate for square roots
# The value `100000` has been chosen in the hope that it covers our requirements.
data_size = 100000
data = generate_square_roots_data(data_size)

In [9]:
# Specify the encoding size
encoding_size = 6

# Create the encoding table using the pregenerated sqrt data. 
# This approach ensures that generating the encoding table is faster.
encoding_table = generate_encoding_map(encoding_size, data)

# Display the total number of square roots utilized for the given encoding size.
print(f"Utilized {len(encoding_table)} square roots to encode with {encoding_size} bits.")

# Display how many potential encodings are still without a corresponding square root in our pregenerated data.
missing_encodings_count = count_missing_entries_for_valid_patterns(encoding_table, encoding_size)
print(f"{missing_encodings_count} out of {2**(encoding_size)} possible encodings are missing from our encoding table.")

Utilized 64 square roots to encode with 6 bits.
0 out of 64 possible encodings are missing from our encoding table.


### Encoding
In this section, we demonstrate how a prover can consolidate a list of bit values into a singular field value. This process facilitates efficient subsequent proofs concerning the decoded value of any given i-th bit.

The quantity of binary values we aim to encode aligns with the previously mentioned parameter, `encoding_size`.

In [598]:
def encode(bits, encoding_table):
    """
    Encode a list of boolean bits into a single value using the provided encoding table.
    
    Args:
        bits (List[bool]): List of boolean values representing bits.
        encoding_table (List[int]): Precomputed encoding table.
        
    Returns:
        int: The encoded value corresponding to the bits sequence.
    """
    
    binary_num = 0
    for i, bit in enumerate(bits):
        shift_position = 2 * i + int(not bit)
        binary_num |= (1 << shift_position)
    
    # Lookup the encoding value using the computed binary number
    encoded_value = encoding_table[binary_num]
    
    return encoded_value

#### Example of Encoding

In [600]:
# Step 1: Define the size of encoding and generate the encoding table
encoding_size = 3
encoding_table = generate_encoding_map(encoding_size)

# Step 2: Specify the list of binary flags to encode into a field element
bits = [True, False, False]

# Step 3: Encode the binary flags into a single field element. 
encoded_bits = None

# We need to verify that the number of flags is compatible with the encoding size
if len(bits) == encoding_size:
    encoded_bits = encode(bits, encoding_table)
    print(f"\nBits {bits} were encoded into value {encoded_bits}.")
else:
    print("\nEncoding cannot proceed due to mismatch in the number of bits and encoding size.")



Bits [True, False, False] were encoded into value 13.


### Decoding

The process of decoding the i-th bit from an encoded value `e` is as follows:

1. **Bit is 1**: If \(e + 2i\) can be square-rooted, then the i-th bit is set to 1.
2. **Bit is 0**: If \(e + 2i + 1\) can be square-rooted, then the i-th bit is set to 0.
3. **Corrupted Encoding**: If both \(e + 2i\) and \(e + 2i + 1\) can either be square-rooted or cannot be square-rooted, the encoding is deemed corrupted and cannot be decoded reliably.

For the decoding process, we will need to derive the square root of two distinct numbers. 

#### Proof of Correct Decoding

To assert the correctness of the decoded i-th bit from the encoded value, the prover only has to present a number, `t`. The conditions are:

- For an i-th bit valued at 1: \(e + 2i\) should equal \(t^2\).
- For an i-th bit valued at 0: \(e + 2i + 1\) should equal \(t^2\).

By satisfying the above conditions, the prover can confidently demonstrate that the decoding has been executed correctly. And one can see that, give the prover has provided `t` as a hint, we just need a single constraint to check the decoding.

In [613]:
class DecodingError(Exception):
    """Raised when decoding cannot be performed correctly."""
    pass

def decode(encoded_value, decoded_bit_index, square_root_table=None):
    """Decode the i-th bit from the encoded value.
    
    Args:
    - encoded_value: The encoded value to be decoded.
    - decoded_bit_index: The position of the bit to decode.
    - square_root_table: (Optional) A precomputed table of square roots.
    
    Returns:
    - A tuple containing a boolean indicating the decoded bit and a proof of the decoding.
    
    Raises:
    - DecodingError: If the decoding process encounters inconsistencies.
    """
    
    # Helper function to get the square root
    def lookup_square_root(value):
        if square_root_table:
            return square_root_table.get(value)
        else:
            return get_sqrt(value)
    
    # Calculate the value to check based on the index
    value_for_bit_1 = encoded_value + 2 * decoded_bit_index
    value_for_bit_0 = value_for_bit_1 + 1

    # Check for square root for the bit valued at 1
    root_for_bit_1 = lookup_square_root(value_for_bit_1)
    if root_for_bit_1 is not None:
        # As proof, return the encoded value, decoded_bit_index, and the square root.
        return True, (root_for_bit_1)

    # Check for square root for the bit valued at 0
    root_for_bit_0 = lookup_square_root(value_for_bit_0)
    print(root_for_bit_0)
    if root_for_bit_0 is not None:
        # As proof, return the encoded value, decoded_bit_index, and the square root.
        return False, (root_for_bit_0)

    # If neither of the conditions are met, the encoding is corrupted
    else:
        raise DecodingError("The encoded value is incorrect.")


#### Example of Decoding

In [614]:
# Index of the bit we aim to decode
bit_index = 2

# Decode the value
decoded_bit, proof = decode(encoded_bits, bit_index)

# Verification
assert(decoded_bit == bits[bit_index]), f"Decoding failed for bit at index {bit_index}."
print(f"Successfully decoded bit at index {bit_index} as {decoded_bit}.")

3091064037240671772801388301617798075985677455689113174543917243621579674803
Successfully decoded bit at index 2 as False.


#### Example of Proving Decoding
We were provided with the trusted values of `encoded_bits`, and someone said that the right decoding of bit at `bit_index` is `decoded_bit`, providing `root` as a proof.

To verify, we do the following:

In [617]:
root = proof
assert(root ** 2 == encoded_bits + 2 * bit_index + (1 - decoded_bit))