# [Homomorphic Encrpytion](https://en.wikipedia.org/wiki/Homomorphic_encryption)

Playing around with (fully) homomorphic encryption schemes.

In [1]:
import numpy as np
from math import gcd
from random import randint
from typing import List, Tuple, NamedTuple
from numpy import ndarray

In [2]:
# Taken from: https://en.wikibooks.org/wiki/Algorithm_Implementation/Mathematics/Extended_Euclidean_algorithm
def egcd(a: int, b: int) -> Tuple[int, int, int]:
    if a == 0:
        return (b, 0, 1)
    else:
        g, y, x = egcd(b % a, a)
        return (g, x - (b // a) * y, y)

def modinv(a: int, m: int) -> int:
    g, x, y = egcd(a, m)
    if g != 1:
        raise Exception('modular inverse does not exist')
    else:
        return x % m

assert modinv(17, 3120) == 2753
assert egcd(1071, 462) == (21, -3, 7)

## [El Gamal](https://en.wikipedia.org/wiki/ElGamal_encryption)

El Gamal can be used to perform encrypted multiplications.

In [3]:
class PublicKey(NamedTuple):
    p: int
    a: int
    b: int

class SecretKey(NamedTuple):
    d: int

class Ciphertext(NamedTuple):
    r: int
    t: int

In [4]:
def keygen(p: int) -> Tuple[PublicKey, SecretKey]:
    # a: 1 < a < p - 1
    a: int = randint(1, p - 1)
    # d: 2 <= d <= p - 2
    d: int = randint(2, p - 2)
    b: int = (a ** d) % p
    pk: PublicKey = PublicKey(p, a, b)
    sk: SecretKey = SecretKey(d)
    return (pk, sk)

In [5]:
def encrypt(message: int, pk: PublicKey) -> Ciphertext:
    k: int = randint(0, 100)    
    r: int = (pk.a ** k) % pk.p
    t: int = ((pk.b ** k) * message) % pk.p
    return Ciphertext(r, t)

In [6]:
def decrypt(c: Ciphertext, pk: PublicKey, sk: SecretKey) -> int:
    # NOTE: This implementation of https://en.wikipedia.org/wiki/Modular_multiplicative_inverse is expensive
    # TODO: One can use the `modinv` function from above but I'll leave this code here as another way to compute it
    return ((c.r ** sk.d) ** (pk.p - 2) * c.t) % pk.p

In [7]:
def mult(a: Ciphertext, b: Ciphertext) -> Ciphertext:
    r: int = a.r * b.r
    t: int = a.t * b.t
    return Ciphertext(r, t)

In [8]:
pk, sk = keygen(47)

print('--- Message Encryption / Decryption ---')
plaintext: int = 42
print(f'Message (Plaintext): {plaintext}')
    
ciphertext: Ciphertext = encrypt(plaintext, pk)
print(f'Message (Ciphertext): {ciphertext}')

decrypted: int = decrypt(ciphertext, pk, sk)
print(f'Message (Decrypted): {decrypted}')

--- Message Encryption / Decryption ---
Message (Plaintext): 42
Message (Ciphertext): Ciphertext(r=27, t=16)
Message (Decrypted): 42


In [9]:
pk, sk = keygen(47)

print('--- Encrypted Multiplication ---')
a: int = 6
b: int = 5
print(f'Numbers (Plaintext): {a}, {b}')
print(f'Result (Plaintext): {a * b}')

enc_a: Ciphertext = encrypt(a, pk)
enc_b: Ciphertext = encrypt(b, pk)
print(f'Numbers (Ciphertext): {enc_a}, {enc_b}')

result: Ciphertext = mult(enc_a, enc_b)
print(f'Result (Ciphertext): {result}')
decrypted: int = decrypt(result, pk, sk)
print(f'Result (Decrypted): {decrypted}')

--- Encrypted Multiplication ---
Numbers (Plaintext): 6, 5
Result (Plaintext): 30
Numbers (Ciphertext): Ciphertext(r=3, t=14), Ciphertext(r=8, t=41)
Result (Ciphertext): Ciphertext(r=24, t=574)
Result (Decrypted): 30


## [RSA](https://en.wikipedia.org/wiki/RSA_(cryptosystem)) Cryptosystem

RSA can be used to perform encrypted multiplications.

In [10]:
class PublicKey(NamedTuple):
    e: int
    n: int

class SecretKey(NamedTuple):
    d: int
    n: int

class Ciphertext(NamedTuple):
    m: int

In [11]:
def keygen(p: int, q: int) -> Tuple[PublicKey, SecretKey]:
    n: int = p * q
    phi: int = (p - 1) * (q - 1)
    # e must be greater than 1 and smaller than phi
    # furthermore gcd(phi, e) must be 1
    e: int = 2
    while gcd(phi, e) != 1:
        e += 1
    d: int = modinv(e, phi)
    pk: PublicKey = PublicKey(e, n)
    sk: SecretKey = SecretKey(d, n)
    return (pk, sk)
    
assert keygen(61, 53)[0] == PublicKey(7, 3233)
assert keygen(61, 53)[1] == SecretKey(1783, 3233)

In [12]:
def encrypt(message: int, pk: PublicKey) -> Ciphertext:
    return Ciphertext(message ** pk.e % pk.n)

In [13]:
def decrypt(c: Ciphertext, sk: SecretKey) -> int:
    return c.m ** sk.d % sk.n

In [14]:
def mult(a: Ciphertext, b: Ciphertext) -> Ciphertext:
    return Ciphertext(m=(a.m * b.m))

In [15]:
pk, sk = keygen(61, 53)

print('--- Message Encryption / Decryption ---')
plaintext: int = 42
print(f'Message (Plaintext): {plaintext}')
    
ciphertext: Ciphertext = encrypt(plaintext, pk)
print(f'Message (Ciphertext): {ciphertext}')

decrypted: int = decrypt(ciphertext, sk)
print(f'Message (Decrypted): {decrypted}')

--- Message Encryption / Decryption ---
Message (Plaintext): 42
Message (Ciphertext): Ciphertext(m=240)
Message (Decrypted): 42


In [16]:
pk, sk = keygen(61, 53)

print('--- Encrypted Multiplication ---')
a: int = 6
b: int = 5
print(f'Numbers (Plaintext): {a}, {b}')
print(f'Result (Plaintext): {a * b}')

enc_a: Ciphertext = encrypt(a, pk)
enc_b: Ciphertext = encrypt(b, pk)
print(f'Numbers (Ciphertext): {enc_a}, {enc_b}')

result: Ciphertext = mult(enc_a, enc_b)
print(f'Result (Ciphertext): {result}')
decrypted: int = decrypt(result, sk)
print(f'Result (Decrypted): {decrypted}')

--- Encrypted Multiplication ---
Numbers (Plaintext): 6, 5
Result (Plaintext): 30
Numbers (Ciphertext): Ciphertext(m=1898), Ciphertext(m=533)
Result (Ciphertext): Ciphertext(m=1011634)
Result (Decrypted): 30


## [Efficient Homomorphic Encryption on Integer Vectors and Its Applications](https://www.rle.mit.edu/sia/wp-content/uploads/2015/04/2014-zhou-wornell-ita.pdf)

**NOTE:** The code written here was produced by following the blog post ["Building Safe A.I."](http://iamtrask.github.io/2017/03/17/safe-ai/) by Andrew Trask.

### Terminology

- **S**: Matrix which represents the secret / private key
- **M**: Public Key (also used to perform Math operations)
- **c**: Vector which contains the encrypted data
- **x**: Plaintext (some papers use the variable **m** instead)
- ***w***: (Weighting) Scalar used to control signal / noise ratio of **x**
- **e**: Random noise (e.g. noise added to the data before encrypting it via the public key) which makes the decryption difficult

Homomorphic Encryption has 4 kind of operations we care about:

1. Public / private keypair generation
1. One-way encryption
1. Decryption
1. Math operations

$$
\textit{S}c = \textit{w}x + e
$$

$$
x = \lceil \frac{Sc}{\textit{w}} \rfloor
$$

In [17]:
def generate_key(w: int, m: int, n: int) -> ndarray:
    S: ndarray = (np.random.rand(m, n) * w / (2 ** 16))
    return S

def encrypt(x: ndarray, S: ndarray, m: int, n: int, w: int) -> ndarray:
    assert len(x) == len(S)
    e: ndarray = (np.random.rand(m))
    c: ndarray = np.linalg.inv(S).dot((w * x) + e)
    return c

def decrypt(c: ndarray, S: ndarray, w) -> ndarray:
    return (S.dot(c) / w).astype('int')

def switch_key(c: ndarray, S: ndarray, m: int, n: int, T) -> (ndarray, ndarray):
    l: int = int(np.ceil(np.log2(np.max(np.abs(c)))))
    c_star: ndarray = get_c_star(c, m, l)
    S_star: ndarray = get_S_star(S, m, n, l)
    n_prime = n + 1
    S_prime = np.concatenate((np.eye(m), T.T), 0).T
    A: ndarray = (np.random.rand(n_prime - m, n * l) * 10).astype('int')
    E: ndarray = (1 * np.random.rand(S_star.shape[0], S_star.shape[1])).astype('int')
    M: ndarray = np.concatenate(((S_star - T.dot(A) + E), A), 0)
    c_prime: ndarray = M.dot(c_star)
    return c_prime, S_prime

def get_c_star(c: ndarray, m: int, l: int) -> ndarray:
    c_star: ndarray = np.zeros(l * m, dtype='int')
    for i in range(m):
        b: ndarray = np.array(list(np.binary_repr(np.abs(c[i]))), dtype='int')
        if (c[i] < 0):
            b *= -1
        c_star[(i * l) + (l - len(b)): (i + 1) * l] += b
    return c_star

def get_S_star(S: ndarray, m: int, n: int, l: int) -> ndarray:
    S_star: List = list()
    for i in range(l):
        S_star.append(S * 2 ** (l - i - 1))
    S_star: ndarray = np.array(S_star).transpose(1, 2, 0).reshape(m, n * l)
    return S_star

def get_T(n: int) -> ndarray:
    n_prime = n + 1
    T: ndarray = (10 * np.random.rand(n, n_prime - n)).astype('int')
    return T

def encrypt_via_switch(x: ndarray, w: int, m: int, n: int, T: ndarray) -> (ndarray, ndarray):
    c, S = switch_key(x * w, np.eye(m), m, n, T)
    return (c, S)

In [18]:
x: ndarray = np.array([0, 1, 2, 5])
    
m: int = len(x)
n: int = m
w: int = 16

S: ndarray = generate_key(w, m, n)
S

array([[1.82191813e-04, 2.26648882e-04, 1.59727698e-04, 6.46591058e-05],
       [2.27827232e-04, 6.19370823e-05, 1.11027165e-04, 5.93790272e-06],
       [1.16780459e-04, 2.19654105e-04, 2.05948472e-04, 3.87387963e-05],
       [1.42767530e-04, 1.31332099e-04, 2.05187437e-04, 1.29092801e-04]])

### Basic addition / multiplication

In [19]:
c: ndarray = encrypt(x, S, m, n, w)
c

array([-119380.40219085, -373082.9341732 ,  586973.15077334,
        200100.36312486])

In [20]:
decrypt(c, S, w)

array([0, 1, 2, 5])

In [21]:
decrypt(c + c, S, w)

array([ 0,  2,  4, 10])

In [22]:
decrypt(c * 10, S, w)

array([ 0, 10, 20, 50])

### Key-switching addition / multiplication

In [23]:
T: ndarray = get_T(n)

In [24]:
c, S = encrypt_via_switch(x, w, m, n, T)

In [25]:
decrypt(c, S, w)

array([0, 1, 2, 5])

In [26]:
decrypt(c + c, S, w)

array([ 0,  2,  4, 10])

In [27]:
decrypt(c * 10, S, w)

array([ 0, 10, 20, 50])