# CSI4108 (Fundamentals of) Cryptography - Assignment 3

## Student Infomation
Name: Jake Wang


## Import Required Libraries

In [1]:
import random
import requests
import time
import operator
from functools import reduce
from Cryptodome.Util.number import getPrime as get_prime
from Cryptodome.Util.number import getStrongPrime as get_strong_prime
from Cryptodome.Util.number import isPrime as is_prime

## Question 1: Elgamal Public Key Encryption Algorithm

### Implementation

In [2]:
def generate_keypair(q, alpha):
    # Generate a private key within (1, q - 1).
    private_key = random.randint(2, q - 2)
    public_key = pow(alpha, private_key, q)

    return private_key, public_key


def encrypt(message, q, alpha, public_key, k=None):
    if k is None:
        k = random.randint(1, q - 1)

    # Generate ephemeral key s
    s = pow(public_key, k, q)

    c1 = pow(alpha, k, q)
    c2 = (s * message) % q

    return c1, c2


def decrypt(c1, c2, private_key, q):
    # Recover ephemeral key s
    s = pow(c1, private_key, q)
    # Calculate the inverse of s
    s_inverse = pow(s, -1, q)

    decrypted_message = (c2 * s_inverse) % q
    return decrypted_message

### Verification

In [3]:
q = 89
alpha = 13

# Message should be in [0, q - 1]
message = 42

private_key, public_key = generate_keypair(q, alpha)
c1, c2 = encrypt(message, q, alpha, public_key)
decrypted_message = decrypt(c1, c2, private_key, q)

print("Plaintext:", message)
print("Generated keypair:", (private_key, public_key))
print("Ciphertext:", (c1, c2))
print("Decrypted plaintext:", decrypted_message)

Plaintext: 42
Generated keypair: (4, 81)
Ciphertext: (25, 79)
Decrypted plaintext: 42


### Decrypt ciphertext
* $m_1$ is given to be 62
* Choose $m_2 = 10$
* Compute ciphertext of $m_1$ and $m_2$ with $k = 43$:

In [4]:
# System parameters
q = 89
alpha = 13
k = 43
private_key, public_key = (55, 12)

# Encrypt messages
m_1 = 62
e_1 = encrypt(m_1, q, alpha, public_key, k=k)
m_2 = 10
e_2 = encrypt(m_2, q, alpha, public_key, k=k)

print(e_1, e_2)

(41, 69) (41, 14)


Trying to compute $m_2$, given $e_1$, $e_2$:

1. From the algorithm and the known value, we know
    * $s_1 = s_2 = s = K_e ^ k$, where $K_e$ is the public key.
    * $c_{2, 1} = sm_1 = 69, c_{2, 2} = sm_2 = 14$
2. We can compute:
    $$
    \begin{aligned}
        (sm_1) ^ {-1} \cdot (sm_2) &\equiv c_{2, 1} ^ {-1} \cdot c_{2, 2} \pmod q \\
        m_1 ^ {-1} \cdot m_2 &\equiv c_{2, 1} ^ {-1} \cdot c_{2, 2} \pmod q \\
    \end{aligned}
    $$
3. Multiple $m_1$ on both sides of the equation:
    $$
    \begin{aligned}
        m_2 &\equiv c_{2, 1} ^ {-1} \cdot c_{2, 2} \cdot m_1 \pmod q \\
    \end{aligned}
    $$
4. Then we can compute $m_2$ trivially:

In [5]:
print(f"m_2 = {(pow(e_1[1], -1, q) * e_2[1] * m_1) % q}")

m_2 = 10


## Question 2: Miller-Rabin Primality Testing Algorithm

### Implementation

