# 2. Gadget decomposition

**GOAL:** for a given polynomial $\boldsymbol{a}$ and an encryption $\textsf{RLWE}(m)$ (or something similar to it), find $\textsf{RLWE}(\boldsymbol{a}\cdot \boldsymbol{m})$, with small noise.

Gadget decomposition and $RLWE'$ technique allows us to multiply ciphertext to a *large* constant with *small* noise increment.

RLWE' is usually used for the key switching. We make an example of key switching from s1 to s2 here.

In [1]:
# Functions from previous lecturenote
import torch
import math

stddev = 3.2
logQ = 27

N = 2**10
Q = 2**logQ


def keygen(dim):
    return torch.randint(2, size = (dim,))

def errgen(stddev):
    e = torch.round(stddev*torch.randn(1))
    e = e.squeeze()
    return e.to(torch.int)

def errgen(stddev, N):
    e = torch.round(stddev*torch.randn(N))
    e = e.squeeze()
    return e.to(torch.int)

def uniform(dim, modulus):
    return torch.randint(modulus, size = (dim,))

def polymult(a, b, dim, modulus):
    res = torch.zeros(dim).to(torch.int)
    for i in range(dim):
        for j in range(dim):
            if i >= j:
                res[i] += a[j]*b[i-j]
                res[i] %= modulus
            else:
                res[i] -= a[j]*b[i-j] # Q - x mod Q = -x
                res[i] %= modulus

    res %= modulus
    return res

