# 3. RGSW Operation

Now we have all the required building blocks for RGSW operations.
We already checked RLWE' multiplication is working.

Call all previous methods

In [1]:
# Functions from previous lecturenote
import torch
import math

stddev = 3.2
logQ = 27

N = 2**10
Q = 2**logQ

def keygen(dim):
    return torch.randint(2, size = (dim,))

def errgen(stddev):
    e = torch.round(stddev*torch.randn(1))
    e = e.squeeze()
    return e.to(torch.int)

def errgen(stddev, N):
    e = torch.round(stddev*torch.randn(N))
    e = e.squeeze()
    return e.to(torch.int)

def uniform(dim, modulus):
    return torch.randint(modulus, size = (dim,))

def polymult(a, b, dim, modulus):
    res = torch.zeros(dim).to(torch.int)
    for i in range(dim):
        for j in range(dim):
            if i >= j:
                res[i] += a[j]*b[i-j]
                res[i] %= modulus
            else:
                res[i] -= a[j]*b[i-j] # Q - x mod Q = -x
                res[i] %= modulus

    res %= modulus
    return res

root_powers = torch.arange(N//2).to(torch.complex128)
root_powers = torch.exp((1j*math.pi/N)*root_powers)

root_powers_inv = torch.arange(0,-N//2,-1).to(torch.complex128)
root_powers_inv = torch.exp((1j*math.pi/N)*root_powers_inv)

def negacyclic_fft(a, N, Q):
    acomplex = a.to(torch.complex128)

    a_precomp = (acomplex[...,:N//2] + 1j * acomplex[..., N//2:]) * root_powers

    return torch.fft.fft(a_precomp)

def negacyclic_ifft(A, N, Q):
    b = torch.fft.ifft(A)
    b *= root_powers_inv

    a = torch.cat((b.real, b.imag), dim=-1)

    aint = a.to(torch.int32)
    # only when Q is a power-of-two
    aint &= Q-1

    return aint

In [2]:
# make an RLWE encryption of message
def encrypt_to_fft(m, sfft):
    ct = torch.stack([errgen(stddev, N), uniform(N, Q)])
    ctfft = negacyclic_fft(ct, N, Q)

    ctfft[0] += -ctfft[1]*sfft + negacyclic_fft(m, N, Q)

    return ctfft

def normalize(v, logQ):
    # same as follows but no branch
    """
    if v > Q//2:
        v -= Q
    """
    # vmod Q when Q is a power-of-two
    Q = (1 << logQ)
    v &= Q-1
    # get msb
    msb = (v & Q//2) >> (logQ - 1)
    v -= (Q) * msb
    return v

def decrypt_from_fft(ctfft, sfft):
    assert len(ctfft.size()) == 2
    return normalize(negacyclic_ifft(ctfft[0] + ctfft[1]*sfft, N, Q), logQ)

In [3]:
# we will use B-ary decomposition, i.e., cut digits by logB bits
d = 3
logB = 7

decomp_shift = logQ - logB*torch.arange(d,0,-1).view(d,1)
mask = (1 << logB) - 1

gvector = 1 << decomp_shift

In [4]:
def decompose(a):
    
    assert len(a.size()) <= 2
    # for RLWE'
    if len(a.size()) == 1:
        res = (a.unsqueeze(0) >> decomp_shift.view(d, 1)) & mask
        return res
    # for RGSW
    elif len(a.size()) == 2:
        res = (a.unsqueeze(0) >> decomp_shift.view(d, 1, 1)) & mask
        return res

msbmask = 0
for i in decomp_shift:
    msbmask += (1<<(i+logB-1))

bin(msbmask)[2:]

# about twice heavier than unsigned decomposition
# it returns value -B/2 <= * <= B/2, not < B/2, but okay
def signed_decompose(a):
    # carry
    da = decompose(a + (a & msbmask))
    # -B
    da -= decompose((a & msbmask))
    return da

In [5]:
def encrypt_rlwep_fft(z, skfft):
    # it has a dimension of d, 2, N
    rlwep = torch.zeros(d, 2, N, dtype=torch.int32)

    # generate the 'a' part
    rlwep[:, 1, :] = torch.randint(Q, size = (d, N), dtype= torch.int32)

    # add error on b
    # INSECURE: need to be fixed later
    rlwep[:, 0, :] = torch.round(stddev * torch.randn(size = (d, N))).to(torch.int32)

    # following is equal to rlwep %= Q, but a faster version
    rlwep &= (Q-1)
    rlwep = normalize(rlwep, logQ)

    # do fft for easy a*s
    rlwepfft = negacyclic_fft(rlwep, N, Q)

    # now b = -a*sk + e
    rlwepfft[:, 0, :] -= rlwepfft[:, 1, :] * skfft.view(1,N//2)

    # add decomposition of z * vec{g}
    gz = gvector * z
    rlwepfft[:, 0, :] += negacyclic_fft(gz, N, Q)

    return rlwepfft

## 3.1. Test the RLWE' encryptions

1. Key gen and encryption

In [6]:
# generate two keys
s1 = keygen(N)
s2 = keygen(N)

s1fft = negacyclic_fft(s1, N, Q)
s2fft = negacyclic_fft(s2, N, Q)

In [7]:
m = torch.zeros((N), dtype=torch.int32)
m[2] = Q//4
m[3] = Q//4
m[5] = Q//4

m[:10]

tensor([       0,        0, 33554432, 33554432,        0, 33554432,        0,
               0,        0,        0], dtype=torch.int32)

In [8]:
ctfft = encrypt_to_fft(m, s1fft)
mdec = decrypt_from_fft(ctfft, s1fft)
mdec[:10]

tensor([       4,       -1, 33554434, 33554431,        1, 33554431,       -2,
              -4,       -5,        2], dtype=torch.int32)

2. Key switching key ($s_1 \rightarrow s_2$) gen

In [9]:
kskfft = encrypt_rlwep_fft(s1, s2fft)

In [10]:
kskfft.size()

torch.Size([3, 2, 512])

3. Try key switching

In [11]:
def keyswitch(ctfft, kskfft):
    ct = negacyclic_ifft(ctfft, N, Q)
    
    da = signed_decompose(ct[1])
    dafft = negacyclic_fft(da, N, Q)

    # a * RLWE'(s1)
    prodfft = dafft.view(d, 1, N//2) * kskfft
    prodsumfft = torch.sum(prodfft, dim = 0)

    # (b,0) + a*RLWE'(s1)
    prodsumfft[0] += ctfft[0]

    return prodsumfft


In [12]:
switched = keyswitch(ctfft, kskfft)

mks = decrypt_from_fft(switched, s2fft)
print("m:\n",m[:10])
print("decrypted:\n",mks[:10])
print("decrypted (scaled to 1):\n",mks[:10]/(Q//4))

m:
 tensor([       0,        0, 33554432, 33554432,        0, 33554432,        0,
               0,        0,        0], dtype=torch.int32)
decrypted:
 tensor([    9344,    14083, 33583855, 33571263,     7272, 33562586,    18551,
           16576,    11830,     8365], dtype=torch.int32)
decrypted (scaled to 1):
 tensor([2.7847e-04, 4.1971e-04, 1.0009e+00, 1.0005e+00, 2.1672e-04, 1.0002e+00,
        5.5286e-04, 4.9400e-04, 3.5256e-04, 2.4930e-04])


Bonus: unsigned decomposition version

In [13]:
def unsigned_keyswitch(ctfft, kskfft):
    ct = negacyclic_ifft(ctfft, N, Q)
    
    da = decompose(ct[1])
    dafft = negacyclic_fft(da, N, Q)

    # a * RLWE'(s1)
    prodfft = dafft.view(d, 1, N//2) * kskfft
    prodsumfft = torch.sum(prodfft, dim = 0)

    # (b,0) + a*RLWE'(s1)
    prodsumfft[0] += ctfft[0]

    return prodsumfft


In [14]:
switched_unsigned = unsigned_keyswitch(ctfft, kskfft)

umks = decrypt_from_fft(switched_unsigned, s2fft)
print("m:\n",m[:10])
print("decrypted:\n",umks[:10])
print("decrypted (scaled to 1):\n",umks[:10]/(Q//4))



m:
 tensor([       0,        0, 33554432, 33554432,        0, 33554432,        0,
               0,        0,        0], dtype=torch.int32)
decrypted:
 tensor([   12873,     8763, 33551106, 33551179,     4456, 33558776,    10529,
            6797,     7232,    14636], dtype=torch.int32)
decrypted (scaled to 1):
 tensor([3.8365e-04, 2.6116e-04, 9.9990e-01, 9.9990e-01, 1.3280e-04, 1.0001e+00,
        3.1379e-04, 2.0257e-04, 2.1553e-04, 4.3619e-04])


In [15]:
print("signed")
print("Error variance: \n", torch.var((m-mks)/Q))

print("unsigned")
print("Error variance: \n", torch.var((m-umks)/Q))


signed
Error variance: 
 tensor(7.3055e-09)
unsigned
Error variance: 
 tensor(9.8936e-09)


## 3.2. RGSW Encryption

RGSW(m) is composed of two RLWE' ciphertexts, i.e., 

$$
RGSW(m) = ( RLWE'(m), RLWE'(\boldsymbol{m_1}\cdot s) )
$$

The multiplication between RLWE and RGSW is defined as follows
$$
\circledast: RLWE \times RGSW \mapsto RLWE
$$
Let $(\boldsymbol{b}, \boldsymbol{a})$ is encryption of $\boldsymbol{m_0}$
$$
(\boldsymbol{b}, \boldsymbol{a}) \circledast RGSW(m) \\
= (\boldsymbol{b}, \boldsymbol{a}) \circledast ( RLWE'(\boldsymbol{m_1}), RLWE'(\boldsymbol{m_1}\cdot s) )\\
= \boldsymbol{b} \odot RLWE'(\boldsymbol{m_1}) 
+ \boldsymbol{a} \odot RLWE'(\boldsymbol{m_1}\cdot s)\\
= RLWE(\boldsymbol{bm_1}) + RLWE(\boldsymbol{asm_1})\\
= RLWE(\boldsymbol{(b+as)m_1})\\
= RLWE(\boldsymbol{(m_0 + e)m_1}) \\
\approx RLWE(\boldsymbol{m_0 m_1})
$$

Similar to RLWE', we can define encryption function as follows.

In [16]:
def encrypt_rgsw_fft(z, skfft):
    # RGSW has a dimension of d, 2, 2, N
    rgsw = torch.zeros(d, 2, 2, N, dtype=torch.int32)

    # generate the 'a' part
    # INSECURE: to be fixed later
    rgsw[:, :, 1, :] = torch.randint(Q, size = (d, 2 , N), dtype= torch.int32)

    # add error on b
    # INSECURE: to be fixed later
    rgsw[:, :, 0, :] = torch.round(stddev * torch.randn(size = (d, 2, N)))

    # following is equal to rgsw %= Q, but a faster version
    rgsw &= (Q-1)
    rgsw = normalize(rgsw, logQ)

    # do fft for easy a*s
    rgswfft = negacyclic_fft(rgsw, N, Q)

    # now b = -a*sk + e
    rgswfft[:, :, 0, :] -= rgswfft[:, :, 1, :] * skfft.view(1, 1, N//2)

    # encrypt (z, z*sk) multiplied by g
    gzfft = negacyclic_fft(gvector * z, N, Q)
    rgswfft[:, 0, 0, :] += gzfft
    rgswfft[:, 1, 1, :] += gzfft

    return rgswfft

In [17]:
decrypt_from_fft(ctfft, s1fft)[:10]

tensor([       4,       -1, 33554434, 33554431,        1, 33554431,       -2,
              -4,       -5,        2], dtype=torch.int32)

We will multiply $\boldsymbol{m}$ by $X^3$, so the resuting message $\boldsymbol{m} X^3$ has a coeffcients of $\boldsymbol{m}$ right shifted by $3$.

In [18]:
# x1 is monomial x^1
x1 = torch.zeros(N)
x1[3] = 1

In [19]:
rgswx1 = encrypt_rgsw_fft(x1, s1fft)

In [20]:
ct = negacyclic_ifft(ctfft, N, Q)
dct = signed_decompose(ct)
multfft = negacyclic_fft(dct, N, Q).view(d, 2, 1, N//2) * rgswx1
multfftsum = torch.sum(multfft, dim = (0,1))

In [21]:
multfftsum.size()

torch.Size([2, 512])

We can see that $\boldsymbol{m}$ is right shifted by three.

In [22]:
decrypt_from_fft(multfftsum, s1fft)[:10]

tensor([   19465,     3338,    16268,    22290,    12925, 33576169, 33574016,
           18176, 33579903,    15308], dtype=torch.int32)

Make RGSW multiplication as a function.

In [23]:
def rgswmult(ctfft, rgswfft):
    ct = negacyclic_ifft(ctfft, N, Q)
    dct = signed_decompose(ct)
    multfft = negacyclic_fft(dct, N, Q).view(d, 2, 1, N//2) * rgswfft
    
    return torch.sum(multfft, dim = (0,1))