In [1]:
import random
import math
import numpy as np
import sympy
from sympy import jacobi_symbol

In [2]:
# pick two primes
p = 101
q = 113

assert sympy.isprime(p)
assert sympy.isprime(q)

In [3]:
n = p*q

In [4]:
# find non-residue x
while True:
    x = random.randint(1, n-1)
    if (
        math.gcd(x, n) == 1 
        and jacobi_symbol(x, p) == -1 
        and jacobi_symbol(x, q) == -1
    ):
        break

print(f'non-residue x: {x}')

non-residue x: 3361


In [5]:
assert jacobi_symbol(x, n)

# encryption

In [6]:
def encrypt_bit(b: int, r: int):
    '''
    Encrypt a bit
    Args:
        b (int): 0 or 1
        r (int): random integer that co-prime to n
    Returns:
        encrypted bit (int)
    '''
    return ( pow(r, 2, n) * pow(x, b, n) ) % n

In [7]:
def generate_random():
    '''
    Generate random that co-prime to n
    Returns:
        random number (int)
    '''
    while True:
        r = random.randint(1, n)
        if math.gcd(r, n) == 1:
            break
    return r

In [8]:
def encrypt(m):
    '''
    Encrypt an integer message
    Args:
        m (int): plaintest
    Returns:
        ciphertext (int)
    '''
    m_binary = bin(m)[2:]
    
    # number of bits
    k = len(m_binary)

    c = []
    for i in range(0, k):
        mi = int(m_binary[i])
        ri = generate_random()

        ci = encrypt_bit(mi, ri)
        c.append(ci)
    return c

In [9]:
m = 17

In [10]:
c = encrypt(m)

In [11]:
c

[7840, 348, 10743, 2612, 5414]

# Decryption

In [12]:
def decrypt(c: list):
    '''
    Decrypt a ciphertext
    Args:
        c (list): ciphertext - encrypted bits of plaintext
    Returns:
        plaintext (int)
    '''
    m_binaries = []
    for i in c:
        xp = i % p
        xq = i % q

        if pow(xp, int((p-1)/2), p) == 1 and pow(xq, int((q-1)/2), q) == 1:
            m_binaries.append("0")
        else:
            m_binaries.append("1")
        
    m_binary = "".join(m_binaries)
    return int(m_binary, 2)

In [13]:
decrypt(c) == m

True

# homomorphic features

In [14]:
args = [
    [0, 0],
    [0, 1],
    [1, 0],
    [1, 1]
]

In [15]:
for b1, b2 in args:
    r1 = generate_random()
    r2 = generate_random()

    c1 = encrypt_bit(b1, r1)
    c2 = encrypt_bit(b2, r2)

    assert ( c1 * c2 ) % n == encrypt_bit(b1+b2, r1*r2)
    b1_xor_b2 = b1 ^ b2
    assert decrypt([( c1 * c2 ) % n]) == b1_xor_b2