root_powers = torch.arange(N//2).to(torch.complex128)
root_powers = torch.exp((1j*math.pi/N)*root_powers)

root_powers_inv = torch.arange(0,-N//2,-1).to(torch.complex128)
root_powers_inv = torch.exp((1j*math.pi/N)*root_powers_inv)

def negacyclic_fft(a, N, Q):
    acomplex = a.to(torch.complex128)

    a_precomp = (acomplex[...,:N//2] + 1j * acomplex[..., N//2:]) * root_powers

    return torch.fft.fft(a_precomp)

def negacyclic_ifft(A, N, Q):
    b = torch.fft.ifft(A)
    b *= root_powers_inv

    a = torch.cat((b.real, b.imag), dim=-1)

    aint = a.to(torch.int32)
    # only when Q is a power-of-two
    aint &= Q-1

    return aint

In [2]:
# generate two keys
s1 = keygen(N)
s2 = keygen(N)

s1fft = negacyclic_fft(s1, N, Q)
s2fft = negacyclic_fft(s2, N, Q)

In [3]:
# make an RLWE encryption of message
def encrypt_to_fft(m, sfft):
    ct = torch.stack([errgen(stddev, N), uniform(N, Q)])
    ctfft = negacyclic_fft(ct, N, Q)

    ctfft[0] += -ctfft[1]*sfft + negacyclic_fft(m, N, Q)

    return ctfft

def normalize(v, logQ):
    # same as follows but no branch
    """
    if v > Q//2:
        v -= Q
    """
    # vmod Q when Q is a power-of-two
    Q = (1 << logQ)
    v &= Q-1
    # get msb
    msb = (v & Q//2) >> (logQ - 1)
    v -= (Q) * msb
    return v

def decrypt_from_fft(ctfft, sfft):
    return normalize(negacyclic_ifft(ctfft[0] + ctfft[1]*sfft, N, Q), logQ)

In [4]:
m = torch.zeros((N), dtype=torch.int32)
m[0] = 1000000

m

tensor([1000000,       0,       0,  ...,       0,       0,       0],
       dtype=torch.int32)

In [5]:
ctfft = encrypt_to_fft(m, s1fft)
mdec = decrypt_from_fft(ctfft, s1fft)
mdec

tensor([999997,     -3,     -1,  ...,     -1,      1,      0],
       dtype=torch.int32)

We cannot decrypt mdec with s2, as the secret key is different. It should look like a random value.

In [6]:
mdec_wrong = decrypt_from_fft(ctfft, s2fft)
mdec_wrong

tensor([-30775409,  29496231,  58264218,  ...,  30844357, -58722876,
         -5122044], dtype=torch.int32)

We want to transform ct as a ciphertext of the same message but with different key s2 *without decryption*.

## 2.1. Gadget decomposition

We define *gadget decomposition* $h$ corresponding to gadget vector $\vec{g} = (g_0, g_1, \dots, g_{d-1})$ as follows.
$$
h: \mathbb{Z} \longmapsto \mathbb{Z}^d
$$
$$
||h(a)|| < B, \left< h(a), \vec{g}\right> = a,
$$
where B is a upper bound.

Also, we can naturally extend it to $\mathcal{R}$ and $\mathbb{Z}^n$.

For example, a number $77 = 0\text{b}01001101$ can be decomposed to $(1,0,1,1,0,0,1,0)$ when $\vec{g} = (1, 2, \dots, 2^7)$ and $B=2$

Otherwise, it can also be decomposed to $(1,3,0,1)$ when $\vec{g} = (1, 4, 4^2, 4^3)$, and $B=4$ here.

We also can do a *approximate gadget decomposition* where $\left< h(a), \vec{g}\right> \approx a$, i.e., $\left< h(a), \vec{g}\right>$ is not exactly the same but similar to $a$.

We first set d and B satisfying $B^{d} \le Q \le B^{d+1}$

In [7]:
# we will use B-ary decomposition, i.e., cut digits by logB bits
d = 4
logB = 6

We shift by amount of decomp_shit and cut `logB` MSBs

In [8]:
decomp_shift = logQ - logB*torch.arange(d,0,-1).view(d,1)
decomp_shift

mask = (1 << logB) - 1

decomp_shift

tensor([[ 3],
        [ 9],
        [15],
        [21]])

We can see that the gadget vector, 
$$
\vec{g} = (B, B^2, \dots, B^{d-1})
$$
is given as follows.

In [9]:
gvector = 1 << decomp_shift
gvector

tensor([[      8],
        [    512],
        [  32768],
        [2097152]])

The decomposition function handles both RLWE' and RGSW case.

In [10]:
def decompose(a):
    
    assert len(a.size()) <= 2
    # for RLWE'
    if len(a.size()) == 1:
        res = (a.unsqueeze(0) >> decomp_shift.view(d, 1)) & mask
        return res
    # for RGSW
    elif len(a.size()) == 2:
        res = (a.unsqueeze(0) >> decomp_shift.view(d, 1, 1)) & mask
        return res



In [11]:
a = uniform(N, Q)
a[:10]

tensor([ 94193827,  51940049,  84661416,  21719618,  41322980,  35156455,
         49498149, 116411239,  20979234,  63728665])

See the decomposed vector always has value less than $B = 2^6 = 64$.

In [12]:
da = decompose(a)
# see all values are smaller than 2^6 = 64
da

tensor([[20, 26, 21,  ...,  3, 48, 60],
        [36,  5, 42,  ..., 33, 22, 10],
        [58, 49, 23,  ..., 57, 49, 34],
        [44, 24, 40,  ..., 12, 23, 30]])

See $<h(a), \vec{g}> \approx a$. 

In [13]:
#composition is inner product, see it is similar to a
torch.sum(da * gvector, dim = 0)

tensor([94193824, 51940048, 84661416,  ..., 27050520, 49851776, 64034272])

In [14]:
a

tensor([94193827, 51940049, 84661416,  ..., 27050525, 49851783, 64034279])

In [15]:
diff = a - torch.sum(da * gvector, dim = 0)
diff %= Q
diff = normalize(diff, logQ)
diff[:10]

tensor([3, 1, 0, 2, 4, 7, 5, 7, 2, 1])

We can extend it to a ciphertext too

In [16]:
m = torch.zeros(N)
m[0] = Q//4
m[3] = Q//4
print("m:\n", m[:10])
ctfft = encrypt_to_fft(m, s1fft)
ct = negacyclic_ifft(ctfft, N, Q)
ct

m:
 tensor([33554432.,        0.,        0., 33554432.,        0.,        0.,
               0.,        0.,        0.,        0.])


tensor([[ 39737414,  13959872,  92131300,  ...,  15676777,  59094439,
          44448668],
        [ 81619725, 119691098,  49104517,  ...,   7291234,   4035796,
         130607505]], dtype=torch.int32)

We will see the message part in the next subsection

In [17]:
ctnew = torch.sum(decompose(ct) * gvector.view(d, 1, 1), dim = 0)
ctnew

tensor([[ 39737408,  13959872,  92131296,  ...,  15676776,  59094432,
          44448664],
        [ 81619720, 119691096,  49104512,  ...,   7291232,   4035792,
         130607504]])

Decrypt it

In [18]:
ctnewfft = negacyclic_fft(ctnew, N, Q)

dm = decrypt_from_fft(ctnewfft, s1fft)[:10]
print("m:\n", m[:10])
print("dm:\n", dm[:10])


m:
 tensor([33554432.,        0.,        0., 33554432.,        0.,        0.,
               0.,        0.,        0.,        0.])
dm:
 tensor([33556360,     1840,     1849, 33556272,     1968,     1873,     1832,
            1904,     1904,     1968], dtype=torch.int32)


## 2.2.RLWE' Ciphertext and key switching keys.

We can make a tuple of RLWE ciphertexts corresponding to a gadget vector $\vec{g} = (g_0, \dots, g_{d-1})$, and call it RLWE'.
$$
RLWE'( \boldsymbol{s} ) =\left( RLWE(g_0 \boldsymbol{s}), RLWE(g_1 \boldsymbol{s}), \dots, RLWE(g_{d-1} \boldsymbol{s})  \right) 
\in \mathcal{R}^{d\times N}
$$

Then the inner product between $h(\boldsymbol{m})$ and $RLWE'( \boldsymbol{s} )$ will give us RLWE (not RLWE'!!!) encryption of $\boldsymbol{m\cdot s}$, $RLWE(\boldsymbol{m \cdot s})$.
The correctness can be seen as follows:
$$
\left<(\boldsymbol{m}_0, \dots, \boldsymbol{m}_{d-1}),  \left( RLWE(g_0 \boldsymbol{s}), \dots, RLWE(g_{d-1} \boldsymbol{s})  \right) \right>
= \sum_{i = 0}^{d-1} (\boldsymbol{m}_i \cdot RLWE(g_i \boldsymbol{s}))
= RLWE( \sum_{i = 0}^{d-1} (\boldsymbol{m}_i \cdot g_i \boldsymbol{s}))
= RLWE( \sum_{i = 0}^{d-1} (\boldsymbol{m} \cdot \boldsymbol{s}))
$$



### 2.2.1. Error analysis (why RLWE'?)

We can also get a ciphertext of $RLWE(\boldsymbol{m}\boldsymbol{s})$ by multiplying $\boldsymbol{m}$ to each element  of $RLWE(\boldsymbol{s})$.
In other words, $(\boldsymbol{m} \cdot \boldsymbol{b}, \boldsymbol{m}\cdot \boldsymbol{a})$ is $RLWE(\boldsymbol{m}\boldsymbol{s})$, 
where $RLWE(\boldsymbol{s}) = (\boldsymbol{b}, \boldsymbol{a})$.

**Naive multiplication**

However, an error $\boldsymbol{e}$ is contained in $RLWE(\boldsymbol{s})$, so the decryption $\boldsymbol{b} + \boldsymbol{a} \cdot  \boldsymbol{s}$ will be given as 
$$ 
\boldsymbol{s} +  \boldsymbol{e}.
$$
Thus, the decryption of $RLWE(\boldsymbol{ms}) = (\boldsymbol{mb}, \boldsymbol{ma})$ results in $ \boldsymbol{ms} +  \boldsymbol{me}$.

**Multiplication using RLWE'**

It is okay when $\boldsymbol{m}$ is small (so it is used to multiply a small constant), but we usually need to to multiply $\boldsymbol{m}$ uniformly sampled in $\mathbb{Z}_Q$.
In this case, the error variance will be $O(Q^2 \sigma^2)$, where $\sigma$ is variance of $\boldsymbol{e}$, and it overwhelms the message.

Instead, if we use the RLWE' product, each ciphertext $RLWE'(g_i \boldsymbol{s})$ is multiplied by $\boldsymbol{m}_i$, whose size is smaller than $B$.
Assuming they are uniformly distributed, the error variance of $RLWE'(\boldsymbol{m}_i g_i \boldsymbol{s})$ should be $B^2/12 \sigma^2$.
Adding $d$ of them, the error variance is  
$$
dB^2/12 \sigma^2,
$$ 
where $B = Q^{1/d} \ll Q$.

Naive multiplication is *infeasible* when $\boldsymbol{m}$ is large, which is exactly the case we want. 
So, we need RLWE' ciphertext (as key) and multiplication using it.

Let's encrypt a key using RLWE' ciphertext.

In [19]:
# it has a dimension of d, 2, N
rlwep = torch.zeros(d, 2, N, dtype=torch.int32)

# generate the 'a' part
rlwep[:, 1, :] = torch.randint(Q, size = (d, N), dtype= torch.int32)

rlwep

tensor([[[        0,         0,         0,  ...,         0,         0,
                  0],
         [105131693,  43028727,  59218047,  ...,  23781605,  83649302,
           61094060]],

        [[        0,         0,         0,  ...,         0,         0,
                  0],
         [ 98837028,  22036715,  31255247,  ...,  60782403,  44504675,
            6568757]],

        [[        0,         0,         0,  ...,         0,         0,
                  0],
         [105422446, 103038924,  36428025,  ...,  57190672,  22163141,
           72082054]],

        [[        0,         0,         0,  ...,         0,         0,
                  0],
         [ 93398305,  29353121, 126119049,  ...,   2045833, 133602593,
           63093644]]], dtype=torch.int32)

In [20]:
# add error on b
rlwep[:, 0, :] = torch.round(stddev * torch.randn(size = (d, N)))

rlwep

tensor([[[       -1,         1,         2,  ...,         5,         0,
                 -3],
         [105131693,  43028727,  59218047,  ...,  23781605,  83649302,
           61094060]],

        [[        5,        -5,         1,  ...,         1,        -2,
                 -3],
         [ 98837028,  22036715,  31255247,  ...,  60782403,  44504675,
            6568757]],

        [[        4,        -5,        -7,  ...,        -4,         3,
                  2],
         [105422446, 103038924,  36428025,  ...,  57190672,  22163141,
           72082054]],

        [[        2,        -1,        -6,  ...,         8,         4,
                 -4],
         [ 93398305,  29353121, 126119049,  ...,   2045833, 133602593,
           63093644]]], dtype=torch.int32)

In [21]:
# following is equal to rlwep %= Q, but a faster version
rlwep &= (Q-1)

rlwep

tensor([[[134217727,         1,         2,  ...,         5,         0,
          134217725],
         [105131693,  43028727,  59218047,  ...,  23781605,  83649302,
           61094060]],

        [[        5, 134217723,         1,  ...,         1, 134217726,
          134217725],
         [ 98837028,  22036715,  31255247,  ...,  60782403,  44504675,
            6568757]],

        [[        4, 134217723, 134217721,  ..., 134217724,         3,
                  2],
         [105422446, 103038924,  36428025,  ...,  57190672,  22163141,
           72082054]],

        [[        2, 134217727, 134217722,  ...,         8,         4,
          134217724],
         [ 93398305,  29353121, 126119049,  ...,   2045833, 133602593,
           63093644]]], dtype=torch.int32)

In [22]:
# do fft for easy a*s
rlwepfft = negacyclic_fft(rlwep, N, Q)

# now b = -a*s2 + e
rlwepfft[:, 0, :] -= rlwepfft[:, 1, :] * s2fft.view(1,N//2)

In [23]:
# return back to R_Q
rlwepifft = negacyclic_ifft(rlwepfft, N, Q)

rlwepifft

tensor([[[105703223,  46420690,  53786687,  ..., 102437439,  49696800,
           88625034],
         [105131693,  43028727,  59218046,  ...,  23781604,  83649302,
           61094060]],

        [[ 82051802,  19886463,   1014332,  ..., 120222834,  79401825,
          128812660],
         [ 98837028,  22036715,  31255247,  ...,  60782403,  44504675,
            6568757]],

        [[ 41358369,   8097865, 128448284,  ..., 101864579,  30638379,
           64260024],
         [105422446, 103038924,  36428024,  ...,  57190672,  22163141,
           72082054]],

        [[ 37733434,  34789102, 107132699,  ...,  58341145,  41061950,
           45352168],
         [ 93398305,  29353121, 126119048,  ...,   2045832, 133602593,
           63093644]]], dtype=torch.int32)

In [24]:
# add decomposition of s1* vec{g}
gs1 = gvector * s1

gs1

tensor([[      8,       0,       8,  ...,       0,       8,       0],
        [    512,       0,     512,  ...,       0,     512,       0],
        [  32768,       0,   32768,  ...,       0,   32768,       0],
        [2097152,       0, 2097152,  ...,       0, 2097152,       0]])

In [25]:
rlwepifft[:, 0, :] += gs1

In [26]:
ksk = negacyclic_fft(rlwepifft, N, Q)

In [27]:
ksk.size()

torch.Size([4, 2, 512])

In [28]:
ct

tensor([[ 39737414,  13959872,  92131300,  ...,  15676777,  59094439,
          44448668],
        [ 81619725, 119691098,  49104517,  ...,   7291234,   4035796,
         130607505]], dtype=torch.int32)

In [29]:
mdec = decrypt_from_fft(ctfft, s1fft)
print(mdec[:10])

tensor([33554427,        0,        3, 33554432,        0,       -3,        0,
              -5,       -1,        1], dtype=torch.int32)


In [30]:
dct = decompose(ct[1])
dctfft = negacyclic_fft(dct, N, Q)
dctfft.size()

torch.Size([4, 512])

In [31]:
print(ct[1])
print(torch.sum(dct*gvector, dim = 0))

tensor([ 81619725, 119691098,  49104517,  ...,   7291234,   4035796,
        130607505], dtype=torch.int32)
tensor([ 81619720, 119691096,  49104512,  ...,   7291232,   4035792,
        130607504])


Following is
$$
a \odot RLWE'_{s2}(s1) = RLWE_{s2}(a \cdot s1)
$$

In [32]:
prodfft = dctfft.view(d, 1, N//2) * ksk
prodsumfft = torch.sum(prodfft, dim = 0)

prodsumfft.size(), prodsumfft

(torch.Size([2, 512]),
 tensor([[-3.6756e+15-6.4560e+13j, -4.0228e+14+4.5322e+11j,
          -7.3409e+13+4.6276e+12j,  ...,
          -2.4512e+13-2.4045e+12j, -5.0447e+13+3.9861e+12j,
          -1.4421e+14+7.1230e+12j],
         [-3.5640e+15-2.0938e+13j, -3.9533e+14-1.4981e+12j,
          -7.8482e+13+6.2964e+12j,  ...,
          -2.0769e+13-2.4418e+12j, -5.8814e+13+7.9083e+12j,
          -1.3519e+14+2.3577e+13j]], dtype=torch.complex128))

Adding $b$ to above, we get
$$
RLWE_{\boldsymbol{s2}}(\boldsymbol{a} \cdot \boldsymbol{s1}) + (\boldsymbol{b}, \boldsymbol{0}) = RLWE_{\boldsymbol{s2}}(\boldsymbol{a} \cdot \boldsymbol{s1} + \boldsymbol{b})
= RLWE_{\boldsymbol{s2}}(\boldsymbol{m} )
$$


In [33]:
prodsumfft[0] += ctfft[0]

Now, we can decrypt using the switched key $s_2$.

Check if the decryption is successful.

In [34]:
mks = decrypt_from_fft(prodsumfft, s2fft)
print("m:\n",m[:10])
print("decrypted:\n",mks[:10])
print("decrypted (scaled to 1):\n",mks[:10]/(Q//4))

m:
 tensor([33554432.,        0.,        0., 33554432.,        0.,        0.,
               0.,        0.,        0.,        0.])
decrypted:
 tensor([34075874,   589722,   652463, 34230423,   729336,   798338,   854322,
          885747,   950882,  1012988], dtype=torch.int32)
decrypted (scaled to 1):
 tensor([1.0155, 0.0176, 0.0194, 1.0201, 0.0217, 0.0238, 0.0255, 0.0264, 0.0283,
        0.0302])
