# 2. Signed gadget decomposition

**GOAL:** Reduce error (4 times in variance) by making $B/2 \le h(a) < B/2$, instead of $h(a) < B$.

In [1]:
# Functions from previous lecturenote
import torch
import math

stddev = 3.2
logQ = 27

N = 2**10
Q = 2**logQ


def keygen(dim):
    return torch.randint(2, size = (dim,))

def errgen(stddev):
    e = torch.round(stddev*torch.randn(1))
    e = e.squeeze()
    return e.to(torch.int)

def errgen(stddev, N):
    e = torch.round(stddev*torch.randn(N))
    e = e.squeeze()
    return e.to(torch.int)

def uniform(dim, modulus):
    return torch.randint(modulus, size = (dim,))

def polymult(a, b, dim, modulus):
    res = torch.zeros(dim).to(torch.int)
    for i in range(dim):
        for j in range(dim):
            if i >= j:
                res[i] += a[j]*b[i-j]
                res[i] %= modulus
            else:
                res[i] -= a[j]*b[i-j] # Q - x mod Q = -x
                res[i] %= modulus

    res %= modulus
    return res

root_powers = torch.arange(N//2).to(torch.complex128)
root_powers = torch.exp((1j*math.pi/N)*root_powers)

root_powers_inv = torch.arange(0,-N//2,-1).to(torch.complex128)
root_powers_inv = torch.exp((1j*math.pi/N)*root_powers_inv)

def negacyclic_fft(a, N, Q):
    acomplex = a.to(torch.complex128)

    a_precomp = (acomplex[...,:N//2] + 1j * acomplex[..., N//2:]) * root_powers

    return torch.fft.fft(a_precomp)

def negacyclic_ifft(A, N, Q):
    b = torch.fft.ifft(A)
    b *= root_powers_inv

    a = torch.cat((b.real, b.imag), dim=-1)

    aint = a.to(torch.int64).to(torch.int32)
    # only when Q is a power-of-two
    aint &= Q-1

    return aint

In [2]:
# make an RLWE encryption of message
def encrypt_to_fft(m, sfft):
    ct = torch.stack([errgen(stddev, N), uniform(N, Q)])
    ctfft = negacyclic_fft(ct, N, Q)

    ctfft[0] += -ctfft[1]*sfft + negacyclic_fft(m, N, Q)

    return ctfft

def normalize(v, logQ):
    # same as follows but no branch
    """
    if v > Q//2:
        v -= Q
    """
    # vmod Q when Q is a power-of-two
    Q = (1 << logQ)
    v &= Q-1
    # get msb
    msb = (v & Q//2) >> (logQ - 1)
    v -= (Q) * msb
    return v

def decrypt_from_fft(ctfft, sfft):
    return normalize(negacyclic_ifft(ctfft[0] + ctfft[1]*sfft, N, Q), logQ)

In [3]:
# we will use B-ary decomposition, i.e., cut digits by logB bits
d = 4
logB = 6

In [4]:
decomp_shift = logQ - logB*torch.arange(d,0,-1).view(d,1)
decomp_shift

mask = (1 << logB) - 1

decomp_shift

tensor([[ 3],
        [ 9],
        [15],
        [21]])

In [5]:
gvector = 1 << decomp_shift
gvector

tensor([[      8],
        [    512],
        [  32768],
        [2097152]])

In [6]:
def decompose(a):
    
    assert len(a.size()) <= 2
    # for RLWE'
    if len(a.size()) == 1:
        res = (a.unsqueeze(0) >> decomp_shift.view(d, 1)) & mask
        return res
    # for RGSW
    elif len(a.size()) == 2:
        res = (a.unsqueeze(0) >> decomp_shift.view(d, 1, 1)) & mask
        return res



New decomposition with sign, thus $|da| < B/2$

In [7]:
msbmask = 0
for i in decomp_shift:
    msbmask += (1<<(i+logB-1))

bin(msbmask)[2:]

'100000100000100000100000000'

In [8]:
# about twice heavier than unsigned decomposition
# it returns value -B/2 <= * <= B/2, not < B/2, but okay
def signed_decompose(a):
    # carry
    da = decompose(a + (a & msbmask))
    # -B
    da -= decompose((a & msbmask))
    return da
    

In [9]:
a = uniform(N, Q)
a[:10]

tensor([123779987,  20546552, 109057511,  40920167, 108014325,  64325800,
        130649608,  79066916, 114671913, 128893103])

In [10]:
da = decompose(a)
# see all values are smaller than 2^6 = 64
da

tensor([[50, 63, 60,  ..., 11, 41, 17],
        [29,  1, 10,  ..., 21, 28, 32],
        [ 1, 51,  0,  ..., 15, 19,  0],
        [59,  9, 52,  ...,  6, 11, 51]])

In [11]:
min(da.flatten()), max(da.flatten())

(tensor(0), tensor(63))

In [12]:
sda = signed_decompose(a)
sda

tensor([[-14,  -1,  -4,  ...,  11, -23,  17],
        [ 30,   2,  11,  ...,  21,  29, -32],
        [  1, -13,   0,  ...,  15,  19,   1],
        [ -5,  10, -12,  ...,   6,  11, -13]])

$|h(a)| \le B/2$

In [13]:
min(sda.flatten()), max(sda.flatten())

(tensor(-32), tensor(32))

In [14]:
aprime = torch.sum(da * gvector, dim = 0)
aprime

tensor([123779984,  20546552, 109057504,  ...,  13085272,  23705928,
        106971272])

$< h(a), \vec{g}> \sim a$, normalized

In [15]:
saprime = torch.sum(sda * gvector, dim = 0)
saprime

tensor([-10437744,  20546552, -25160224,  ...,  13085272,  23705928,
        -27246456])

In [16]:
torch.equal(normalize(aprime, logQ), saprime)

True