In [1]:
def extensiveGCD(x, y):
    if x < y:
        x, y = y, x
    div = lambda x, y: (x//y, x%y)
    factors = []

    # this is totally unnecessary, it just felt appropriate to relabel
    # the variables in a form similar to the DA: n = pq + r
    n, q = x, y

    while q > 0:
        fac, rem = div(n, q)
        factors.append(fac)
        n = q
        q = rem

    # we throw out the last factor as we do not need it for divisibility
    factors = factors[::-1][1:]

    return n, factors
    

In [2]:
def pairBezout(x, y):
    coef = lambda P, c: (P[1], P[0] - P[1] * c)
    maxSorted = x < y
    d, factors = extensiveGCD(x, y)

    bezout = (0, 1)
    for fac in factors:
        bezout = coef(bezout, fac)

    if maxSorted:
        bezout = (bezout[1], bezout[0])

    return bezout

def testValidBezout(x, y):
    d, _ = extensiveGCD(x, y)
    p, q = pairBezout(x, y)
    left = (p == 0) or (p * x % y == d)
    right = (q == 0) or (q * y % x == d)
    return all([left, right])


In [3]:
def is_prime(n):
    from math import sqrt, ceil
    # we only need to check whether primes up to faclimit divide n, as some
    # factor needs to be smaller than faclimit (can be proven by
    # contradiction!) in fact, we can check whether any number in this range
    # divides n: this is way more expensive but for small numbers we don't have
    # to worry about finding primes inside the range 
    faclimit = ceil(sqrt(n))

    # we can divide the search in half just by testing for 2 and
    # limiting ourselves to odd primes
    if n % 2 == 0:
        return False
    for q in range(3, faclimit + 1, 2):
        if n % q == 0:
            return False

    return True

# find the closest prime sitting on top of low_seed
def generatePrime(low_seed):
    candidate = low_seed
    while True:
        if is_prime(candidate):
            return candidate
        else:
            candidate += 1

In [4]:
class FiniteIntegerRing:
    def __init__(self, n):
        if n == 0:
            raise ValueError("Do not specify 0! Use normal operations instead!")
        elif n < 0:
            n = -n

        self._n = n
    
    def add(self, a, b):
        return (a + b) % self._n

    def sub(self, a, b):
        return self.add(a, -b)

    def id(self, a):
        return self.add(a, 0)

    def add_inv(self, a):
        return self.sub(0, a)

    def mult(self, a, b):
        return (a * b) % self._n

    def pow(self, b, e):
        c = 1
        for f in range(e):
            f += 1
            c = self.mult(b, c)
        return c

class FiniteField(FiniteIntegerRing):
    def __init__(self, p):
        if not is_prime(p):
            raise ValueError("Specified value {} should be prime!".format(p))
        self._n = p
        
        self._cached = False
        if (p < 1000):
            self._cached = True
            # there's nothing inherently special about any one instance
            # of this class, i.e. p = q implies FiniteField(p) = FiniteField(q),
            # so it's fine to just cache the multiplicative inverses
            self._mult_inv_cache = [-1]
            for x in range(1, p+1):
                b, _ = pairBezout(x, self._n)
                self._mult_inv_cache.append(self.add(b, 0))

    def mult_inv(self, a):
        if a == 0:
            raise ZeroDivisionError

        if self._cached:
            # we cached the results!
            return self._mult_inv_cache[self.id(a)]
        else:
            b, _ = pairBezout(a, self._n)
            return self.add(b, 0)

    def div(self, a, b):
        return self.mult(a, self.mult_inv(b))


Sections about the Chinese Remainder theorem and Euler Totient function? (maybe also its relation to RSA?)

Also some good-ol' group theory & ring theory!

How about Mersenne primes as a bit of trivia?

In [5]:
from numpy.random import randint
p = generatePrime(randint(2**9, 2**10))
q = generatePrime(randint(2**9, 2**10))

# though ensure p != q
n = p*q
print(n, p, q)

790393 977 809


In [6]:
lcm = lambda a,b: a*b//extensiveGCD(a, b)[0]
car_tot = lambda p, q: lcm(p-1, q-1)

lam = car_tot(p, q)
print(lam)

98576


As a result, we have the public key $(n, e)$ and the private key $(n, d)$, which satisfies
$$
m^{ed} \equiv m \mod n
$$

## Chinese remainder to the rescue!

In [10]:
class RSAKey:
    def __init__(self, n, e):
        self._n = n
        self._key = e

        self._ring = FiniteIntegerRing(n)

    def encrypt(self, message):
        return self._ring.pow(message, self._key)

class RSASecretKey(RSAKey):
    def __init__(self, n, d, p, q):
        RSAKey.__init__(self, n, d)
        self._p = p
        self._q = q

        self._dp = d % (p - 1)
        self._dq = d % (q - 1)
        
        self._fp = FiniteField(p)
        self._fq = FiniteField(q)

    def decrypt(self, ciphertext):
        m1 = self._fp.pow(ciphertext, self._dp)
        m2 = self._fq.pow(ciphertext, self._dq)
        h = self._fp.mult(self._fp.mult_inv(self._q), m1 - m2)

        return m2 + h * q

In [22]:
def generateRSAKeyPair(p, q) -> (RSAKey, RSASecretKey):
    from math import lcm, gcd
    from numpy.random import choice
    n = p * q
    totient = lcm(p-1, q-1)

    coprimes = [x for x in range(2, totient)
                if gcd(x, totient) == 1]

    e = choice(coprimes)
    d, _ = pairBezout(e, totient)
    # getting rid of the negative
    d = d + totient

    pub = RSAKey(n, e)
    sec = RSASecretKey(n, d, p, q)
    return (pub, sec)

Let's encrypt a message and send it to ourselves :)

In [23]:
from numpy.random import randint
p = generatePrime(randint(2**10, 2**11))
q = generatePrime(randint(2**10, 2**11))
print(p, q)

pub, sec = generateRSAKeyPair(p, q)

1361 1531


In [26]:
message = 0xbeef
ciphertext = pub.encrypt(message)
print(hex(ciphertext))

0x13eb14


In [28]:
decrypted_message = sec.decrypt(ciphertext)
print(hex(decrypted_message))

0xbeef
