# 2. Gadget decomposition

**GOAL:** for a given polynomial $\boldsymbol{a}$ and an encryption $\textsf{RLWE}(m)$ (or something similar to it), find $\textsf{RLWE}(\boldsymbol{a}\cdot \boldsymbol{m})$, with small noise.

Gadget decomposition and $RLWE'$ technique allows us to multiply ciphertext to a *large* constant with *small* noise increment.

RLWE' is usually used for the key switching. We make an example of key switching from s1 to s2 here.

In [1]:
# Functions from previous lecturenote
import torch
import math

stddev = 3.2
logQ = 27

N = 2**10
Q = 2**logQ


def keygen(dim):
    return torch.randint(2, size = (dim,))

def errgen(stddev):
    e = torch.round(stddev*torch.randn(1))
    e = e.squeeze()
    return e.to(torch.int)

def errgen(stddev, N):
    e = torch.round(stddev*torch.randn(N))
    e = e.squeeze()
    return e.to(torch.int)

def uniform(dim, modulus):
    return torch.randint(modulus, size = (dim,))

def polymult(a, b, dim, modulus):
    res = torch.zeros(dim).to(torch.int)
    for i in range(dim):
        for j in range(dim):
            if i >= j:
                res[i] += a[j]*b[i-j]
                res[i] %= modulus
            else:
                res[i] -= a[j]*b[i-j] # Q - x mod Q = -x
                res[i] %= modulus

    res %= modulus
    return res

