# Homomorphic Encrpytion
Playing around with (fully) homomorphic encryption schemes.

In [1]:
from math import gcd
from random import randint
from typing import List, Tuple, NamedTuple

In [2]:
# Taken from: https://en.wikibooks.org/wiki/Algorithm_Implementation/Mathematics/Extended_Euclidean_algorithm
def egcd(a: int, b: int) -> Tuple[int, int, int]:
    if a == 0:
        return (b, 0, 1)
    else:
        g, y, x = egcd(b % a, a)
        return (g, x - (b // a) * y, y)

def modinv(a: int, m: int) -> int:
    g, x, y = egcd(a, m)
    if g != 1:
        raise Exception('modular inverse does not exist')
    else:
        return x % m

assert modinv(17, 3120) == 2753
assert egcd(1071, 462) == (21, -3, 7)

## [El Gamal](https://en.wikipedia.org/wiki/ElGamal_encryption)

El Gamal can be used to perform encrypted multiplications.

In [3]:
class PublicKey(NamedTuple):
    p: int
    a: int
    b: int

class SecretKey(NamedTuple):
    d: int

class Ciphertext(NamedTuple):
    r: int
    t: int

In [4]:
def keygen(p: int) -> Tuple[PublicKey, SecretKey]:
    # a: 1 < a < p - 1
    a: int = randint(1, p - 1)
    # d: 2 <= d <= p - 2
    d: int = randint(2, p - 2)
    b: int = (a ** d) % p
    pk: PublicKey = PublicKey(p, a, b)
    sk: SecretKey = SecretKey(d)
    return (pk, sk)

In [5]:
def encrypt(message: int, pk: PublicKey) -> Ciphertext:
    k: int = randint(0, 100)    
    r: int = (pk.a ** k) % pk.p
    t: int = ((pk.b ** k) * message) % pk.p
    return Ciphertext(r, t)

In [6]:
def decrypt(c: Ciphertext, pk: PublicKey, sk: SecretKey) -> int:
    # NOTE: This implementation of https://en.wikipedia.org/wiki/Modular_multiplicative_inverse is expensive
    # TODO: One can use the `modinv` function from above but I'll leave this code here as another way to compute it
    return ((c.r ** sk.d) ** (pk.p - 2) * c.t) % pk.p

In [7]:
def mult(a: Ciphertext, b: Ciphertext) -> Ciphertext:
    r: int = a.r * b.r
    t: int = a.t * b.t
    return Ciphertext(r, t)

In [8]:
pk, sk = keygen(47)

print('--- Message Encryption / Decryption ---')
plaintext: int = 42
print(f'Message (Plaintext): {plaintext}')
    
ciphertext: Ciphertext = encrypt(plaintext, pk)
print(f'Message (Ciphertext): {ciphertext}')

decrypted: int = decrypt(ciphertext, pk, sk)
print(f'Message (Decrypted): {decrypted}')

--- Message Encryption / Decryption ---
Message (Plaintext): 42
Message (Ciphertext): Ciphertext(r=32, t=28)
Message (Decrypted): 42


In [9]:
pk, sk = keygen(47)

print('--- Encrypted Multiplication ---')
a: int = 6
b: int = 5
print(f'Numbers (Plaintext): {a}, {b}')
print(f'Result (Plaintext): {a * b}')

enc_a: Ciphertext = encrypt(a, pk)
enc_b: Ciphertext = encrypt(b, pk)
print(f'Numbers (Ciphertext): {enc_a}, {enc_b}')

result: Ciphertext = mult(enc_a, enc_b)
print(f'Result (Ciphertext): {result}')
decrypted: int = decrypt(result, pk, sk)
print(f'Result (Decrypted): {decrypted}')

--- Encrypted Multiplication ---
Numbers (Plaintext): 6, 5
Result (Plaintext): 30
Numbers (Ciphertext): Ciphertext(r=42, t=2), Ciphertext(r=9, t=26)
Result (Ciphertext): Ciphertext(r=378, t=52)
Result (Decrypted): 30


## [RSA](https://en.wikipedia.org/wiki/RSA_(cryptosystem)) Cryptosystem

RSA can be used to perform encrypted multiplications.

In [10]:
class PublicKey(NamedTuple):
    e: int
    n: int

class SecretKey(NamedTuple):
    d: int
    n: int

class Ciphertext(NamedTuple):
    m: int

In [11]:
def keygen(p: int, q: int) -> Tuple[PublicKey, SecretKey]:
    n: int = p * q
    phi: int = (p - 1) * (q - 1)
    # e must be greater than 1 and smaller than phi
    # furthermore gcd(phi, e) must be 1
    e: int = 2
    while gcd(phi, e) != 1:
        e += 1
    d: int = modinv(e, phi)
    pk: PublicKey = PublicKey(e, n)
    sk: SecretKey = SecretKey(d, n)
    return (pk, sk)
    
assert keygen(61, 53)[0] == PublicKey(7, 3233)
assert keygen(61, 53)[1] == SecretKey(1783, 3233)

In [12]:
def encrypt(message: int, pk: PublicKey) -> Ciphertext:
    return Ciphertext(message ** pk.e % pk.n)

In [13]:
def decrypt(c: Ciphertext, sk: SecretKey) -> int:
    return c.m ** sk.d % sk.n

In [14]:
def mult(a: Ciphertext, b: Ciphertext) -> Ciphertext:
    return Ciphertext(m=(a.m * b.m))

In [15]:
pk, sk = keygen(61, 53)

print('--- Message Encryption / Decryption ---')
plaintext: int = 42
print(f'Message (Plaintext): {plaintext}')
    
ciphertext: Ciphertext = encrypt(plaintext, pk)
print(f'Message (Ciphertext): {ciphertext}')

decrypted: int = decrypt(ciphertext, sk)
print(f'Message (Decrypted): {decrypted}')

--- Message Encryption / Decryption ---
Message (Plaintext): 42
Message (Ciphertext): Ciphertext(m=240)
Message (Decrypted): 42


In [16]:
pk, sk = keygen(61, 53)

print('--- Encrypted Multiplication ---')
a: int = 6
b: int = 5
print(f'Numbers (Plaintext): {a}, {b}')
print(f'Result (Plaintext): {a * b}')

enc_a: Ciphertext = encrypt(a, pk)
enc_b: Ciphertext = encrypt(b, pk)
print(f'Numbers (Ciphertext): {enc_a}, {enc_b}')

result: Ciphertext = mult(enc_a, enc_b)
print(f'Result (Ciphertext): {result}')
decrypted: int = decrypt(result, sk)
print(f'Result (Decrypted): {decrypted}')

--- Encrypted Multiplication ---
Numbers (Plaintext): 6, 5
Result (Plaintext): 30
Numbers (Ciphertext): Ciphertext(m=1898), Ciphertext(m=533)
Result (Ciphertext): Ciphertext(m=1011634)
Result (Decrypted): 30
