In [1]:
from concrete.ml.common.preprocessors import TLUDeltaBasedOptimizer, InsertRounding
from concrete import fhe
import numpy as np
import matplotlib.pyplot as plt
from concrete.fhe import Configuration, Integer

input_range = (-234, 283)

inputset = np.arange(input_range[0], input_range[1], dtype=np.int64)
integer = Integer.that_can_represent(inputset)
full_range = np.arange(integer.min(), integer.max(), dtype=np.int64)

# Constant function
def f(x):
    x = x.astype(np.float64)
    x = 0.75 * x - 200
    x = x * (x > 0)
    x = x // 118
    # x = (x + 2.1) / 3.4
    x = np.rint(x)
    x = x.astype(np.int64)
    return x

# 2 jumps -> like what we have in CIFAR
def f(x):
    x = x.astype(np.float64)
    x = 0.75 * x + 134.
    x = x * (x > 0)
    x = x // 118
    # x = (x + 2.1) / 3.4
    x = np.rint(x)
    x = x.astype(np.int64)
    return x

# 1 jump
def f(x):
    x = x.astype(np.float64)
    x = 0.75 * x + 0.
    x = x * (x > 0)
    x = x // 118
    # x = (x + 2.1) / 3.4
    x = np.rint(x)
    x = x.astype(np.int64)
    return x

# 5 jumps
def f(x):
    x = x.astype(np.float64)
    x = 0.75 * x + 163.
    x = x * (x > 0)
    x = x // 69
    # x = (x + 2.1) / 3.4
    x = np.rint(x)
    x = x.astype(np.int64)
    return x

# TODO: check with f that has non-constant delta

def compute(circuit):
    Y = []
    X = []
    
    for x in full_range:
        y = circuit.simulate(x)
        X.append(x)
        Y.append(y)

    return np.array(Y)

# Naive
f_naive = fhe.compiler({"x": "encrypted"})(f)

circuit_naive = f_naive.compile(inputset)

naive_res = compute(circuit_naive)

# Optim - Approx
exactness = fhe.Exactness.APPROXIMATE
optim = TLUDeltaBasedOptimizer(overflow_protection=True, exactness=exactness)
pre_proc_optim = [optim]
cfg_optim = Configuration(additional_pre_processors=pre_proc_optim)

f_optim = fhe.compiler({"x": "encrypted"})(f)

circuit_optim = f_optim.compile(inputset, configuration=cfg_optim)

optim_res = compute(circuit_optim)

# Optim - Res
exactness = fhe.Exactness.EXACT
optim_exact = TLUDeltaBasedOptimizer(overflow_protection=True, exactness=exactness)
pre_proc_optim_exact = [optim_exact]
cfg_optim_exact = Configuration(additional_pre_processors=pre_proc_optim_exact)

f_optim_exact = fhe.compiler({"x": "encrypted"})(f)

circuit_optim_exact = f_optim_exact.compile(inputset, configuration=cfg_optim_exact)

optim_res_exact = compute(circuit_optim_exact)

# Round (bit-width-from-optim)
n_bits_round = list(optim.statistics.values())[0]["optimized_bitwidth"] if optim.statistics else None
rounding_from_optim_no_scaling = InsertRounding(n_bits_round, overflow_protection=True)
pre_proc_round_from_optim_no_scaling = [rounding_from_optim_no_scaling]
cfg_round_from_optim_no_scaling = Configuration(additional_pre_processors=pre_proc_round_from_optim_no_scaling)

f_round_from_optim_no_scaling = fhe.compiler({"x": "encrypted"})(f)

circuit_round_from_optim_no_scaling = f_round_from_optim_no_scaling.compile(inputset, configuration=cfg_round_from_optim_no_scaling)

round_res_from_optim_no_scaling = compute(circuit_round_from_optim_no_scaling)

# Plot
fig, ax = plt.subplots()
ax.plot(full_range, naive_res,label="ground-truth", linestyle="-")
ax.plot(full_range, optim_res,label="optim-approx", linestyle="-.")
ax.plot(full_range, optim_res_exact,label="optim-exact", linestyle=":")
ax.plot(full_range, round_res_from_optim_no_scaling,label="round-from-optim", linestyle="-.")
ax.vlines(input_range, np.min(naive_res), np.max(naive_res), color="grey", linestyle="--", label="bounds")
plt.legend()

# Set the secondary ticks
if optim.statistics:
    lsbs_to_remove = integer.bit_width - list(optim.statistics.values())[0]["optimized_bitwidth"]
    rounded_ticks = full_range[np.concatenate(
        [
             
            np.diff(
                fhe.round_bit_pattern(full_range, lsbs_to_remove=lsbs_to_remove)
            ).astype(bool),
            np.array([False,]), 
        ]
    )]
    
    # Create secondary axes for the top ticks
    ax_top = ax.twiny()
    ax_top.set_xlim(ax.get_xlim())  # Make sure the secondary axis has the same limits as the primary axis
    
    # Set the secondary ticks
    ax_top.set_xticks(rounded_ticks)
    
    # Customize appearance of secondary ticks
    ax_top.tick_params(which='minor', length=4, color='red')

plt.show()

DETECTED A TLU
subgraph_input_shape=(517,)
subgraph_inputs.shape=(517,)