root_powers = torch.arange(N//2).to(torch.complex128)
root_powers = torch.exp((1j*math.pi/N)*root_powers)

root_powers_inv = torch.arange(0,-N//2,-1).to(torch.complex128)
root_powers_inv = torch.exp((1j*math.pi/N)*root_powers_inv)

def negacyclic_fft(a, N, Q):
    acomplex = a.to(torch.complex128)

    a_precomp = (acomplex[...,:N//2] + 1j * acomplex[..., N//2:]) * root_powers

    return torch.fft.fft(a_precomp)

def negacyclic_ifft(A, N, Q):
    b = torch.fft.ifft(A)
    b *= root_powers_inv

    a = torch.cat((b.real, b.imag), dim=-1)

    aint = a.to(torch.int32)
    # only when Q is a power-of-two
    aint &= Q-1

    return aint

In [2]:
# generate two keys
s1 = keygen(N)
s2 = keygen(N)

s1fft = negacyclic_fft(s1, N, Q)
s2fft = negacyclic_fft(s2, N, Q)

In [3]:
# make an RLWE encryption of message
def encrypt_to_fft(m, sfft):
    ct = torch.stack([errgen(stddev, N), uniform(N, Q)])
    ctfft = negacyclic_fft(ct, N, Q)

    ctfft[0] += -ctfft[1]*sfft + negacyclic_fft(m, N, Q)

    return ctfft

def decrypt_from_fft(ctfft, sfft):
    return negacyclic_ifft(ctfft[0] + ctfft[1]*sfft, N, Q)

In [4]:
m = torch.zeros((N), dtype=torch.int32)
m[0] = 1000000

m

tensor([1000000,       0,       0,  ...,       0,       0,       0],
       dtype=torch.int32)

In [5]:
ctfft = encrypt_to_fft(m, s1fft)
mdec = decrypt_from_fft(ctfft, s1fft)
mdec

tensor([   999999,         1,         1,  ...,         0, 134217725,
        134217727], dtype=torch.int32)

We cannot decrypt mdec with s2, as the secret key is different. It should look like a random value.

In [6]:
mdec_wrong = decrypt_from_fft(ctfft, s2fft)
mdec_wrong

tensor([94206457, 32622688, 21222813,  ..., 89645965, 13428919, 65218450],
       dtype=torch.int32)

We want to transform ct as a ciphertext of the same message but with different key s2 *without decryption*.

## 2.1. Gadget decomposition

We define *gadget decomposition* $h$ corresponding to gadget vector $\vec{g} = (g_0, g_1, \dots, g_{d-1})$ as follows.
$$
h: \mathbb{Z} \longmapsto \mathbb{Z}^d
$$
$$
||h(a)|| < B, \left< h(a), \vec{g}\right> = a,
$$
where B is a upper bound.

Also, we can naturally extend it to $\mathcal{R}$ and $\mathbb{Z}^n$.

For example, a number $77 = 0\text{b}01001101$ can be decomposed to $(0,1,0,0,1,1,0,1)$ when $\vec{g} = (2^7, 2^6, \dots, 1)$ and $B=2$

Otherwise, it can also be decomposed to $(1,0,3,1)$ when $\vec{g} = (4^3, 4^2, 4, 1)$, and $B=4$ here.

We also can do a *approximate gadget decomposition* where $\left< h(a), \vec{g}\right> \approx a$, i.e., $\left< h(a), \vec{g}\right>$ is not exactly the same but similar to $a$.

We first set d and B satisfying $B^{d} \le Q \le B^{d+1}$

In [40]:
# we will use B-ary decomposition, i.e., cut digits by logB bits
d = 3
logB = 7

In [41]:
decomp_shift = logQ - logB*torch.arange(1,d+1).view(d,1)
decomp_shift

mask = (1<< logB)-1

In [42]:
def decompose(a):

    assert len(a.size()) <= 2

    if len(a.size()) == 1:
        return (a.unsqueeze(0) >> decomp_shift.view(d, 1)) & mask
    elif len(a.size()) == 2:
        return (a.unsqueeze(0) >> decomp_shift.view(d, 1, 1)) & mask


In [43]:
gvector = 1<<decomp_shift
gvector

tensor([[1048576],
        [   8192],
        [     64]])

In [44]:
a = uniform(N, Q)
a

tensor([ 97657358,  85445077,  29105743,  ..., 127586060,  93580405,
         44880300])

In [45]:
da = decompose(a)
# see all values are smaller than 2^7 = 128
da

tensor([[ 93,  81,  27,  ..., 121,  89,  42],
        [ 17,  62,  96,  ...,  86,  31, 102],
        [  8,  39, 121,  ...,  60,  49,  70]])

In [46]:
#composition is inner product, see it is similar to a
torch.sum(da * gvector, dim = 0)

tensor([ 97657344,  85445056,  29105728,  ..., 127586048,  93580352,
         44880256])

We can extend it to a ciphertext too

In [47]:
ctfft = encrypt_to_fft(m, s1fft)
ct = negacyclic_ifft(ctfft, N, Q)
ct

tensor([[ 53410785,  43055478,  73744037,  ..., 103529846,  82593721,
         115845510],
        [ 33853395,  17871358, 103050552,  ...,  71641821,  69220000,
          65527806]], dtype=torch.int32)

In [48]:
torch.sum(decompose(ct) * gvector.view(d, 1, 1), dim = 0)

tensor([[ 53410752,  43055424,  73744000,  ..., 103529792,  82593664,
         115845504],
        [ 33853376,  17871296, 103050496,  ...,  71641792,  69219968,
          65527744]])

## 2.2.RLWE' Ciphertext and key switching keys.

We can make a tuple of RLWE ciphertexts corresponding to a gadget vector $\vec{g} = (g_0, \dots, g_{d-1})$, and call it RLWE'.
$$
RLWE'( \boldsymbol{s} ) =\left( RLWE(g_0 \boldsymbol{s}), RLWE(g_1 \boldsymbol{s}), \dots, RLWE(g_{d-1} \boldsymbol{s})  \right) 
\in \mathcal{R}^{d\times N}
$$

Then the inner product between $h(\boldsymbol{m})$ and $RLWE'( \boldsymbol{s} )$ will give us RLWE (not RLWE') encryption of $\boldsymbol{m}$, $RLWE(\boldsymbol{a \cdot s})$.
The correctness can be seen as follows:
$$
\left<(\boldsymbol{m}_0, \dots, \boldsymbol{m}_{d-1}),  \left( RLWE(g_0 \boldsymbol{s}), \dots, RLWE(g_{d-1} \boldsymbol{s})  \right) \right>
= \sum_{i = 0}^{d-1} (\boldsymbol{m}_i \cdot RLWE(g_i \boldsymbol{s}))
= RLWE( \sum_{i = 0}^{d-1} (\boldsymbol{m}_i \cdot g_i \boldsymbol{s}))
= RLWE( \sum_{i = 0}^{d-1} (\boldsymbol{m} \cdot \boldsymbol{s}))
$$

### 2.2.1. Error analysis (why RLWE'?)

We can also get a ciphertext of $RLWE(\boldsymbol{m}\boldsymbol{s})$ by multiplying $\boldsymbol{m}$ to each element of $RLWE(\boldsymbol{s})$.
In other words, $(\boldsymbol{m} \cdot \boldsymbol{b}, \boldsymbol{m}\cdot \boldsymbol{a})$ is $RLWE(\boldsymbol{m}\boldsymbol{s})$, 
where $RLWE(\boldsymbol{s}) = (\boldsymbol{b}, \boldsymbol{a})$.

**Naive multiplication**

However, an error $\boldsymbol{e}$ is contained in $RLWE(\boldsymbol{s})$, so the decryption $\boldsymbol{b} + \boldsymbol{a} \cdot  \boldsymbol{s}$ will be given as 
$$ 
\boldsymbol{s} +  \boldsymbol{e}.
$$
Thus, the decryption of $RLWE(\boldsymbol{ms}) = (\boldsymbol{mb}, \boldsymbol{ma})$ results in $ \boldsymbol{ms} +  \boldsymbol{me}$.

**Multiplication using RLWE'**

It is okay when $\boldsymbol{m}$ is small (so it is used to multiply a small constant), but we usually need to to multiply $\boldsymbol{m}$ uniformly sampled in $\mathbb{Z}_Q$.
In this case, the error variance will be $Q^2 \sigma^2$, where $\sigma$ is variance of $\boldsymbol{e}$, and it overwhelms the message.

Instead, if we use the RLWE' product, each ciphertext $RLWE'(g_i \boldsymbol{s})$ is multiplied by $\boldsymbol{m}_i$, whose size is smaller than $B$.
Assuming they are uniformly distributed, the error variance of $RLWE'(\boldsymbol{m}_i g_i \boldsymbol{s})$ should be $B^2/12 \sigma^2$.
Adding $d$ of them, the error variance is  
$$
dB^2/12 \sigma^2,
$$ 
where $B = Q^{1/d} \ll Q$.

Naive multiplication is *infeasible* when $\boldsymbol{m}$ is large, which is exactly the case we want. 
So, we need RLWE' ciphertext and multiplication using it.

Let's encrypt key using RLWE' ciphertext.

In [50]:
rlwep = torch.zeros(d, 2, N, dtype=torch.int32)

# generate a part
rlwep[:, 1, :] = torch.randint(Q, size = (d, N), dtype= torch.int32)

rlwep

tensor([[[        0,         0,         0,  ...,         0,         0,
                  0],
         [124659440,  94520205,  71042207,  ..., 110570277,   3546280,
          124776653]],

        [[        0,         0,         0,  ...,         0,         0,
                  0],
         [ 19101960,   5690381,  27804331,  ...,  46262763,  57761828,
           71623088]],

        [[        0,         0,         0,  ...,         0,         0,
                  0],
         [ 12979970,  65446211, 110722805,  ..., 116226067,  47933314,
           77185437]]], dtype=torch.int32)

In [51]:
# add error on b
rlwep[:, 0, :] = torch.round(stddev * torch.randn(size = (d, N)))
rlwep

tensor([[[       -1,        -2,        10,  ...,         3,        -1,
                  5],
         [124659440,  94520205,  71042207,  ..., 110570277,   3546280,
          124776653]],

        [[        2,        -1,         5,  ...,        -1,         0,
                 -5],
         [ 19101960,   5690381,  27804331,  ...,  46262763,  57761828,
           71623088]],

        [[       -7,        -2,         4,  ...,         4,        -1,
                  3],
         [ 12979970,  65446211, 110722805,  ..., 116226067,  47933314,
           77185437]]], dtype=torch.int32)

In [52]:
# do fft for easy a*s
rlwepfft = negacyclic_fft(rlwep, N, Q)

# now b = -a*s2 + e
rlwepfft[:, 0, :] -= rlwepfft[:, 1, :] * s2fft

In [53]:
# return back to R_Q
rlwepifft = negacyclic_ifft(rlwepfft, N, Q)
rlwepifft

tensor([[[120516150,  62334107,  89060741,  ...,  95901350,  72683104,
           58747394],
         [124659440,  94520205,  71042206,  ..., 110570277,   3546280,
          124776653]],

        [[ 32846998, 116461100,  25032157,  ..., 116260390,  18385891,
           19095183],
         [ 19101960,   5690380,  27804330,  ...,  46262763,  57761828,
           71623088]],

        [[125341286, 101676018,  15652642,  ...,  27402630, 111753956,
           33526242],
         [ 12979970,  65446210, 110722805,  ..., 116226067,  47933314,
           77185437]]], dtype=torch.int32)

In [54]:
# add decomposition of s1* vec{g}
gs1 = gvector * s1

gs1.size()

torch.Size([3, 1024])

In [55]:
rlwepifft[:, 0, :] += gs1

In [56]:
ksk = rlwepifft

In [85]:
ksk

tensor([[[121564726,  63382683,  89060741,  ...,  95901350,  73731680,
           59795970],
         [124659440,  94520205,  71042206,  ..., 110570277,   3546280,
          124776653]],

        [[ 32855190, 116469292,  25032157,  ..., 116260390,  18394083,
           19103375],
         [ 19101960,   5690380,  27804330,  ...,  46262763,  57761828,
           71623088]],

        [[125341350, 101676082,  15652642,  ...,  27402630, 111754020,
           33526306],
         [ 12979970,  65446210, 110722805,  ..., 116226067,  47933314,
           77185437]]], dtype=torch.int32)

In [86]:
ct

tensor([[ 53410785,  43055478,  73744037,  ..., 103529846,  82593721,
         115845510],
        [ 33853395,  17871358, 103050552,  ...,  71641821,  69220000,
          65527806]], dtype=torch.int32)

In [139]:
mdec = decrypt_from_fft(negacyclic_fft(ct, N, Q), s1fft)
print(mdec)
normalize(mdec, Q)
print(mdec)

tensor([  1000222,       205,       223,  ..., 134217515, 134217506,
        134217501], dtype=torch.int32)
tensor([1000222,     205,     223,  ...,    -213,    -222,    -227],
       dtype=torch.int32)


In [140]:
mdec

tensor([1000222,     205,     223,  ...,    -213,    -222,    -227],
       dtype=torch.int32)

In [142]:
decompose(ct[1]).size()

torch.Size([3, 1024])

In [143]:
ksk.size()

torch.Size([3, 2, 1024])

In [144]:
prod = decompose(ct[1]).view(d, 1, N) * ksk
prodsum = torch.sum(prod, dim = 0)

prodsum.size(), prodsum

(torch.Size([2, 1024]),
 tensor([[12969363122,  8878853893, 10418015497,  ..., 12466280880,
          14942546763,  9149600252],
         [ 5494510750,  6281976295, 13692873598,  ..., 14413273000,
           4605814568, 21979334441]]))

In [145]:
prodfft = negacyclic_fft(prodsum, N, Q)

Now, we can decrypt using the switched key $s_2$.

In [147]:
mks = decrypt_from_fft(prodfft, s2fft)
print(mks)

tensor([115243276,  29342566, 110492725,  ...,  57805258,  79197029,
         71597867], dtype=torch.int32)


In [149]:
normalize(mks, Q)

tensor([-18974452,  29342566, -23725003,  ...,  57805258, -55020699,
        -62619861], dtype=torch.int32)

In [138]:
def normalize(v, Q):
    # same as follows but no branch
    """
    if v > Q//2:
        v -= Q
    """
    # vmod Q when Q is a power-of-two
    v &= Q-1
    # get msb
    msb = (v & Q//2) >>  (logQ - 1)
    v -= (Q) * msb
    return v