In [1]:
from random import randint
global num_of_iterations
num_of_iterations = 40

In [2]:
def mod_pow(x, y, m):
    # Returns (x^y) % p
    res = 1

    x = x % m
    while y > 0:
        if y & 1:
            res = (res * x) % m

        y = y >> 1
        x = (x * x) % m

    return res

# This function is called
# for all k trials. It returns
# false if n is composite and
# returns false if n is
# probably prime. d is an odd
# number such that d*2<sup>r</sup> = n-1
# for some r >= 1
def miller_rabin_test(d, n):
    # Pick a random number in [2..n-2]
    # Corner cases make sure that n > 4
    a = 2 + randint(1, n - 4)

    # Compute a^d % n
    x = mod_pow(a, d, n)

    if x == 1 or x == n - 1:
        return True

    while d != n - 1:
        x = (x * x) % n
        d *= 2

        if x == 1:
            return False
        if x == n - 1:
            return True

    # Return composite
    return False

# It returns false if n is
# composite and returns true if n
# is probably prime. k is an
# input parameter that determines
# accuracy level. Higher value of
# k indicates more accuracy.
def is_prime(n, k):
    # Corner cases
    if n <= 1 or n == 4:
        return False
    if n <= 3:
        return True

    # Find r such that n =
    # 2^d * r + 1 for some r >= 1
    d = n - 1
    while d % 2 == 0:
        d //= 2

    # Iterate given number of 'k' times
    for i in range(k):
        if not miller_rabin_test(d, n):
            return False

    return True

In [3]:
def generate_random_prime(range_min, range_max):
    global num_of_iterations
    while True:
        number = randint(range_min, range_max)
        if is_prime(number, num_of_iterations):
            return number

def generate_smaller_bigger(range_min, range_max):
    a = generate_random_prime(range_min, range_max)
    b = generate_random_prime(range_min, range_max)
    return (a, b) if a <= b else (b, a)

def generate_pairs(range_min, range_max):
    p_p1 = generate_smaller_bigger(range_min, range_max)
    q_q1 = generate_smaller_bigger(range_min, range_max)
    return (p_p1[0], q_q1[0]), (p_p1[1], q_q1[1])


In [4]:
def gcd_extended(a, m):
    if a == 0:
        return m, 0, 1
    gcd, x1, y1 = gcd_extended(m % a, a)
    x = y1 - (m // a) * x1
    y = x1
    return gcd, x, y


def inverse_modulo(a, m):
    gcd, x, y = gcd_extended(a, m)
    if gcd == 1:
        return (x % m + m) % m
    else:
        return -1

In [5]:
def generate_keys(p, q):
    n = p * q
    phi = (p - 1) * (q - 1)
    while True:
        e = randint(2, phi - 1)
        if gcd_extended(e, phi)[0] == 1:
            d = inverse_modulo(e, phi)
            return e, n, d

In [6]:
class User:
    def __init__(self, name, p, q):
        self.name = name
        self.p = p
        self.q = q

        keys = generate_keys(p, q)
        self.e = keys[0]
        self.n = keys[1]
        self.d = keys[2]

    def regenerate_keys(self, receiver_n, range_min, range_max):
        while True:
            p = generate_random_prime(range_min, range_max)
            q = generate_random_prime(range_min, range_max)
            keys = generate_keys(p, q)
            if keys[1] <= receiver_n:
                self.p = p
                self.q = q
                self.e = keys[0]
                self.n = keys[1]
                self.d = keys[2]
                return

    def encrypt(self, m, e, n):
        c = mod_pow(m, e, n)
        return c

    def decrypt(self, c):
        m = mod_pow(c, self.d, self.n)
        return m

    def sign(self, m):
        s = mod_pow(m, self.d, self.n)
        return s

    def verify(self, m, s, e, n):
        m_check = mod_pow(s, e, n)
        return m == m_check

    def send_key(self, k, e, n):
        if self.n > n:
            print('Need to regenerate keys!')
            return
        k_enc = self.encrypt(k, e, n)
        s = self.sign(k)
        s_enc = self.encrypt(s, e, n)
        return k_enc, s_enc

    def receive_key(self, k_enc, s_enc, e, n):
        k = self.decrypt(k_enc)
        s = self.decrypt(s_enc)
        if self.verify(k, s, e, n):
            print('Success')
        else:
            print('Failure')


In [8]:
range_min = (2 ** 255) + 1
range_max = (2 ** 256) - 1
smaller,bigger = generate_pairs(range_min, range_max)

lil = User('Lil', smaller[0], smaller[1])
big = User('Big', bigger[0], bigger[1])

message = randint(1, 2 ** 30)
print(f'Generated message: {message}\n')

m_enc = lil.encrypt(message, big.e, big.n)
s = lil.sign(message)
print(f'Encrypted message from Lil to Big: {m_enc}')
print(f'Signature: {s}\n\n')


m_dec = big.decrypt(m_enc)
s_verify = big.verify(m_dec, s, lil.e, lil.n)
print(f'Big decrypted message: {m_dec}')
print(f'Signature verify: {s_verify}')

Generated message: 696218097

Encrypted message from Lil to Big: 4389782605843559830433250718211387723067700586443649655811637780092556688854354510119774171832702416548022542491203414840452132534119975643289020707706004
Signature: 8608556994551702746755689314883178697221366354083975974421632214827288055777717075055802250643702974801082832365051901684192396904969996927719713205635521


Big decrypted message: 696218097
Signature verify: True


In [9]:
k = randint(1, 2 ** 30)
print(f'Generated key: {k}\n')

k_enc, s_enc = lil.send_key(k, big.e, big.n)
big.receive_key(k_enc, s_enc, lil.e, lil.n)

Generated key: 554730538

Success
