# 6. FHEW Homomorphic Encryption

So far, we have learned how FHEW-like HE works.

In [1]:
# Functions from previous lecturenote
import torch
import math

stddev = 3.2
logQ = 27
logQks = 14

N = 1 << 10
Q = 1 << logQ
Qks = 1 << logQks

n = 512
q = 2*N

def keygen(dim):
    return torch.randint(2, size = (dim,), dtype=torch.int32)

def errgen(stddev):
    e = torch.round(stddev*torch.randn(1))
    e = e.squeeze()
    return e.to(torch.int)

def errgen(stddev, N):
    e = torch.round(stddev*torch.randn(N))
    e = e.squeeze()
    return e.to(torch.int)

def uniform(dim, modulus):
    return torch.randint(modulus, size = (dim,), dtype=torch.int32)

def polymult(a, b, dim, modulus):
    res = torch.zeros(dim).to(torch.int)
    for i in range(dim):
        for j in range(dim):
            if i >= j:
                res[i] += a[j]*b[i-j]
                res[i] %= modulus
            else:
                res[i] -= a[j]*b[i-j] # Q - x mod Q = -x
                res[i] %= modulus

    res %= modulus
    return res

root_powers = torch.arange(N//2).to(torch.complex128)
root_powers = torch.exp((1j*math.pi/N)*root_powers)

root_powers_inv = torch.arange(0,-N//2,-1).to(torch.complex128)
root_powers_inv = torch.exp((1j*math.pi/N)*root_powers_inv)

def negacyclic_fft(a, N, Q):
    acomplex = a.to(torch.complex128)

    a_precomp = (acomplex[...,:N//2] + 1j * acomplex[..., N//2:]) * root_powers

    return torch.fft.fft(a_precomp)

def negacyclic_ifft(A, N, Q):
    b = torch.fft.ifft(A)
    b *= root_powers_inv

    a = torch.cat((b.real, b.imag), dim=-1)
    # Rounding should be more accurate
    # a += 0.5

    aint = a.to(torch.int32)
    # only when Q is a power-of-two
    aint &= Q-1

    return aint

# make an RLWE encryption of message
def encrypt_to_fft(m, sfft):
    ct = torch.stack([errgen(stddev, N), uniform(N, Q)])
    ctfft = negacyclic_fft(ct, N, Q)

    ctfft[0] += -ctfft[1]*sfft + negacyclic_fft(m, N, Q)

    return ctfft

def normalize(v, logQ):
    # same as follows but no branch
    """
    if v > Q//2:
        v -= Q
    """
    # vmod Q when Q is a power-of-two
    Q = (1 << logQ)
    v &= Q-1
    # get msb
    msb = (v & Q//2) >> (logQ - 1)
    v -= (Q) * msb
    return v

def decrypt_from_fft(ctfft, sfft):
    assert len(ctfft.size()) == 2
    # normalization is optional
    return normalize(negacyclic_ifft(ctfft[0] + ctfft[1]*sfft, N, Q), logQ)
    # return negacyclic_ifft(ctfft[0] + ctfft[1]*sfft, N, Q)

# we will use B-ary decomposition, i.e., cut digits by logB bits
d = 3
logB = 8

assert d * logB < logQ

decomp_shift = logQ - logB*torch.arange(d,0,-1).view(d,1)
mask = (1 << logB) - 1

gvector = 1 << decomp_shift

def decompose(a):
    
    assert len(a.size()) <= 2
    # for RLWE'
    if len(a.size()) == 1:
        res = (a.unsqueeze(0) >> decomp_shift.view(d, 1)) & mask
        return res
    # for RGSW
    elif len(a.size()) == 2:
        res = (a.unsqueeze(0) >> decomp_shift.view(d, 1, 1)) & mask
        return res

msbmask = 0
for i in decomp_shift:
    msbmask += (1<<(i+logB-1))

bin(msbmask)[2:]

# about twice heavier than unsigned decomposition
# it returns value -B/2 <= * <= B/2, not < B/2, but okay
def signed_decompose(a):
    # carry
    da = decompose(a + (a & msbmask))
    # -B
    da -= decompose((a & msbmask))
    return da

def encrypt_rgsw_fft(z, skfft):
    # RGSW has a dimension of d, 2, 2, N
    rgsw = torch.zeros(d, 2, 2, N, dtype=torch.int32)

    # generate the 'a' part
    # INSECURE: to be fixed later
    rgsw[:, :, 1, :] = torch.randint(Q, size = (d, 2 , N), dtype= torch.int32)

    # add error on b
    # INSECURE: to be fixed later
    rgsw[:, :, 0, :] = torch.round(stddev * torch.randn(size = (d, 2, N)))

    # following is equal to rgsw %= Q, but a faster version
    rgsw &= (Q-1)
    rgsw = normalize(rgsw, logQ)

    # do fft for easy a*s
    rgswfft = negacyclic_fft(rgsw, N, Q)

    # now b = -a*sk + e
    rgswfft[:, :, 0, :] -= rgswfft[:, :, 1, :] * skfft.view(1, 1, N//2)

    # encrypt (z, z*sk) multiplied by g
    gzfft = negacyclic_fft(gvector * z, N, Q)
    rgswfft[:, 0, 0, :] += gzfft
    rgswfft[:, 1, 1, :] += gzfft

    return rgswfft

def rgswmult(ctfft, rgswfft):
    ct = negacyclic_ifft(ctfft, N, Q)
    dct = signed_decompose(ct)
    multfft = negacyclic_fft(dct, N, Q).view(d, 2, 1, N//2) * rgswfft
    
    return torch.sum(multfft, dim = (0,1))

In [2]:
def encryptLWE(message, dim, modulus, key):
    ct = uniform(dim + 1, modulus)

    ct[0] = 0

    ct[0] = message * modulus//4 - torch.dot(ct[1:], key)
    ct[0] += errgen(stddev, 1)
    ct &= modulus -1

    return ct

def decryptLWE(ct, sk, modulus):
    m_dec = torch.dot(ct, torch.cat((torch.ones(1, dtype=torch.int32), sk)))
    m_dec %= modulus

    m_dec = m_dec.to(torch.float)
    m_dec /= modulus/4.
    m_dec = torch.round(m_dec)
    return m_dec.to(torch.int)%4

In [3]:
def extract(ctRLWE):
    beta = ctRLWE[0][0]

    alpha = ctRLWE[1][:]
    alpha[1:] = -alpha[1:].flip(dims = [0])

    return torch.cat((beta.unsqueeze(0), alpha))

In [4]:
dks = 2
logBks = 7

decomp_shift_ks = logQks - logBks*torch.arange(dks,0,-1).to(torch.int32).view(dks,1)
mask_ks = torch.tensor([(1 << logBks) - 1]).to(torch.int32)

gvector_ks = 1 << decomp_shift_ks

def decompose_ks(a):
    
    assert len(a.size()) == 1

    res = (a.unsqueeze(0) >> decomp_shift_ks.view(dks, 1)) & mask_ks
    return res

Bks = 1 << logBks
# size: (dks, Bks, N, n+1)
def LWEkskGen(sk, skN, logQks):
    ksk = torch.randint(Qks, size = (dks, Bks, N, n+1), dtype= torch.int32)
    # b <- e
    ksk[..., 0] = torch.round(stddev * torch.randn(size = (dks, Bks, N))).to(torch.int32)
    # b <- e - a * s
    ksk[..., 0] -= torch.sum(ksk[:,:,:,1:] * sk, dim = -1)
    # b <- e - a * s + j B^r skN_i
    ksk[..., 0] += (gvector_ks * torch.tensor(range(Bks))).view(dks, Bks, 1) * skN
    ksk &= Qks - 1

    return ksk

def LWEkeySwitch(ctLWE, kskLWE, Qks):
    # do decomposition
    alpha = ctLWE[1:]
    dalpha = decompose_ks(alpha)
    # do appropriate addition of keys
    switched = torch.zeros(n+1, dtype=torch.int32)
    switched[0] = ctLWE[0]

    for r in range(dks):
        for i in range(N):
            switched += kskLWE[r, dalpha[r, i], i]

    switched &= Qks-1

    return switched
    

In [5]:
def brkgen(sk, skNfft, n, N):
    zero_poly = torch.zeros([N], dtype=torch.int32)

    one_poly = torch.zeros([N], dtype=torch.int32)
    one_poly[0] = 1

    brk = [None]*n
    for i in range(n):
        if sk[i] == 0:
            brk[i] = encrypt_rgsw_fft(zero_poly, skNfft)
        else:
            brk[i] = encrypt_rgsw_fft(one_poly, skNfft)
    
    return brk

In [6]:
def precompute_alpha(q, N, Q):
    alphapoly = []

    for i in range(q):
        poly = torch.zeros([N], dtype=torch.int32)
        poly[0] = -1
        if i < N:
            poly[i] += 1
        else:
            poly[i-N] += -1
        alphapoly.append(negacyclic_fft(poly, N, Q))
        
    return alphapoly

In [7]:
def nand_map(i):
    i += 2*N 
    i %= 2*N
    if 3*(q>>3) <= i < 7*(q>>3): # i \in [3q/8, 7q/8)
        return -(Q >> 3)
    else: # i \in [-q/8, 3q/8)
        return Q >> 3 

f_nand = torch.zeros([N], dtype=torch.int32)

for i in range(N):
    f_nand[i] = nand_map(-i)

In [8]:
def gate_bootstrapping(ct0, ct1, gate="NAND"):
    ctsum = (ct0 + ct1) & (q-1)

    assert gate == "NAND"
    
    # initialize acc
    acc = torch.zeros([2,N], dtype=torch.int32)
    acc[0] = f_nand

    accfft = negacyclic_fft(acc, N, Q)

    beta = ctsum[0]
    xbeta = torch.zeros([N], dtype=torch.int32)

    if beta < N:
        xbeta[beta] = 1
    else:
        xbeta[beta - N] = -1

    accfft *= negacyclic_fft(xbeta, N, Q)

    # blind rotation
    alpha = ctsum[1:]
    for i in range(n):
        ai = alpha[i]
        accfft += alphapoly[ai] * rgswmult(accfft, brk[i])

    # extract
    acc = negacyclic_ifft(accfft, N, Q)
    accLWE = extract(acc)
    accLWE[0] += (Q >> 3)
    accLWE &= Q - 1
    
    # mod switching 1
    accLWE_ms = (accLWE * (Qks/Q)).to(torch.int32)

    # key switching
    accLWE_ks = LWEkeySwitch(accLWE_ms, ksk, Qks)

    # mod switching 2
    return (accLWE_ks * (q/Qks)).to(torch.int32)

# Test

In [9]:
sk = keygen(n)
skN = keygen(N)
skNfft = negacyclic_fft(skN, N, Q)

ksk = LWEkskGen(sk, skN, logQks)
brk = brkgen(sk, skNfft, n, N)
alphapoly = precompute_alpha(q, N, Q)

In [10]:
m0 = 1
m1 = 1

ct0 = encryptLWE(m0, n, q, sk)
ct1 = encryptLWE(m1, n, q, sk)

ctNAND = gate_bootstrapping(ct0, ct1, gate="NAND")

print(f"NAND output: {int(not (m0 and m1))}")
print(f"encrypted result: {decryptLWE(ctNAND, sk, q)}")


NAND output: 0
encrypted result: 0
