In [81]:
from random import random
from random import seed
import math

In [82]:
q = 32 # Ciphertext Modulus is 2**32 
level = 3 # Decomposition level for the PBS
base_log = 7 # Decomposition base log for the PBS

In [83]:
def closest_representable(input_word, level, base_log):
    ## Inputs: 
    ##  input_word to be "rounded", decomposition parameters level_count and base_log
    ## Outputs: Computes the closest representable number by the decomposition defined by 
    ##  level_count and base_log.
    non_rep_bit_count = q - level * base_log
    non_rep_mask = 1 << (non_rep_bit_count - 1)
    non_rep_bits = input_word & non_rep_mask
    non_rep_msb = non_rep_bits >> (non_rep_bit_count - 1)
    res = input_word >> non_rep_bit_count
    res = res + non_rep_msb
    res = res << non_rep_bit_count
    return res

def decompose_input(decomp_input, level_l, base_log):
    ## Inputs: 
    ##  Coefficient decomp_input to be decomposed with decomposition parameters level_l and base_log
    ## Output: list of level_l coefficients representing the closest representable number
    closest_representable_input = closest_representable(decomp_input, level_l, base_log) % 2**q
    #print("Closest representable input: ", hex(closest_representable_input))
    current_level = level_l
    state = closest_representable_input >> (q - base_log * level_l)
    mod_b_mask = (1 << base_log) - 1
    decomposed_input = [0]*level_l
    for i in range(level_l):
        #print("Current level: ", current_level)
        # Decompose the current level
        decomp_output = state & mod_b_mask
        #print("decomp out tmp: ", hex(decomp_output))
        state >>= base_log
        carry = ((decomp_output-1) | state) & decomp_output
        #print("Carry: ", hex(carry))
        carry >>= base_log - 1
        #print("Shifted Carry: ", hex(carry))
        state += carry
        decomp_output = decomp_output - (carry << base_log)
        decomposed_input[current_level-1] = decomp_output
        current_level -= 1
        
    # Reconstruct to check 
    recons = 0
    for i in range(level_l):
        recons += decomposed_input[i]*2**(32-(1+i)*base_log) % 2**32
        print("reconstruction", i, recons)
    recons = recons % 2**q
    if recons!=closest_representable_input: 
        print("problem decomposing")
        print(recons, " ", closest_representable_input)
    print(recons, " ", closest_representable_input)

    return decomposed_input

In [84]:
# Generate test vectors
2617242525
decomposition = decompose_input(2617242525, 3, 7)
print("Decomposition of ", "2617242525" , " is ", decomposition)

# decomposition = decompose_input(int("0x187fc55f",16), 3, 7)
# print("Decomposition of ", "0xF87fc55f" , " is ", decomposition)

# decomposition = decompose_input(int("0xe07fa59b",16), 3, 7)
# print("Decomposition of ", "0xe07fa59b" , " is ", decomposition)
# print("----------------------------------------")

# decomposition = decompose_input(int("0x7f9d0e65",16), 3, 7)
# print("Decomposition of ", "0x7f9d0e65" , " is ", decomposition)
# print("----------------------------------------")

# decomposition = decompose_input(int("0x5f08769",16), 3, 7)
# print("Decomposition of ", "0x5f08769" , " is ", decomposition)
# print("----------------------------------------")

# decomposition = decompose_input(int("0x1fea30de",16), 3, 7)
# print("Decomposition of ", "0x1fea30de" , " is ", decomposition)
# print("----------------------------------------")

# decomposition = decompose_input(int("0x81f7631",16), 3, 7)
# print("Decomposition of ", "0x81f7631" , " is ", decomposition)
# print("----------------------------------------")

# decomposition = decompose_input(int("0xd0b0d0f7",16), 3, 7)
# print("Decomposition of ", "0xd0b0d0f7" , " is ", decomposition)
# print("----------------------------------------")

# decomposition = decompose_input(int("0x913685b5",16), 3, 7)
# print("Decomposition of ", "0x913685b5" , " is ", decomposition)


reconstruction 0 2617245696
reconstruction 1 2617245696
reconstruction 2 6912208896
2617241600   2617241600
Decomposition of  2617242525  is  [-50, 0, -2]
