# RSA implementation

Now, we will take a look at a sample implementation of RSA and some hurdles we have to overcome in the process

We have three steps to take care of:

1. Key generation
2. Encryption
3. Decryption

## Key generation

The most difficult and complex part, not only of RSA, but also of many other cryptographic schemes. Many attacks are based on a poor choice of parameters, which the implementer may not take into account. See: https://crypto.stanford.edu/~dabo/papers/RSA-survey.pdf

* If $p$ and $q$ are close to each other $n = pq$ may be factorized with Fermat's factorization algorithm
* If $p - 1$ or $q - 1$ is composed of small factors $\rho$ Pollard's may be applied
* If the computed private key $d$ is small ($d < \frac{1}{3}n^{\frac{1}{4}}$), Wiener's attack may be used to expose $d$
* If the public exponent $e$ is small Copperfield's method could recover the message. If $M^e < n$ then $M$ could be easily factored
* ... and more I may have missed
<br><br>
This is why many manuals advise against trying to program a cryptographic solution from scratch.

What we can do is to follow standards and guidelines. For this example, we will use the National Institute of Standards and Technology's (NIST) Federal Information Processing Standards (FIPS). FIPS 186-4 contains guidelines for implementing generation of RSA keys.

<br>
The standard takes into account the above concerns and provides algorithms for the generation of secure values of $p$, $q$ and $e$. In particular:

* $2^{16} < e < 2 ^ {256}$, though no reason is given for this range of values other than having a large enough magnitude of $e$. In fact the standard requires $e$ to be chosen beforehand.
* $|p - q| > 2 ^{bitlength / 2 - 100}$, to ensure that $p$ and $q$ are not too close (note that recommended bitlengths are of 2048)
* $\sqrt{2} * 2 ^ {bitlength / 2} < p < 2 ^{bitlength / 2} - 1$ and $\sqrt{2} * 2 ^ {bitlength / 2} < q < 2 ^{bitlength / 2} - 1$ to ensure $n$ has the proper bitlength and $q$ and $p$ are both big enough
* $2 ^{bitlength / 2} < d < LCM(p - 1, q - 1)$

$p$ and $q$ are generated randomly and therefore we need to test whether, in addition to the previous conditions, are prime. So we need a primality test and a safe random bit (number) generator.

<br>
We will start with the primality test. There are deterministic primality tests that ensure the number is prime. For our purposes probabilistic tests are good enough and simpler and they are also recognized by the standard.

In [None]:
import secrets

def power_mod(base: int, exp: int, m: int) -> int:
    '''
    Compute (base ** exp) % m

    Parameters
    ----------
    base : int
        Base
    power : int
        Exponent
    m : int
        Modulo

    Returns
    -------
    int
        Result
    '''
    return pow(base, exp, m)

def bitlength(n: int) -> int:
    '''
    Return the number of bits required to represent n

    Parameters
    ----------
    n : int
        number

    Returns
    -------
    int
        Number of bits to represent n
    '''
    return n.bit_length()

In [None]:
# Some illustration
print("3 ^ 4 mod 31:", power_mod(3, 4, 31))
print("bitlength of 2 ** 5 - 1:", bitlength(2 ** 5 - 1))
print("bitlength of 2 ** 5:", bitlength(2 ** 5))

In [None]:
def miller_rabin(w: int, k: int = 10) -> bool:
    '''
    Computes the Miller-Rabin primality test for n. k is the number of rounds
    to be executed, a greater number increases the probability that n is actually
    prime is the test is positive.

    Parameters
    ----------
    w : int
        Number to be tested for primality.
    k : int, optional
        Number of rounds of the algorithm. The default is 32.

    Returns
    -------
    bool: whether n passes the test
    '''
    # We simply shotcircuit some easy problems
    if w in [2, 3, 5, 7]:
        return True
    # If n is divisible by two do not bother with the algorithm: it is not prime
    if w % 2 == 0:
        return False
    a = 0
    m = w - 1
    while m % 2 == 0:
        a += 1
        m //= 2
    
    wlen = bitlength(w)
    for _ in range(k):
        b = secrets.randbits(wlen)
        while b <= 1 or b >= w - 1:
            b = secrets.randbits(wlen)
        z = power_mod(b, m, w)
        if z == 1 or z == w - 1:
            continue
        i = 0
        while i < a - 1 and z != 1:
            z = power_mod(z, 2, w)
            if z == w - 1:
                break
            i += 1
        else:
            return False
    return True