steps_indexes=array([-125,  -33,   59,  151,  243])
delta_axis=array([92, 92, 92, 92])

threshold=-125, delta=92


[0m

AssertionError: 

In [None]:
raise ValueError()

In [None]:
(naive_res != round_res_from_optim_no_scaling).sum()

In [None]:
(naive_res != optim_res_exact).sum()

In [None]:
# approx vs exact
(optim_res != optim_res_exact)[inputset].sum()

In [None]:
# Ground truth vs Optimized
(naive_res != optim_res)[inputset].sum()

In [None]:
# Ground truth vs Just rounded
(naive_res != round_res_from_optim_no_scaling)[inputset].sum()

In [None]:
deltas = [92,  92*2]

In [None]:
for index in range(len(deltas)):
    print(deltas[index], (fhe.truncate_bit_pattern((deltas[index] * 713) - 1, lsbs_to_remove=16) + 1) / 713, (((deltas[index] * 713) - 1) + 1) / 713)
# We should probably add some rounding in the TLU itself too
# also the value that we use in the TLU can be float -> check that we actually do everything after the casting as float

In [None]:
optimized_repr = []
for x in inputset:
    optimized_repr.append((fhe.round_bit_pattern(((x * 713) - 1), lsbs_to_remove=16) + 1) / 713)
plt.plot(inputset, optimized_repr)

In [None]:
x = np.arange(-16, 16)
def f(x, thresholds=((9, 1), (14, 1),)):
    res = np.zeros(x.shape)
    for threshold, value in thresholds:
        res += (x >= threshold) * value
    return res
y = f(x)
x_star = fhe.round_bit_pattern(x, 2)
y_star = f(x_star)
plt.plot(x,y, label="baseline")
plt.plot(x, y_star, label="removing 2 bits")
plt.legend()

In [None]:
thresholds = x[np.concatenate([[False,], np.diff(y).astype(bool)])]
thresholds

In [None]:
deltas = np.diff(thresholds)
deltas

In [None]:
x_min, x_max = np.min(x), np.max(x)
x_min, x_max

In [None]:
for delta in deltas:
    print(f"{x_min=}, {x_max=}, {delta=} requires {np.ceil(np.log2(np.ceil((x_max - x_min)/delta))).astype(np.int32)} bits")

In [None]:
base_integer_repr = fhe.Integer.that_can_represent([x_min, x_max])
base_integer_repr.bit_width

In [None]:
tlu_bit_width_target = np.ceil(np.log2(np.ceil((x_max - x_min)/deltas[0]))).astype(np.int32)
print(tlu_bit_width_target)

In [None]:
mse_to_beat = np.sum((y - y_star)**2)
plt.plot(x,y, label="baseline")
plt.plot(x, y_star, label="removing 2 bits")
solution_found = False
for a in range(1, 1024):
    for b in range(1024):
        pre_tlu_x = (x*a)-b
        new_repr = Integer.that_can_represent([pre_tlu_x.min(), pre_tlu_x.max()])
        lsbs_to_remove = new_repr.bit_width - 3
        y_starstar = f((fhe.round_bit_pattern(pre_tlu_x, lsbs_to_remove)+b)/a)
        mse = np.sum((y - y_starstar)**2)
        if mse < mse_to_beat:
            print("beats mse", a, b)
        if np.all(y == y_starstar):
            print("exact match", a, b, new_repr.bit_width)
            plt.plot(x, y_starstar)
            solution_found = True
        if solution_found:
            break
    if solution_found:
        break

In [None]:
mse_to_beat = np.sum((y - y_star)**2)
plt.figure()
solution_found = False
mses = []
for a in [3]:  # Fixed known solution a
    for b in range(-1024, 1024):
        pre_tlu_x = (x*a)-b
        new_repr = Integer.that_can_represent([pre_tlu_x.min(), pre_tlu_x.max()])
        lsbs_to_remove = new_repr.bit_width - 3  # Target 3-bit LUT
        y_starstar = f((fhe.round_bit_pattern(pre_tlu_x, lsbs_to_remove)+b)/a)

        mse = np.sum((y - y_starstar)**2)
        mses.append(mse)
        if mse < mse_to_beat:
            print("beats mse", a, b)
        if np.all(y == y_starstar):
            print("exact match", a, b, new_repr.bit_width)
            solution_found = True

plt.plot(mses)

In [None]:
mse_to_beat = np.sum((y - y_star)**2)
solution_found = False
mses = []
plt.figure()
for a in range(1, 1024):
    for b in [1]:  # Fixed known solution b (b==3 works too)
        pre_tlu_x = (x*a)-b
        new_repr = Integer.that_can_represent([pre_tlu_x.min(), pre_tlu_x.max()])
        lsbs_to_remove = new_repr.bit_width - 3  # Target 3-bit LUT
        y_starstar = f((fhe.round_bit_pattern(pre_tlu_x, lsbs_to_remove)+b)/a)

        mse = np.sum((y - y_starstar)**2)
        mses.append(mse)
        if mse < mse_to_beat:
            print("beats mse", a, b)
        if np.all(y == y_starstar):
            print("exact match", a, b, new_repr.bit_width)
            solution_found = True

plt.plot(mses)