# 5. Blind rotation
A blind rotation procdeure is the core of FHEW bootstrapping and other many application of HE supporting non-arithmetic (not $\times$ or +) operations.

It is defined as follows.

**Definition (blind rotation)**
- Input: LWE ciphertext $(\beta, \vec{\alpha})$, blind rotation keys $brk$ under secret $\boldsymbol{z}$, and public polynomial $f$, where $\beta + \left< \vec{\alpha}, \vec{s}\right> = u$
- Output: $RLWE_{\boldsymbol{z}}(f\cdot X^u)$

It is compose of following three steps.
1. Make a encryption of $f\cdot X^{\beta}$, $ACC$. It can easily be done $ACC = (f\cdot X^{\beta}, 0)$.
2. Accumulation: homomophically multiply $X^{\alpha_i s_i}$ to $ACC$ and update it.
3. After accumulation, we get $RLWE_{\boldsymbol{z}}(f\cdot X^{\beta + \left< \vec{\alpha}, \vec{s}\right>}) = RLWE_{\boldsymbol{z}}(f\cdot X^u)$

In [54]:
# Functions from previous lecturenote
import torch
import math

stddev = 3.2
logQ = 27

N = 2**10
Q = 2**logQ

def keygen(dim):
    return torch.randint(2, size = (dim,))

def errgen(stddev):
    e = torch.round(stddev*torch.randn(1))
    e = e.squeeze()
    return e.to(torch.int)

def errgen(stddev, N):
    e = torch.round(stddev*torch.randn(N))
    e = e.squeeze()
    return e.to(torch.int)

def uniform(dim, modulus):
    return torch.randint(modulus, size = (dim,), dtype=torch.int32)

def polymult(a, b, dim, modulus):
    res = torch.zeros(dim).to(torch.int)
    for i in range(dim):
        for j in range(dim):
            if i >= j:
                res[i] += a[j]*b[i-j]
                res[i] %= modulus
            else:
                res[i] -= a[j]*b[i-j] # Q - x mod Q = -x
                res[i] %= modulus

    res %= modulus
    return res

