# [AP14](https://web.eecs.umich.edu/~cpeikert/pubs/polyboot.pdf) Homomorphic Encryption Scheme

- [Faster Bootstrapping with Polynomial Error](https://web.eecs.umich.edu/~cpeikert/pubs/polyboot.pdf)
- [Homomorphic Encryption from Learning with Errors:Conceptually-Simpler, Asymptotically-Faster, Attribute-Based](https://eprint.iacr.org/2013/340.pdf)
- [Fully Homomorphic Encryptionfor Machine Learning](https://www.di.ens.fr/~minelli/docs/phd-thesis.pdf)
- [Building a Fully Homomorphic Encryption Scheme in Python](https://courses.csail.mit.edu/6.857/2019/project/15-Hedglin-Phillips-Reilley.pdf)

In [1]:
from math import log2, ceil, inf
import numpy as np
from numpy.testing import assert_array_equal

In [2]:
# NOTE: Uncomment to simplfy debugging
# np.random.seed(1)

## Utility Functions

In [3]:
def generate_gadget_matrix(n, l, modulus):
    """
    Generates the gadget matrix `G` which is a block-diagonal matrix of powers of 2.
    """
    # NOTE: In the paper the range is `l - 1` but Pythons `range` function already excludes the last entry
    g = np.array([1 << i for i in range(l)])
    I = np.eye(n)
    G = np.kron(I, g).astype(int)
    return G

def test():
    q = 65536
    n = 3
    l = ceil(log2(q))
    result = generate_gadget_matrix(n, l, modulus=q)
    expected = np.array([
        [
            1,     2,     4,     8,    16,    32,    64,   128,   256,
          512,  1024,  2048,  4096,  8192, 16384, 32768,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0
        ],
        [
            0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     1,     2,
            4,     8,    16,    32,    64,   128,   256,   512,  1024,
         2048,  4096,  8192, 16384, 32768,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0
        ],
        [
            0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     1,     2,     4,     8,
           16,    32,    64,   128,   256,   512,  1024,  2048,  4096,
         8192, 16384, 32768
        ]
    ])
    assert result.shape == (n, n * l)
    assert_array_equal(result, expected)

test()

In [4]:
def num_to_bin_vector(number, width):
    """
    Translates a number to a fixed-width binary vector
    """
    return np.array([(int(number) >> i & 1) for i in range(width)])

def test():
    # Integer
    number = 64
    width = 8
    result = num_to_bin_vector(number, width)
    assert_array_equal(result, np.array([0, 0, 0, 0, 0, 0, 1, 0]))
    # Float
    number = 1024.0
    width = 11
    result = num_to_bin_vector(number, width)
    assert_array_equal(result, np.array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1]))

test()

In [5]:
def bit_decomp(matrix, padding):
    """
    Decomposes a given matrix into its corresponding binary representation
    NOTE: The binary numbers are returned in a columnar fashion starting with the LSB
    ASIDE: This funciton is called G^-1 or G_inv in the literature
    """
    result_matrix = []
    for column in matrix.T:
        interim_matrix = []
        for value in column:
            interim_matrix.append(num_to_bin_vector(value, padding))
        result_matrix.append(interim_matrix)
    return np.array(result_matrix).reshape(matrix.shape[1], padding * matrix.shape[0]).T

def test():
    # Non-square matrix
    result = bit_decomp(np.array([
        [64, 32, 16],
        [8, 4, 2]]), padding=8)
    expected = np.array([
        [0, 0, 0],
        [0, 0, 0],
        [0, 0, 0],
        [0, 0, 0],
        [0, 0, 1],
        [0, 1, 0],
        [1, 0, 0],
        [0, 0, 0],
        #########
        [0, 0, 0],
        [0, 0, 1],
        [0, 1, 0],
        [1, 0, 0],
        [0, 0, 0],
        [0, 0, 0],
        [0, 0, 0],
        [0, 0, 0],
    ])
    assert_array_equal(result, expected)
    # Square matrix
    result = bit_decomp(np.array([
        [64, 32],
        [16, 8]]), padding=8)
    expected = np.array([
        [0, 0],
        [0, 0],
        [0, 0],
        [0, 0],
        [0, 0],
        [0, 1],
        [1, 0],
        [0, 0],
        ######
        [0, 0],
        [0, 0],
        [0, 0],
        [0, 1],
        [1, 0],
        [0, 0],
        [0, 0],
        [0, 0],
    ])
    assert_array_equal(result, expected)

test()

## Security Parameters

In [6]:
# NOTE: One might need to increase `q` and `n` such that the growing noise when adding / multiplying ciphertexts
# won't cause the decryption to fail.
q = pow(2, 16)
n = 3
l = ceil(log2(q))
m = n * l

print(f'q: {q}')
print(f'l: {l}')
print(f'n: {n}')
print(f'm: {m}')

q: 65536
l: 16
n: 3
m: 48


## Secret Key

In [7]:
s = np.random.choice(q, n - 1) % q

print(s)
print(s.shape)

assert s.shape == (n - 1,)

[25708 58288]
(2,)


In [8]:
sk = np.append(s, [1])

print(sk)
print(sk.shape)

assert sk.shape == (n,)

[25708 58288     1]
(3,)


## Public Key

In [9]:
e = np.rint(np.random.normal(0.0, 1.0, m)).astype(int) % q

print(e)
print(e.shape)

assert e.shape == (m,)

[65534 65535     0     0     0 65535 65534     0     0 65534 65534     0
     1 65535     2     0     0 65535 65535     1     0     1 65535     1
     0     0     0     1     1     0     0     0     2 65533     0     2
     0     0     1     1     1 65535 65535     0     1     1     0     0]
(48,)


In [10]:
A = np.random.choice(q, (n - 1, m)) % q

print(A)
print(A.shape)

assert A.shape == (n - 1, m)

[[64215 45153 49341 55156 21866  8970  5513 64479 37170 33768 35233 26892
  21533 12667 14561 32478 55939  6468 21045 17047 30137 49223 39669 57276
   8985 34606   728 48747 34413  6141 43548 32825 42637 62786 22535 34439
  46608 12739 23711 15682 51586 39983 57274 63859 63842 11879 11902 22578]
 [26696  4082 42398 46255 55206 35564 38096 46248  5960 37039 26188 27157
  14038 15815 53570 36947 35058  5331 49211  1027 53443 31511 45654  5497
  21867  2332 42120 13762 42124 27466 10084 20085 21037 51447 44144 36451
  13804   318 16071 39140 46208 58346 24600 54678 17659  5866  4189 57839]]
(2, 48)


In [11]:
b = (s.dot(A) + e) % q

print(b)
print(b.shape)

assert b.shape == (m,)

[22578 55627  6748 39488 58584 30327 22730 35220 41624 60206 44074 36480
 17757 55731 17998  4024  6052 41663 56299 32453 23836 57029 61563 57729
 11548  5800 18592  7301 36413 21404 29840  2940 48494 28325 48116 10758
 25856 65252 52453 59545 25305 28339 25335   420 31465  3669 35864 65128]
(48,)


In [12]:
pk = np.row_stack((-A, b))

print(pk)
print(pk.shape)

assert pk.shape == (n, m)
assert_array_equal(pk[:n - 1], -A)
assert_array_equal(pk[n - 1], b)

[[-64215 -45153 -49341 -55156 -21866  -8970  -5513 -64479 -37170 -33768
  -35233 -26892 -21533 -12667 -14561 -32478 -55939  -6468 -21045 -17047
  -30137 -49223 -39669 -57276  -8985 -34606   -728 -48747 -34413  -6141
  -43548 -32825 -42637 -62786 -22535 -34439 -46608 -12739 -23711 -15682
  -51586 -39983 -57274 -63859 -63842 -11879 -11902 -22578]
 [-26696  -4082 -42398 -46255 -55206 -35564 -38096 -46248  -5960 -37039
  -26188 -27157 -14038 -15815 -53570 -36947 -35058  -5331 -49211  -1027
  -53443 -31511 -45654  -5497 -21867  -2332 -42120 -13762 -42124 -27466
  -10084 -20085 -21037 -51447 -44144 -36451 -13804   -318 -16071 -39140
  -46208 -58346 -24600 -54678 -17659  -5866  -4189 -57839]
 [ 22578  55627   6748  39488  58584  30327  22730  35220  41624  60206
   44074  36480  17757  55731  17998   4024   6052  41663  56299  32453
   23836  57029  61563  57729  11548   5800  18592   7301  36413  21404
   29840   2940  48494  28325  48116  10758  25856  65252  52453  59545
   25305  28339  2

In [13]:
assert_array_equal(sk.dot(pk) % q, e)

## Encryption

In [14]:
mu = 42

print(mu)

42


In [15]:
G = generate_gadget_matrix(n, l, q)

print(G)
print(G.shape)

assert G.shape == (n, m)

[[    1     2     4     8    16    32    64   128   256   512  1024  2048
   4096  8192 16384 32768     0     0     0     0     0     0     0     0
      0     0     0     0     0     0     0     0     0     0     0     0
      0     0     0     0     0     0     0     0     0     0     0     0]
 [    0     0     0     0     0     0     0     0     0     0     0     0
      0     0     0     0     1     2     4     8    16    32    64   128
    256   512  1024  2048  4096  8192 16384 32768     0     0     0     0
      0     0     0     0     0     0     0     0     0     0     0     0]
 [    0     0     0     0     0     0     0     0     0     0     0     0
      0     0     0     0     0     0     0     0     0     0     0     0
      0     0     0     0     0     0     0     0     1     2     4     8
     16    32    64   128   256   512  1024  2048  4096  8192 16384 32768]]
(3, 48)


In [16]:
R = np.random.choice(2, (m, m)) % q

print(R)
print(R.shape)

assert R.shape == (m, m)

[[1 1 0 ... 0 0 1]
 [1 0 0 ... 0 0 0]
 [1 0 1 ... 1 1 0]
 ...
 [0 1 1 ... 0 1 1]
 [1 0 1 ... 0 0 1]
 [0 0 1 ... 0 1 1]]
(48, 48)


In [17]:
C = ((pk @ R) + (mu * G)) % q

print(C)
print(C.shape)

assert C.shape == (n, m)

[[ 5710 58646 63027 60887 20432 15002 29460 49609 25884 60422 10572 59629
   2726 36781 54479 22299  7826 56710 12694 16278 35086 60355  5784 36685
  41232 37847 58590 14799   800  2681 54920  4542  1748 53758  2872 42193
  49190 33652 19464 29554 26236 44235 29130 33487 20935 39222 38041 64979]
 [ 8674 40445  4531   807 15622 21739 10207 25420 27748  1709  9135 20603
  57437 10352 65079 55336 58944 47881 62907 13295  1305 34742 20109 27045
  48469 34416 58804 62710 55975     7 24309 56710 25977  7707 51959 49569
  21407 36200 51714 61164 29839 49131 19509 28358 41080 44484 45471 57315]
 [59241 49465 14669 47931 27041 37756 60226 61172 61287 34691  3742 33395
  29446 45569 53722 41238 24894 22539  9891 51816 47081  2717 51189 63250
   9165 43079 45717 13710 50867  5923 55087 11707 16517 36251 54839 62323
  41032 53458 12223 21163 36451 53199 23391 35467 43404 17015 61736 18928]]
(3, 48)


## Decryption

In [18]:
msg = sk.dot(C) % q

print(msg)
print(msg.shape)

assert msg.shape == (m,)

[31153 62321 59105 52671 39809 14084 28162 56320 47095 28667 57342 49151
 32766 65533 65534 65530 23254 46531 27515 55040 44545 23553 47109 28670
 57341 49147 32765     2     3 65535 65535     3    37    83   167   335
   672  1346  2687  5379 10755 21507 43015 20479 40960 16383 32772     4]
(48,)


In [19]:
sg = sk.dot(G) % q

print(sg)
print(sg.shape)

assert sg.shape == (m,)

[25708 51416 37296  9056 18112 36224  6912 13824 27648 55296 45056 24576
 49152 32768     0     0 58288 51040 36544  7552 15104 30208 60416 55296
 45056 24576 49152 32768     0     0     0     0     1     2     4     8
    16    32    64   128   256   512  1024  2048  4096  8192 16384 32768]
(48,)


In [20]:
# We might run into "divide by zero" RuntimeWarnings here
with np.errstate(divide='ignore',invalid='ignore'):
    r = (msg // sg)

print(r)
print(r.shape)

assert r.shape == (m,)

[ 1  1  1  5  2  0  4  4  1  0  1  1  0  1  0  0  0  0  0  7  2  0  0  0
  1  1  0  0  0  0  0  0 37 41 41 41 42 42 41 42 42 42 42  9 10  1  2  0]
(48,)


In [21]:
res = 0
dist = inf

for val in np.unique(r):
    d = (msg - (val * sg)) % q
    d = np.minimum(d, q - d) % q
    d = int(np.linalg.norm(d)) % q
    if d < dist:
        res = val
        dist = d

print(f'The value is: {res}')

assert res == mu

The value is: 42


## Homomorphic Addition / Multiplication

In [22]:
def encrypt(mu):
    """
    The encryption logic from above in a single function
    """
    return ((pk @ R) + (mu * G)) % q

In [23]:
def decrypt(C):
    """
    The decryption logic from above in a single function
    """
    msg = sk.dot(C) % q
    sg = sk.dot(G) % q
    with np.errstate(divide='ignore',invalid='ignore'):
        r = (msg // sg)
    
    res = 0
    dist = inf

    for val in np.unique(r):
        d = (msg - (val * sg)) % q
        d = np.minimum(d, q - d) % q
        d = int(np.linalg.norm(d)) % q
        if d < dist:
            res = val
            dist = d

    return res

In [24]:
res = decrypt((encrypt(42) + encrypt(28) + encrypt(30)) % q)

print(res)

assert res == 42 + 28 + 30

100


In [25]:
res = decrypt(((encrypt(2) + encrypt(4)) % q) @ bit_decomp(encrypt(3), l) % q)

print(res)

assert res == (2 + 4) * 3

18