In [6]:
def miller_rabin(n, t):
    if n <= 4:
        return n == 2 or n == 3
    if n & 1 == 0:
        return False

    # Compute d and k such that n - 1 = q * 2 ^ k.
    q = n - 1
    k = 0
    while q & 1 == 0:
        q >>= 1
        k += 1

    tested = set()
    for _ in range(t):
        # Select random integer a in (1, n - 1)
        a = random.randint(2, n - 2)
        while a in tested:
            a = random.randint(2, n - 2)
        tested.add(a)

        x = pow(a, q, n)
        # if a ^ q === 1 (mod n)
        if x == 1:
            # Inconclusive, continue to another round
            continue

        # for j in [0, k - 1]
        for _ in range(k - 1):
            x = pow(x, 2, n)
            # if a ^ (2 ^ j * q) === n - 1 (mod n)
            if x == n - 1:
                # Inconclusive, continue to another round
                break
        else:
            # Didn't encounter a break: composite number
            return False

    # Probably prime
    return True


def find_probable_prime(bits, confidence):
    while True:
        candidate = random.getrandbits(bits)
        if miller_rabin(candidate, confidence):
            return candidate

### Verification

In [7]:
prime_set = {
    prime
    for list in (
        line.strip().split()
        for line in requests.get("https://t5k.org/lists/small/10000.txt").text.splitlines()
        if len(line) == 71
    )
    for prime in list
}

bits = 15
confidence = 6

probable_prime = find_probable_prime(bits, confidence)
print(f"A probable prime: {probable_prime}")

if str(probable_prime) in prime_set:
    print("Found in the table")

A probable prime: 8221
Found in the table


## Question 3: RSA and ECDH Exploration

### RSA

#### Implementation

In [8]:
def generate_keypair(bits):
    p = get_prime(bits)
    q = get_prime(bits)
    n = p * q
    phi = (p - 1) * (q - 1)
    e = 65537
    d = pow(e, -1, phi)
    return (n, e), (n, d, p, q)


def rsa_encrypt(message, public_key, mod_exp=pow):
    n, e = public_key
    return mod_exp(message, e, n)


def rsa_decrypt(ciphertext, private_key, mod_exp=pow):
    n, d, _, _ = private_key
    return mod_exp(ciphertext, d, n)


