# LWE key switching
## LWE extrac operation

We can extract an LWE ciphertext, encrypting only the constant term from an RLWE ciphertext.
In fact, we can extract $N$ different LWE ciphertext encrypting each coefficients, but we only need the constant term in FHWE-like HE.

First, we call necessary operations.

In [1]:
# Functions from previous lecturenote
import torch
import math

stddev = 3.2
logQ = 27
logQks = 14

n = 2**9
Qks = 2**logQks

N = 2**10
Q = 2**logQ

def keygen(dim):
    return torch.randint(2, size = (dim,), dtype = torch.int32)

def errgen(stddev, N):
    e = torch.round(stddev*torch.randn(N)).to(torch.int32)
    e = e.squeeze()
    return e.to(torch.int)

def uniform(dim, modulus):
    return torch.randint(modulus, size = (dim,), dtype = torch.int32)

def polymult(a, b, dim, modulus):
    res = torch.zeros(dim).to(torch.int)
    for i in range(dim):
        for j in range(dim):
            if i >= j:
                res[i] += a[j]*b[i-j]
                res[i] %= modulus
            else:
                res[i] -= a[j]*b[i-j] # Q - x mod Q = -x
                res[i] %= modulus

    res %= modulus
    return res

root_powers = torch.arange(N//2).to(torch.complex128)
root_powers = torch.exp((1j*math.pi/N)*root_powers)

root_powers_inv = torch.arange(0,-N//2,-1).to(torch.complex128)
root_powers_inv = torch.exp((1j*math.pi/N)*root_powers_inv)

def negacyclic_fft(a, N, Q):
    acomplex = a.to(torch.complex128)

    a_precomp = (acomplex[...,:N//2] + 1j * acomplex[..., N//2:]) * root_powers

    return torch.fft.fft(a_precomp)

def negacyclic_ifft(A, N, Q):
    b = torch.fft.ifft(A)
    b *= root_powers_inv

    a = torch.cat((b.real, b.imag), dim=-1)

    aint = a.to(torch.int32)
    # only when Q is a power-of-two
    aint &= Q-1

    return aint

# make an RLWE encryption of message
def encrypt_to_fft(m, sfft):
    ct = torch.stack([errgen(stddev, N), uniform(N, Q)])
    ctfft = negacyclic_fft(ct, N, Q)

    ctfft[0] += -ctfft[1]*sfft + negacyclic_fft(m, N, Q)

    return ctfft

def normalize(v, logQ):
    # same as follows but no branch
    """
    if v > Q//2:
        v -= Q
    """
    # vmod Q when Q is a power-of-two
    Q = (1 << logQ)
    v &= Q-1
    # get msb
    msb = (v & Q//2) >> (logQ - 1)
    v -= (Q) * msb
    return v

def decrypt_from_fft(ctfft, sfft):
    assert len(ctfft.size()) == 2
    return normalize(negacyclic_ifft(ctfft[0] + ctfft[1]*sfft, N, Q), logQ)

In [2]:
def encryptLWE(message, dim, modulus, key):
    ct = uniform(dim + 1, modulus)

    ct[0] = 0

    ct[0] = message * modulus//4 - torch.dot(ct[1:], key)
    ct[0] += errgen(stddev, 1)
    ct &= modulus -1

    return ct

def decryptLWE(ct, sk, modulus):
    m_dec = torch.dot(ct, torch.cat((torch.ones(1, dtype=torch.int32), sk)))
    m_dec %= modulus

    m_dec = m_dec.to(torch.float)
    m_dec /= modulus/4.
    m_dec = torch.round(m_dec)
    return m_dec.to(torch.int)

In [3]:
sk = keygen(N)
msg = 1

ctLWE = encryptLWE(msg, N, Q, sk)
decryptLWE(ctLWE, sk, Q)

tensor(1, dtype=torch.int32)

In [4]:
skfft = negacyclic_fft(sk, N, Q)

msgRing = torch.tensor(range(N))
msgRing %= 2
msgRing ^= 1 # 1, 0,1,0,1,....
msgRing *= Q//4 

ctRLWE = encrypt_to_fft(msgRing, skfft)
mDec = decrypt_from_fft(ctRLWE, skfft)
mDec

tensor([33554433,        1, 33554434,  ...,        2, 33554430,       -1],
       dtype=torch.int32)

The free coefficient of `mDec` is given as follows.
$$
b_0 + (\boldsymbol{a} \cdot \boldsymbol{s})_0,
$$ 
where 
$$
(\boldsymbol{a} \cdot \boldsymbol{s})_0 = a_0 \cdot s_0 + \sum_{i = 1}^{N-1} -a_{N-i} \cdot s_i .
$$

In [5]:
def extract(ctRLWE):
    beta = ctRLWE[0][0]

    alpha = ctRLWE[1][:]
    alpha[1:] = -alpha[1:].flip(dims = [0])

    return torch.cat((beta.unsqueeze(0), alpha))



In [6]:
ctExtract = extract(negacyclic_ifft(ctRLWE, N, Q))

In [7]:
decryptLWE(ctExtract, sk, Q)

tensor(1, dtype=torch.int32)

## 4.2. LWE key switching

The LWE key switching key from $\vec{s}_1$ to $\vec{s}_2$ is composed of multiple LWE ciphertexts under secret $\vec{s}_2$ encrypting scaled version of each element of $\vec{s}_1$.
$$
ksk = \{LWE_{\vec{s}_2}(j B^r s_{1, i}) : r\in [0,d), j \in [0, B), i \in [0, N) \}.
$$

The key switching process for a given $LWE_{\vec{s}_1}(m)$ ciphertext $(\beta, \vec{\alpha})$ is done by following process.
1. Decompose each $\alpha_i$ so that $\alpha_i = \sum_r \alpha_{i,r} B^r$.
2. Calculate 
$$
\sum_i \sum_r ksk[r, \alpha_{i,r}, i] = \sum_i \sum_r LWE_{\vec{s}_2}(\alpha_{i,r} B^r s_{1, i})\\
 = \sum_i LWE_{\vec{s}_2}(\alpha_{i} \cdot s_{1, i}) \\
 = LWE_{\vec{s}_2}( \left< \vec{\alpha}, \vec{s_1}\right>)
$$
3. Add $(\beta, 0)$ to the ciphertext, and get $LWE_{\vec{s}_2}( \left< \vec{\alpha}, \vec{s_1}\right> + \beta) \approx LWE_{\vec{s}_2}(m)$

Thus, the LWE key switching key is an array of dimension $(d, B, N, n + 1)$, where $N$ and $n$ are dimensions of $\vec{s}_1$ and $\vec{s}_2$, repspectively.

NOTE: As in RLWE, we can apply approximate LWE key switching.

We generate key switching key from `skN` to `sk` of modulus `Qks`.

In [8]:
sk = keygen(n)
skN = keygen(N)

In [9]:
dks = 2
logBks = 6

decomp_shift_ks = logQks - logBks*torch.arange(dks,0,-1).view(dks,1)
mask_ks = (1 << logBks) - 1

gvector_ks = 1 << decomp_shift_ks

In [46]:
Bks = 1 << logBks
# size: (dks, Bks, N, n+1)
def LWEkskGen(sk, skN, logQks):
    ksk = torch.randint(Qks, size = (dks, Bks, N, n+1), dtype= torch.int32)
    # b <- e
    ksk[..., 0] = torch.round(stddev * torch.randn(size = (dks, Bks, N))).to(torch.int32)
    # b <- e - a * s
    ksk[..., 0] -= torch.sum(ksk[:,:,:,1:] * sk, dim = -1)
    ksk[..., 0] += (gvector_ks * torch.tensor(range(Bks))).view(dks, Bks, 1) * skN
    ksk &= Qks - 1

    return ksk

In [None]:
def LWEkeySwitch(ctLWE, kskLWE):
    pass