In [None]:
# we throw a couple of quick tests
print("34 is probably prime:", miller_rabin(34))
print("2 ** 456 * 3 is probably prime:", miller_rabin(2 ** 456 * 3))
print("5952322734258198408259587570747479451304815439783987563217060625033 is probably prime:", 
      miller_rabin(5952322734258198408259587570747479451304815439783987563217060625033))

For probabilisitic primality tests, it is important to have an estimation of the error. For the Miller-Rabin test, the FIPS gives a way to estimate the number of iterations required to have a concrete error bound. The formula is complex, but we can implement it for convenience, rather than having to refer to a precomputed table.

In [None]:
from decimal import Decimal, getcontext
import math

# We fix 2 ** - 128 to correspond with the estimated 128 bits of security
# An attacker can succeed in 2 ** bits_of_security operations
def estimate_k(bits: int, error : float = 2 ** -128) -> int:
    '''
    Compute the number of iterations of Miller-Rabin necessary to get a 
    probability of having a composite number with bits bits 
    passing the test lower than error.

    Parameters
    ----------
    bits : int
        Number of bits of the number to be tested
    error : TYPE, optional
        Upper bound on the probability of a composite number passing the test.
        The default is 2 ** -128.

    Returns
    -------
    int
        Number of iterations of Miller-Rabin.
    '''
    max_t = math.ceil(- math.log2(error) / 2)
    max_m = math.floor(2 * math.sqrt(bits - 1) - 1)
    for t in range(1, max_t):
        for M in range(3, max_m):
            first = Decimal(2.00743 * math.log(2) * bits) * pow(Decimal(2), -bits)
            summatory = sum(
                (
                    Decimal(2 ** (m - (m - 1) * t)) 
                    * sum(
                         Decimal(1 / Decimal(2) ** Decimal(j + (bits - 1) / j)) 
                         for j in range(2, m + 1)
                    )
                 )
                for m in range(3, M + 1)
            )
            summand = pow(Decimal(2), bits - 2 - M * t) 
            factor = (
                Decimal(8 * (math.pi ** 2 - 6) / 3) * pow(Decimal(2), bits - 2)
            )
            
            estimate = first * (summand + factor * summatory)
            if estimate < error:
                return t
    return max_t

In [None]:
# Test whether the estimate corresponds to NIST's estimates
print("For error bound 2 ** -100 with 512 bits:", estimate_k(512, 2 ** -100))
print("For error bound 2 ** -100 with 1024 bits:", estimate_k(1024, 2 ** -100))
print("For error bound 2 ** -100 with 1536 bits:", estimate_k(1536, 2 ** -100))

With the preliminaries done, we implement a function that returns a random probable prime (note the difference with _provable_ prime)

In [None]:
from typing import Callable

def random_odd_number_nbits(nbits: int) -> Callable[[], int]:
    '''
    Returns a function that takes no arguments and returns a random odd number
    with nbits number of bits

    Parameters
    ----------
    nbits : int
        Number of bits

    Returns
    -------
    (Callable[[], int])
        Function that returns a random number
    '''
    return lambda: secrets.randbits(nbits) | 1

def random_probable_prime(generator_func: Callable[[], int], k: int = 50, 
                          test_func: Callable[[int], bool] = None,
                          limit: int = 30000) -> int:
    '''
    Generate a random prime number with a set number of bits 

    Parameters
    ----------
    bits : int
        Number of bits of the number
    k: int
        Number of iterations of Miller-Rabin test for primality
    test_func : Callable[[int], bool], optional
        Defines other criteria for acceptance of the random number.
        If the returned value is False the number is discarded
    limit : int
        Maximum number of randomdonly generated numbers to be tested.
        If no number satisfies the criteria, raise a ValueError


    Returns
    -------
    A random prime number of the desired number of bits

    '''
    test_func = (lambda x: True) if test_func is None else test_func

    i = 0     
    while True:
        random_number = generator_func()
        
        if test_func(random_number) and miller_rabin(random_number, k=k):
            return random_number
        if limit is not None:
            i += 1
            if i > limit:
                raise ValueError("Could not find a random number satisfying properties")

In [None]:
rand_100 = random_odd_number_nbits(100)
rand_6 = random_odd_number_nbits(6)

random_bits_6 = rand_6()
random_bits_100 = rand_100()