def rsa_decrypt_crt(ciphertext, private_key):
    n, d, p, q = private_key

    # CRT precalculations
    m_list = [p, q]
    totient_list = [m - 1 for m in m_list]

    M = reduce(operator.mul, m_list)
    M_list = [M // m for m in m_list]

    d_list = [M % m for M, m in zip(M_list, m_list)]
    d_reverse_list = [pow(d, -1, m) for d, m in zip(d_list, m_list)]

    c_list = [d_reverse * M for d_reverse, M in zip(d_reverse_list, M_list)]

    n = ciphertext
    exponent = d

    # CRT fast exponentiation
    n_tuple = (n % m for m in m_list)
    exponent_tuple = (exponent % totient for totient in totient_list)

    result_tuple = (
        pow(n_component, exponent_component, m)
        for n_component, exponent_component, m
        in zip(n_tuple, exponent_tuple, m_list)
    )
    result = sum(
        result_component * c
        for result_component, c in zip(result_tuple, c_list)
    ) % M

    return result

#### Benchmarking (decrypt with or without CRT)

In [9]:
# Generate key pair
bits = 1024
public_key, private_key = generate_keypair(bits)

# Message to be encrypted
message = 476921883457909

# Encryption
start_time = time.time()
ciphertext = rsa_encrypt(message, public_key)
encryption_time = time.time() - start_time

# Decryption
start_time = time.time()
decrypted_message = rsa_decrypt(ciphertext, private_key)
decryption_time_without_crt = time.time() - start_time

# Decryption with CRT
start_time = time.time()
decrypted_message_crt = rsa_decrypt_crt(ciphertext, private_key)
decryption_time_with_crt = time.time() - start_time

print("Original Message:", message)
print("Ciphertext:", ciphertext)
print("Encryption Time:", encryption_time)
print("Decrypted Message (Without CRT):", decrypted_message)
print("Decryption Time (Without CRT):", decryption_time_without_crt)
print("Decrypted Message (With CRT):", decrypted_message_crt)
print("Decryption Time (With CRT):", decryption_time_with_crt)

Original Message: 476921883457909
Ciphertext: 11566844198355385838477940479021552757576863519650336099629353454437660844502692443071529641281925204084117133791999528949189657176491383314868478568282264316769098455107758487009952012379924441642232876888602979139272468524974705408398870449919457077619029487362402099454177236674492198816532972501481344874383336217509892416211632470311744913305334675117496760149890148123721810310834490543336764425504017416821812541888501323420570921140079305106840146159104595668848523098145612103730503192398462742527929075471138632796019382682938692017553750576279390960535704879050867617043137749741788463731778304503063209373
Encryption Time: 0.00021886825561523438
Decrypted Message (Without CRT): 476921883457909
Decryption Time (Without CRT): 0.03512072563171387
Decrypted Message (With CRT): 476921883457909
Decryption Time (With CRT): 0.011171102523803711


#### Observation
* Decryption with CRT is faster than normal decryption.

### ECDH

#### ECDH Implementation

In [10]:
class Point:
    def __init__(self, x: int, y: int):
        self.x = x
        self.y = y

    def __eq__(self, other):
        return self.x == other.x and self.y == other.y

    def __repr__(self):
        return f"Point({self.x}, {self.y})"

    def __str__(self):
        return repr(self)


class ECGroup:
    def __init__(self, p: int, a: int, b: int):
        self.p = p
        self.a = a
        self.b = b

    def add(self, p: Point, q: Point):
        if p == Point(0, 0):
            return q
        if q == Point(0, 0):
            return p

        if p != q:
            slope = (q.y - p.y) * pow(q.x - p.x, -1, self.p) % self.p
        else:
            slope = (3 * p.x * p.x + self.a) * pow(2 * p.y, -1, self.p) % self.p

        x = (slope * slope - p.x - q.x) % self.p
        y = (slope * (p.x - x) - p.y) % self.p

        return Point(x, y)

    def multiply(self, k: int, p: Point):
        result = Point(0, 0)
        addend = p

        while k:
            if k & 1:
                result = self.add(result, addend)
            addend = self.add(addend, addend)
            k >>= 1

        return result

    def y_squared(self, x: int):
        return (x ** 3 + self.a * x + self.b) % self.p

    def __repr__(self):
        return f"ECGroup({self.p}, {self.a}, {self.b})"

    def __str__(self):
        return repr(self)

#### ECDH Verification

In [11]:
# brainpoolP160r1
# Reference: https://datatracker.ietf.org/doc/html/rfc5639#section-3.1
E = ECGroup(
    0xe95e4a5f737059dc60dfc7ad95b3d8139515620f,
    0x340e7be2a280eb74e2be61bada745d97e8f7c300,
    0x1e589a8595423412134faa2dbdec95c8d8675e58
)
G = Point(
    0xbed5af16ea3f6a4f62938c4631eb5af7bdbcdbc3,
    0x1667cb477a1a8ec338f94741669c976316da6321
)

# Verify that G is on the curve
print(f"G is on curve: {E.y_squared(G.x) == pow(G.y, 2, E.p)}")

start_time = time.time()

# Alice side
n_a = random.getrandbits(160)
P_a = E.multiply(n_a, G)

# Bob side
n_b = random.getrandbits(160)
P_b = E.multiply(n_b, G)

# Alice and Bob exchange P_a and P_b

# Alice side
k_a = E.multiply(n_a, P_b)

# Bob side
k_b = E.multiply(n_b, P_a)

end_time = time.time()

print(f"Alice has key: {k_a}")
print(f"Bob has key: {k_b}")
# Two keys should be equal
print(f"Two keys are equal: {k_a == k_b}")

print(f"Ellpased time: {end_time - start_time}")

G is on curve: True
Alice has key: Point(1290523551237518590156283625125266111764555020684, 887800925528475370478952271981502477580143452726)
Bob has key: Point(1290523551237518590156283625125266111764555020684, 887800925528475370478952271981502477580143452726)
Two keys are equal: True
Ellpased time: 0.03154325485229492


#### DH Verification

In [12]:
# 160-bit elliptive curve has 80-bit security.
# To achieve 80-bit security, we need 1228-bit modulus size.
# Reference: https://datatracker.ietf.org/doc/html/rfc3766#section-5

# Generated with `openssl dhparam 1228`
# PEM output:
# -----BEGIN DH PARAMETERS-----
# MIGkAoGaCg9o40rrQWqaxxCdFgwwnvPqJlvLt0M+iS5aOTMnKP6pe4/ogMX5iBCA
# lpsf1veK5u91kKjMtqy1AoRmJm+oQRzVOQRBbRHVNgQNYjdTH0Buop4kv/A18v2B
# aDZmijvPWCfkPmYx49d4pIpsKCIT17igBn+NI6QAGq3TE+cURNtU11yLPKC9Urmd
# QtA+1TL4O6Vy1oTHT97dRwIBAgICAMg=
# -----END DH PARAMETERS-----
# View with `openssl dhparam -inform PEM -in ./dh.pem -check -text -noout`
# DH Parameters: (1228 bit)
# P:
#     0a:0f:68:e3:4a:eb:41:6a:9a:c7:10:9d:16:0c:30:
#     9e:f3:ea:26:5b:cb:b7:43:3e:89:2e:5a:39:33:27:
#     28:fe:a9:7b:8f:e8:80:c5:f9:88:10:80:96:9b:1f:
#     d6:f7:8a:e6:ef:75:90:a8:cc:b6:ac:b5:02:84:66:
#     26:6f:a8:41:1c:d5:39:04:41:6d:11:d5:36:04:0d:
#     62:37:53:1f:40:6e:a2:9e:24:bf:f0:35:f2:fd:81:
#     68:36:66:8a:3b:cf:58:27:e4:3e:66:31:e3:d7:78:
#     a4:8a:6c:28:22:13:d7:b8:a0:06:7f:8d:23:a4:00:
#     1a:ad:d3:13:e7:14:44:db:54:d7:5c:8b:3c:a0:bd:
#     52:b9:9d:42:d0:3e:d5:32:f8:3b:a5:72:d6:84:c7:
#     4f:de:dd:47
# G:    2 (0x2)
# recommended-private-length: 200 bits
# DH parameters appear to be ok.
p = "0a0f68e34aeb416a9ac7109d160c309e" \
    "f3ea265bcbb7433e892e5a39332728fe" \
    "a97b8fe880c5f9881080969b1fd6f78a" \
    "e6ef7590a8ccb6acb5028466266fa841" \
    "1cd53904416d11d536040d6237531f40" \
    "6ea29e24bff035f2fd816836668a3bcf" \
    "5827e43e6631e3d778a48a6c282213d7" \
    "b8a0067f8d23a4001aadd313e71444db" \
    "54d75c8b3ca0bd52b99d42d03ed532f8" \
    "3ba572d684c74fdedd47"
p = int(p, 16)
g = 0x2

start_time = time.time()

# Alice side
n_a = random.getrandbits(160)
P_a = pow(g, n_a, p)

# Bob side
n_b = random.getrandbits(160)
P_b = pow(g, n_b, p)

# Alice and Bob exchange P_a and P_b

# Alice side
k_a = pow(P_b, n_a, p)

# Bob side
k_b = pow(P_a, n_b, p)

end_time = time.time()

print(f"Alice has key: {k_a}")
print(f"Bob has key: {k_b}")
# Two keys should be equal
print(f"Two key are equal: {k_a == k_b}")

print(f"Ellpased time: {end_time - start_time}")

Alice has key: 1183355598691882188483629711617224954819096246507720799341269615982173437104994254197494264856397346563392022081924361694355093065657624286365815466844196697440931532766378865245112698870856222246808909858607337538491920939307630952766356292268040815517699806962037040720053113767873656258783986866473485912332241258155031931440419283945498102392923401459439868677334845
Bob has key: 1183355598691882188483629711617224954819096246507720799341269615982173437104994254197494264856397346563392022081924361694355093065657624286365815466844196697440931532766378865245112698870856222246808909858607337538491920939307630952766356292268040815517699806962037040720053113767873656258783986866473485912332241258155031931440419283945498102392923401459439868677334845
Two key are equal: True
Ellpased time: 0.0041790008544921875


#### Observation
* Under the security level of 80 bits, we can see DH is faster than ECDH algorithm.