root_powers = torch.arange(N//2).to(torch.complex128)
root_powers = torch.exp((1j*math.pi/N)*root_powers)

root_powers_inv = torch.arange(0,-N//2,-1).to(torch.complex128)
root_powers_inv = torch.exp((1j*math.pi/N)*root_powers_inv)

def negacyclic_fft(a, N, Q):
    acomplex = a.to(torch.complex128)

    a_precomp = (acomplex[...,:N//2] + 1j * acomplex[..., N//2:]) * root_powers

    return torch.fft.fft(a_precomp)

def negacyclic_ifft(A, N, Q):
    b = torch.fft.ifft(A)
    b *= root_powers_inv

    a = torch.cat((b.real, b.imag), dim=-1)
    # Rounding should be more accurate
    # a += 0.5

    aint = a.to(torch.int32)
    # only when Q is a power-of-two
    aint &= Q-1

    return aint

# make an RLWE encryption of message
def encrypt_to_fft(m, sfft):
    ct = torch.stack([errgen(stddev, N), uniform(N, Q)])
    ctfft = negacyclic_fft(ct, N, Q)

    ctfft[0] += -ctfft[1]*sfft + negacyclic_fft(m, N, Q)

    return ctfft

def normalize(v, logQ):
    # same as follows but no branch
    """
    if v > Q//2:
        v -= Q
    """
    # vmod Q when Q is a power-of-two
    Q = (1 << logQ)
    v &= Q-1
    # get msb
    msb = (v & Q//2) >> (logQ - 1)
    v -= (Q) * msb
    return v

def decrypt_from_fft(ctfft, sfft):
    assert len(ctfft.size()) == 2
    # normalization is optional
    # return normalize(negacyclic_ifft(ctfft[0] + ctfft[1]*sfft, N, Q), logQ)
    return negacyclic_ifft(ctfft[0] + ctfft[1]*sfft, N, Q)

# we will use B-ary decomposition, i.e., cut digits by logB bits
d = 3
logB = 7

decomp_shift = logQ - logB*torch.arange(d,0,-1).view(d,1)
mask = (1 << logB) - 1

gvector = 1 << decomp_shift

def decompose(a):
    
    assert len(a.size()) <= 2
    # for RLWE'
    if len(a.size()) == 1:
        res = (a.unsqueeze(0) >> decomp_shift.view(d, 1)) & mask
        return res
    # for RGSW
    elif len(a.size()) == 2:
        res = (a.unsqueeze(0) >> decomp_shift.view(d, 1, 1)) & mask
        return res

msbmask = 0
for i in decomp_shift:
    msbmask += (1<<(i+logB-1))

bin(msbmask)[2:]

# about twice heavier than unsigned decomposition
# it returns value -B/2 <= * <= B/2, not < B/2, but okay
def signed_decompose(a):
    # carry
    da = decompose(a + (a & msbmask))
    # -B
    da -= decompose((a & msbmask))
    return da

def encrypt_rgsw_fft(z, skfft):
    # RGSW has a dimension of d, 2, 2, N
    rgsw = torch.zeros(d, 2, 2, N, dtype=torch.int32)

    # generate the 'a' part
    # INSECURE: to be fixed later
    rgsw[:, :, 1, :] = torch.randint(Q, size = (d, 2 , N), dtype= torch.int32)

    # add error on b
    # INSECURE: to be fixed later
    rgsw[:, :, 0, :] = torch.round(stddev * torch.randn(size = (d, 2, N)))

    # following is equal to rgsw %= Q, but a faster version
    rgsw &= (Q-1)
    rgsw = normalize(rgsw, logQ)

    # do fft for easy a*s
    rgswfft = negacyclic_fft(rgsw, N, Q)

    # now b = -a*sk + e
    rgswfft[:, :, 0, :] -= rgswfft[:, :, 1, :] * skfft.view(1, 1, N//2)

    # encrypt (z, z*sk) multiplied by g
    gzfft = negacyclic_fft(gvector * z, N, Q)
    rgswfft[:, 0, 0, :] += gzfft
    rgswfft[:, 1, 1, :] += gzfft

    return rgswfft

def rgswmult(ctfft, rgswfft):
    ct = negacyclic_ifft(ctfft, N, Q)
    dct = signed_decompose(ct)
    multfft = negacyclic_fft(dct, N, Q).view(d, 2, 1, N//2) * rgswfft
    
    return torch.sum(multfft, dim = (0,1))

## 5.1 CMux gate

In each iteration, we need to multiply $RGSW(X^{\alpha_i s_i})$, where $\alpha_i$ is public and $s_i$ is given in encrypted from.
The variants: AP/DM, GINX/CGGI, and LMKCDEY differ in generation of encrypted $X^{\alpha_i s_i}$.

In GINX variant blind rotation for *binary secrets*, we use the following CMux gate as subprocess.
$$
X^{\alpha_i s_i} = (1-s_i) + X^{\alpha_i} \cdot s_i
$$
There are only two possible cases = $s_i$ is $0$ or $1$ - thus, we can easily see that the equation holds in both cases.

1. $s_i = 0$
$$
X^{\alpha_i s_i} = (1-s_i) + X^{\alpha_i} \cdot s_i \\
    = (1-0) + 0\\
    = X^0 
$$
2. $s_i = 1$
$$
X^{\alpha_i s_i} = (1-1) + X^{\alpha_i} \cdot 1 \\
    = (1-1) +  X^{\alpha_i}\\
    = X^{\alpha_i}
$$

Hence, we calculate the following in each iteration
$$
ACC \leftarrow ACC \circledast ((1-RGSW(s_i)) + X^{\alpha_i} \cdot RGSW(s_i)),
$$
which can be computed efficiently with
$$
ACC \leftarrow ACC  + (X^{\alpha_i} - 1 ) \cdot ACC\circledast RGSW(s_i).
$$

We implement the GINX accumulation in the following codes.

In [55]:
s = 1
skN = keygen(N)
skNfft = negacyclic_fft(skN, N, Q)

sPoly = torch.zeros([N], dtype=torch.int32)
sPoly[0] = s

rgswKey = encrypt_rgsw_fft(sPoly, skNfft)

In [56]:
# ACC is transparent encryption of (0,1,2, ..., N)
ACC = torch.zeros([2,N], dtype=torch.int32)
for i in range(N):
    ACC[0][i] = (i%10) * 1000000

ACCfft = negacyclic_fft(ACC, N, Q)

ACC

tensor([[      0, 1000000, 2000000,  ..., 1000000, 2000000, 3000000],
        [      0,       0,       0,  ...,       0,       0,       0]],
       dtype=torch.int32)

In [65]:
alpha = 1
alphaPoly = torch.zeros([N], dtype=torch.int32)
alphaPoly[0] = -1
alphaPoly[alpha] = 1

alphaPolyfft = negacyclic_fft(alphaPoly, N, Q)

In [66]:
# Accumulation
temp = rgswmult(ACCfft, rgswKey)
ACCfft += alphaPolyfft * temp

In [67]:
decrypt_from_fft(temp, skNfft)[:10]

tensor([107240719,  13388392,  67389239,  94891949, 108328708,  79879031,
        105447611, 120037201,  71632247,  49894554], dtype=torch.int32)

In [68]:
decrypt_from_fft(alphaPolyfft * temp, skNfft)[:10]

tensor([ 55621463,  93852326,  80216880, 106715019, 120780968,  28449677,
        108649149, 119628138,  48404954,  21737693], dtype=torch.int32)

In [61]:
decrypt_from_fft(ACCfft, skNfft)[:10]

tensor([131211942,    998322,   3002208,   5002741,   7006116,   9012927,
         11014387,  13009728,  15000824,  17002389], dtype=torch.int32)

In [62]:
dm = decrypt_from_fft(ACCfft, skNfft)[:10]

In [63]:
normalize(dm, logQ)

tensor([-3005786,   998322,  3002208,  5002741,  7006116,  9012927, 11014387,
        13009728, 15000824, 17002389], dtype=torch.int32)

First, we generate $f$ for NAND gate.
$[-q/8, 3q/8]$ is mapped to $q/8$, and $[3q/8, 7q/8]$ is mapped to $-q/8$

NOTE: $f$ depends on the binary gate we want to perform.


In [64]:
q = 2*N
f_nand = torch.zeros(size=[N], dtype=torch.int32)

f_nand -= (Q>>3)
f_nand[3*q//8:] += (Q>>2)