print("Integer value of 6 random bits:", random_bits_6)
print("Integer value of 100 random bits:", random_bits_100)

print("Number of bits required to represent the random 6 bits:", bitlength(random_bits_6))
print("Number of bits required to represent the random 100 bits:", bitlength(random_bits_100))

In [None]:
# Some tests
bit_size = 512
r1 = random_probable_prime(random_odd_number_nbits(bit_size), estimate_k(bit_size, 1e-40))
print("random probable prime of 512 bits:", r1)

# Pass a function to limit acceptable values
r2 = random_probable_prime(random_odd_number_nbits(bit_size), estimate_k(bit_size, 1e-40), 
                           lambda candidate: candidate >= 2 ** (bit_size - 1))
print("random probable prime of 512 bits greater than 2 ** 511:", r2)

print("number of bits required for r1:", bitlength(r1))
print("number of bits required for r2:", bitlength(r2))

In [None]:
# Equivalent condition
r2 = random_probable_prime(random_odd_number_nbits(bit_size), 50, 
                           lambda candidate: bitlength(candidate) == bit_size)
print("random probable prime of 512 bits greater than 2 ** 511:", r2)
print("number of bits required for r2:", bitlength(r2))

Now, we implement the key generation of RSA.

<br>
A note on the default value of $e$. It is chosen since is big enough according to the standard and has only two bits set to 1, which accelerates fast modular exponentiation based on the binary notation of the exponent: 

https://en.wikipedia.org/wiki/Modular_exponentiation#Right-to-left_binary_method
<br>
https://es.planetcalc.com/8979/

In [None]:
print("2 ** 16 + 1 in binary:", bin(2 ** 16 + 1))
print("2 ** 16 - 1 in binary:", bin(2 ** 16 - 1))

In [None]:
%timeit power_mod(2 ** 1024, 2 ** 16 + 1, 3 ** 222)

In [None]:
%timeit power_mod(2 ** 1024, 2 ** 16 - 1, 3 ** 222)

In [None]:
import warnings

def multiplicative_inverse(number: int, m: int = None) -> int:
    '''
    Computes (number ** -1) % modulo
    '''
    return power_mod(number, -1, m)

def coprimes(a: int, b: int) -> bool:
    '''
    Tests whether a and b are coprimes
    '''
    return math.gcd(a, b) == 1

In [None]:
print("Inverse of 5 mod 77:", multiplicative_inverse(5, 77))
print("Are 12 and 34 coprimes:", coprimes(12, 34))
print("Are 5 and 111 coprimes:", coprimes(5, 111))

In [None]:
def rsa_keygen(nlen: int = 2048, e: int = 2 ** 16 + 1, tries : int = 30000) -> tuple[tuple[int, int], int]:
    '''
    Compute public and private keys for RSA

    Parameters
    ----------
    nlen : int
        Number of bits of n
    e: int
        Public exponent.
    tries : int. Default is 30000
        The number of randomly generated numbers to be tested for p and q
        in each iteration.
        If a number of random numbers equal to tries is generated, raise an
        error.
    Returns
    -------
    n, e, d:
        (n, e) is the public key and d the private key
    '''
    # This is a particularity of our implementation, we will see why
    if nlen < 9:
        raise ValueError("Number of bits of n must be greater than 8")
    # Why?
    if e % 2 != 1:
        raise ValueError("e should be odd")
    # We are not going to enforce these limits, but they are NIST's recommendations
    if e <= 2 ** 16 or e >= 2 ** 256:
        warnings.warn("exponent e should be an odd integer between 2 ** 16 and 2 ** 256, got {}".format(e))
    if nlen not in [2048, 3072, 4096]:
        warnings.warn("bitlen should be in [2048, 3072, 4096], got {}".format(nlen))
    
    # Have to ensure p * q has nlen bits
    p_size = math.ceil(nlen / 2)
    q_size = nlen - p_size
    
    # NIST restrictions to ensure p and q are big enough but not too close
    # Why those values?
    min_p = Decimal(2 ** (p_size - 1)) * Decimal(2).sqrt()
    min_q = Decimal(2 ** (q_size - 1)) * Decimal(2).sqrt()
    min_d = 2 ** (nlen // 2)
    p_q_diff = Decimal(2) ** Decimal(nlen / 2 - 100)
    
    # Ensure we mimimize the probabilities of error in the primality test
    k = estimate_k(nlen, 2 ** - 128)
    
    valid_d = False
    # d must not be too small and the number of bits of n must be exactly nlen
    # in accordance to NIST specifications
    while not valid_d:
        def valid_p(p_candidate):
            return p_candidate >= min_p and coprimes(p_candidate - 1, e)
        
        p = random_probable_prime(random_odd_number_nbits(p_size),
                                  k = k, test_func = valid_p, limit = tries)            
        
        def valid_q(q_candidate):
            return (
                q_candidate >= min_q
                and coprimes(q_candidate - 1, e)
                and abs(p - q_candidate) >= p_q_diff
            )
        
        q = random_probable_prime(random_odd_number_nbits(q_size), k = k, 
                                  test_func = valid_q, limit = tries)
        
        # Preserves properties of RSA and gives smaller values of d, which accelerates computations
        carmichael_lambda = math.lcm(p - 1, q - 1)
        d = multiplicative_inverse(e, carmichael_lambda)
        n = p * q
        
        # Check loop conditions
        valid_d = d > min_d
    return (n, e), d

In [None]:
(n, e), d = rsa_keygen(16)
print("n:", n, "e:", e, "d:", d)

print("======================")
# There is a clear difference in execution time
(n, e), d = rsa_keygen(1024)
print("n:", n)
print("e:", e)
print("d:", d)

**Why are values of $p$ and $q$ selected to be greater than $2^{psize - 1}\sqrt{2}$ and $2^{psize - 1}\sqrt{2}$ respectively?**
<br><br>
To ensure that the value of $n$ is of $nlen = psize + \text{ } qsize$ bits
<br><br>
Since $n = pq$ the minimum value of $n$ for these restrictions would be:

$$
n_{min} = p_{min}q_{min} = 2^{psize - 1}\sqrt{2} * 2^{psize - 1}\sqrt{2} = 2 * 2^{psize + qsize - 2} = 2 ^{nlen - 1}
$$

With $nlen$ bits we can represent values in the interval $[2^{nlen - 1}, 2^{nlen} - 1)$, so the minimum value of $n$, $n_{min}$ is exactly the minimum value we need $nlen$ bit to store.

## Encryption

Key generation is done, and now we need to implement encryption and decryption. The arithmetic itself is easy since the programming language does the heavy lifting, but we need to decide how are we going to deal with text. 
<br><br>
Using base 26 to encode characters would be too limited and ASCII does not support tildes or ñ. String characters are internally stored as sequences of bytes, which are given a representation according to a particular encoding.

In [None]:
a_string = "这他妈是什么意"
print("The original string:", a_string)

encoding = a_string.encode("utf-16LE")
print("The code points of the string in utf-16:", encoding)

What we are going to do is to operate in base $2^8$, that way, we can deal with any posible byte and therefore, with any possible character.
<br><br>
Another question to solve is to determine the conversion from bytes to numeric blocks ($M < n$, so an appropriate block size must be chosen) and from the encrypted blocks to bytes.

To convert from bytes to blocks we will simply convert the bytes to base $2^8$ considering a little endian representation:

$$
bytes = [0xFA, 0x12, 0x01]
$$

$$
value = 0x01 * (2^{8})^0 + 0x12 * (2 ^{8}) ^ 1 + 0xFA * (2 ^{8}) ^2
$$

Since we want to ensure $M < n$ we will choose a number of bytes lower than the number of bytes required to store $n$

$$
blocksize = \lceil bitlength(n)/ 8 - 1 \rceil
$$

An example:

$$
n = 352
$$
$$
bitlength(n) = 9
$$

Therefore 2 bytes of storage are required. The size of the block of bytes we convert to integer is:

$$
blocksize =  \lceil bitlength(n)/ 8 - 1 \rceil = 1
$$
$$
block\_max\_value = 2^8 - 1 < n
$$

<br><br>
This has the implication that our implementation of RSA will require at least 9 bits in $n$, otherwise, we would select a block size of 0 bytes.

We still have one problem, though. $M^e \text{ } mod \text{ } n$ may well a number of bits greater than our $blocksize$ for storage.
<br><br>
For example $n = 2479$, $blocksize = 1$, $e = 47$, $M = 78$

$$
M^e \equiv 1945 \text{ } mod \text{ } n
$$

$2 ^{10} = 1024 < 1945 < 2 ^{11} = 2048$, so we need 10 bits of storage. Therefore, each encrypted numeric message will have to be converted to $blocksize + 1$ bytes. This is somewhat inefficient, but RSA is designed to encrypt simple short messages such as symmetric keys not data, so what we are doing is unorthodox.

In [None]:
from typing import Iterable

def compute_block_size(n: int) -> int:
    '''
    Compute the size of the blocks of bytes to be converted to
    integers. The appropriate size for the encrypted blocks will
    be compute_block_size(n) + 1, to accomodate any possible value
    block_val < n
    '''
    quotient, remainder = divmod(bitlength(n), 8)
    if remainder == 0:
        quotient -= 1
    return quotient

def from_base_factors(factors: Iterable[int], base: int = 2 ** 8) -> int:
    '''
    Compute decimal value from coefficients located in factors in the selected
    base.
    
    The result is an integer whose value is
    factors[0] * base ** 0 + factors[1] * base ** 1, ..., factors[n] * base ** n
    '''
    total = 0
    nfactors = len(factors)
    for i, factor in enumerate(factors):
        total += factor * base ** (nfactors - i - 1)
    return total

In [None]:
n1 = 312
print("bitlength of {}:".format(n1), bitlength(n1))
print("Block size for n = {}:".format(n1), compute_block_size(356))

n2 = 2 ** 16 - 1
print("bitlength of {}:".format(n2), bitlength(n2))
print("Block size for n = {}:".format(n2), compute_block_size(n2))

In [None]:
# We take advantage of the fact that bytes are sequences of ints
block = b"hola"
print("Block value:", from_base_factors(block))
print("Manual block value:", 104  * 2 ** 24 + 111 * 2 ** 16 + 108 * 2 ** 8 + 97 * 2 ** 0)

In [None]:
factors = [0x01, 0x0F]
print("Value from factors:", from_base_factors(factors, base=16))

In [None]:
def iter_blocks(iterable: Iterable, n: int):
    '''
    Generator that returns blocks of size n from iterable. The last block may
    be truncated if the length of the iterable is not divisible by n
    '''
    if n <= 0:
        raise ValueError("n must be greater than 0")
    acum = []
    for elem in iterable:
        acum.append(elem)
        if len(acum) == n:
            yield acum
            acum = []
    if acum:
        yield acum

def block_from_bytes(byt: bytes) -> int:
    '''
    A simple alias
    Translate the bytes to a numeric value in base 2 ** 8.
    Bytes are taken in little endian.
    '''
    return from_base_factors(byt, 2 ** 8)

def blocks_from_bytes(by: bytes, block_size: int) -> list:
    if block_size <= 0:
        raise ValueError("Block size must be an integer greater than zero")
        
    return [block_from_bytes(byte_block) 
            for byte_block in iter_blocks(by, block_size)]

In [None]:
encoded = "dia".encode("utf-8")
for i, byte_block in enumerate(iter_blocks(encoded, 2)):
    print("block index: {}:".format(i), byte_block)

In [None]:
byte_block = b"hola"
print("Numeric value of {}:".format(byte_block), block_from_bytes(byte_block))

In [None]:
# block from bytes simply chains previous operations
encoded = "sample".encode("utf-8")
print("Integer blocks in {}:".format(encoded), blocks_from_bytes(encoded, 3))

first, second = iter_blocks(encoded, 3)
print("First block {}:".format(bytes(first)), 115 * 2 ** 16 + 97 * 2 ** 8 + 109 * 2 ** 0)

RSA applies $M^{exponent} \text{ mod } n$, with exponent being $e$ or $d$ depending on whether encryption or decryption is applied. Since only the exponent changes, we factor our the operation.

In [None]:
def rsa_conversion(by: bytes, n: int, ex: int, extract_blocks_size: int
                   ) -> list[int]:
    # Transform byte sequence to blocks of integers
    blocks = blocks_from_bytes(by, extract_blocks_size)
    # Exponentiation modulo n of each block
    return [power_mod(block, ex, n) for block in blocks]
    

def rsa_encrypt(by: bytes, n: int, e: int) -> bytes:
    block_size = compute_block_size(n)
    encrypted_block_size = block_size + 1
    
    last_size = len(by) % block_size    
    last_size = last_size or block_size
    
    encrypted = rsa_conversion(by, n, e, block_size)
    encrypted = [block.to_bytes(encrypted_block_size, byteorder="big") 
                 for block in encrypted]
    
    # We add an additional block with size of the last one.
    # This is necessary to properly decrypt leading null bytes
    padding_block = rsa_conversion(
        last_size.to_bytes(block_size, byteorder="big"), n, e, block_size)
    padding_block = [block.to_bytes(encrypted_block_size, byteorder="big")
                     for block in padding_block]
    encrypted = (b'').join(encrypted + padding_block)
    return encrypted

In [None]:
# join simply concatenates
byte_blocks = [b"ab", b"\x68\x78", b"\x00"]
print("Joined byte blocks:", (b"").join(byte_blocks))

In [None]:
(n, e), d = rsa_keygen(130)
message = "A: Смерть Ивана Ильича"
encoded = message.encode("utf-8")
print("encoded message:", encoded)

# Note the increased length
encrypted = rsa_encrypt(encoded, n, e)
print("Encrypted message:", encrypted)
print("Encrypted length:", len(encrypted), "Original length:", len(encoded))

It must be noted that in the real world RSA includes padding in the messages: https://www.rfc-editor.org/rfc/rfc8017.
<br><br>
As we saw, small values of $e$ can create situations in which $M^e < n$ and the message can be obtained by computing an e-th root. Padding can help with that.
<br><br>
But most importantly, padding can introduce a probabilistic component to the message: a plaintext can correspond to many different ciphertexts, so several attacks are mitigated: common modulus, small $e$, etc.
<br><br>
We have omitted padding in this implementation for simplicity, but RSA without an appropriate padding scheme is not considered secure.

## Decryption

We have already defined most operations so at this stage decryption is easy.
<br><br>
First, we define the auxiliary functions

In [None]:
def to_base_factors(original: int, base: int = 2 ** 8) -> list[int]:
    '''
    Compute the coefficients of the decomposition of original in the selected
    base. Original is assumed to be decimal.
    
    The result is a list with the coefficients such that
    original = l[0] * base ** (n - 1) + l[1] * base ** (n - 2), ..., l[n] * base ** 0
    '''
    factors = []
    while original > 0:
        original, remainder = divmod(original, base)
        factors.insert(0, remainder)
    return factors

In [None]:
int_block = 7561581
print("Corresponding bytes to {}:".format(int_block), bytes(to_base_factors(int_block)))

A problem presents itself: what do we do with null bytes in the most significant position?

In [None]:
# Bytes in block
encoded = b'\x00\x73\x61\x6D'
print("bytes:", encoded)

# transform bytes to int, the integer value of \x00 is 0
# so the value will be the same as b'\x73\x61\x6D'
int_block = block_from_bytes(encoded)
print("int block value:", int_block)

# Recover the bytes from the integer value of the block
original = to_base_factors(int_block)
print("original:", bytes(original))

If we had processed the value as little endian we would have the same problem, except it would be applied to trailing null bytes. We need to take into account the size of the block and add missing null bytes, if any.

In [None]:
def bytes_from_block(block: int, blocksize : int = None) -> bytes:
    '''
    Extract the original bytes from a numeric block.
    Bytes are in little endian.
    '''
    factors = to_base_factors(block, 2 ** 8)
    # deal with the null byte \x00
    if blocksize is not None:
        factors = [0] * (blocksize - len(factors)) + factors
    return bytes(factors)

In [None]:
# Bytes in block
encoded = b'\x73\x61\x6D\x00'
print("bytes:", encoded)

# transform bytes to int, the integer value of \x00 is 0
# so the value will be the same as b'\x73\x61\x6D'
int_block = block_from_bytes(encoded)
print("int block value:", int_block)

# Recover the bytes from the integer value of the block
original = bytes_from_block(int_block, len(encoded))
print("original:", bytes(original))

Note that this is not a problem with zeros in other positions:
<br><br>
bytes = b'\x63\x00\x00' = b'c\x00\x00'
<br><br>
$int\_value = 99 * (2^8)^2 + 0 * (2^8)^1 + 0 * (2^8)^0 = 6488064$
<br><br>
We convert it back to bytes:
<br><br>
$original = 6488064$
<br>
==================================
<br>
$quotient = 6488064 // 2^8  = 25344$
<br>
$remainder = 6488064 \text { mod } 2^8 = {\color{red}0}$
<br>
==================================
<br>
==================================
<br>
$quotient = 25344 // 2^8  = 99$
<br>
$remainder = 25344 \text { mod } 2^8 = {\color{red}0}$
<br>
==================================
<br>
==================================
<br>
$quotient = 99 // 2^8  = 0$
<br>
$remainder = 99 \text { mod } 2^8 = {\color{red}{99}}$
<br>
==================================
<br><br>

In [None]:
def rsa_decrypt(by: bytes, n: int, d: int) -> bytes:
    encrypted_block_size = compute_block_size(n) + 1
    
    decrypted = rsa_conversion(by, n, d, encrypted_block_size)
    last_size = decrypted[-1]
    # decrypt the last block independently
    last_block = [bytes_from_block(decrypted[-2], last_size)]
    
    decrypted = decrypted[:-2]
    decrypted = [bytes_from_block(block, encrypted_block_size - 1) 
                 for block in decrypted]
    decrypted = (b'').join(decrypted + last_block)
    
    return decrypted

In [None]:
original_bytes = rsa_decrypt(encrypted, n, d)
original_message = original_bytes.decode("utf-8")
print("decrypted message:", original_message)
print("original message:", message)

To reiterate: the division in blocks is unorthodox: most implementations just choose a size equal to the number of bytes required to accomodate $n$ and simply throw an error if the numerical value of the message is greater: https://github.com/Legrandin/pycryptodome/blob/master/lib/Crypto/PublicKey/RSA.py#L147

## More professional implementations

We have implemented a (incomplete and insecure) version of RSA for educational purposes, but in any serious application it is recommended to use some established library to include cryptographic operations. In Python's case the _cryptography_ library is one of the more popular choices: https://cryptography.io/en/latest/

As the documentation https://cryptography.io/en/latest/ says, even properly implemented cryptographic operations can be used incorrectly, so care must be taken to avoid security issues.

In [None]:
from cryptography.hazmat.primitives.asymmetric import rsa
from cryptography.hazmat.primitives.asymmetric import padding
from cryptography.hazmat.primitives import hashes

# The private key is stored accoring to specifications in RFC 8017
# https://www.rfc-editor.org/rfc/rfc8017, section 3.2
private_key = rsa.generate_private_key(65537, 2048)
public_key = private_key.public_key()
print([x for x in dir(private_key.private_numbers()) if "__" not in x])
print("Private key:", private_key.private_numbers().d)
print("Public key:", public_key.public_numbers())
print("Bitlength of n:", bitlength(public_key.public_numbers().n))

To encrypt the message padding is required and it is very useful as we will see.

In [None]:
message = "hello"
encoded = message.encode("utf-8")
print("Encoded message:", encoded)
encrypted = public_key.encrypt(
    encoded, 
    # Padding is a mandatory argument. Note that OAEP is the padding of choice for modern uses
    padding.OAEP(
        mgf=padding.MGF1(algorithm=hashes.SHA512()),
        algorithm=hashes.SHA512(),
        label=None
    )
)

# Note how the length of the ciphertext is incremented through padding
# and how each execution produces a different ciphertext:
# the padding used is probabilistic
print("ciphertext:", encrypted)

In [None]:
original_message = private_key.decrypt(
    encrypted,
    padding.OAEP(
        mgf=padding.MGF1(algorithm=hashes.SHA512()),
        algorithm=hashes.SHA512(),
        label=None
    )
)

print("Original encoded:", original_message)

In addition, the private key must be properly stored, which is another matter we ignored in our implementation. Since it is a private key, storage must be protected; saving in plaintext is advised against.
<br><br>
There are different encodings, but usually the process involves applying base64 so that the key contains only ASCII characters.

In [None]:
from cryptography.hazmat.primitives import serialization

pem = private_key.private_bytes(
   encoding=serialization.Encoding.PEM,
   format=serialization.PrivateFormat.PKCS8,
   # Obviously, the password should be something actually secure
   encryption_algorithm=serialization.BestAvailableEncryption(b'1234')
)
print("Private key encoded:", pem.decode("ascii"))

# Now save the key. open() simply returns a file object
with open("private_key.pem", "wb") as pf:
    pf.write(pem)

In [None]:
with open("private_key.pem", "rb") as pf:
    loaded_private_key = serialization.load_pem_private_key(
        # Need to pass the bytes of the saved key and the password if any
        # and there should be a password
        pf.read(),
        password=b"1234"
    )
    
print("Are the original private key and the loaded one equal?", 
      loaded_private_key.private_numbers() == private_key.private_numbers()
)