# RSA Oracle
* **Event:** HackTX CTF
* **Problem Type:** Cryptography
* **Point Value / Difficulty:**
* **(Optional) Tools Required / Used:**


## Background Information
- https://en.wikipedia.org/wiki/Modular_arithmetic
- https://en.wikipedia.org/wiki/RSA_(cryptosystem)#Operation

## Solution Idea

We are given a file called `rsa.py` which corresponds to the code that's running on the server.
We notice that we can encrypt any string (even bytes!) and decrypt any ciphertext except the flag.
If we can't decrypt the flag why don't we try decrypting 2*flag?

In [1]:
# Necessary imports if you want to run the server code locally.

! pip install gmpy2
! pip install pycryptodome





I copied `rsa.py` here for convenience, feel free to run this cell if you want to solve the problem locally.

In [None]:
# Decryption oracle

import sys
import random
import gmpy2
from Crypto.Util.number import inverse, bytes_to_long, long_to_bytes
import codecs

def input_fix(string):
     return codecs.decode(string,"unicode_escape")


def gen_prime():
    base = random.getrandbits(1024)
    off = 0
    while True:
        if gmpy2.is_prime(base + off):
            break
        off += 1
    p = base + off

    return p

class RSA(object):
    def __init__(self):
        pass

    def generate(self, p, q, e=0x10001):
        self.p = p
        self.q = q
        self.N = p * q
        self.e = e
        phi = (p-1) * (q-1)
        self.d = inverse(e, phi)

    def encrypt(self, m):
        return pow(m, self.e, self.N)

    def decrypt(self, c):
        return pow(c, self.d, self.N)


def main():
    r = RSA()
    p = gen_prime()
    q = gen_prime()
    r.generate(p, q)

    f = open('FLAG.txt', 'rb')
    flag = f.readlines()[0].strip()
    flag_m = bytes_to_long(flag)

    print("Secret flag", r.encrypt(flag_m))


    # Can encrypt anything, even bytes!
    def encrypt_msg():
        print('input the message: ')
        m = input_fix(input().encode())
        M = bytes_to_long(m.encode())
        print(r.encrypt(M))

    # Can only decrypt numbers :(
    def decrypt_msg():
        print('input the ciphertext (as an integer): ')
        c = input()
        print(c)
        if c.isnumeric():
            dec = r.decrypt(int(c))
            if long_to_bytes(dec) == flag:
                print("No thanks.")
                sys.exit(1)
        else:
            print('Not an integer...')
            sys.exit(1)

        print(dec)

    menu = {
        '1' : encrypt_msg,
        '2' : decrypt_msg
    }

    cnt = 2
    while cnt > 0:
        ""
        options = '''Welcome to the RSA encryption and decryption tool!
        1. encrypt_msg
        2. decrypt_msg
        '''
        print(options)
        print('Select option: ')
        choice = input()
        if choice not in menu.keys():
            print("Not a valid choice...")
            sys.exit(1)

        menu[choice]()

        cnt -= 1

main()

## Solution

We first notice that `cnt = 2` which implies that `while cnt > 0` will only run twice.
This just means that we are only allowed two queries to our oracle which is all we need as we will see soon.

Let us analyze the following two functions.

```python
def encrypt_msg():
    print('input the message: ')
    m = input_fix(input().encode())
    M = bytes_to_long(m.encode())
    print(r.encrypt(M))
```

Nice, we can encrypt any string we want.
More specifically, given some message $M$ we convert it into an integer $m$ using the function `bytes_to_long` (https://pycryptodome.readthedocs.io/en/latest/src/util/util.html).
Textbook RSA encryption takes $m$ as its input and outputs $m^e \pmod N$.

The reason that `input_fix` was included in this problem was to allow the user to encrypt bytes.
Be default, python's input() function will escape the character `\` so if we try to send `\x02` it will output `\\x02` instead of `\x02`.
I added this as a hint since I realized it wasn't so clear.

Now for the decrypt_msg function:
```python
# Can only decrypt numbers :(
def decrypt_msg():
    print('input the ciphertext (as an integer): ')
    c = input()
    print(c)
    if c.isnumeric():
        dec = r.decrypt(int(c))
        if long_to_bytes(dec) == flag:
            print("No thanks.")
            sys.exit(1)
        else:
            print('Not an integer...')
            sys.exit(1)

    print(dec)
```

The decrypt_msg function will allow us to decrypt any ciphertext except the flag.
Since this is textbook RSA the decrypt function will take our ciphertext $c$ and give us $c^d \pmod N$.
The function `long_to_bytes` simply converts our integer into its byte representation (see link above for more details).
We can assume that the flag was encrypted through this service since this service doesn't allow us to decrypt it.

Now for the solution:
1. Let $c_1$ be the ciphertext of the secret flag.
Notice that $c_1 = m^e \pmod N$ since $m$, the flag, was encrypted through this service.
2. Let us encrypt the value `\x02`, this will give us $c_2 = 2^e \pmod N$.
3. Now multiplying the ciphertext will give us $c_1 * c_2 = m^e2^e = (2m)^{e}$.
4. Decrypting $c_1 * c_2$ will give us $(c_1 * c_2)^d = (2m)^{ed} = 2m$.
5. Divide $2m$ by $2$ and call `long_to_bytes(m)` to get the flag.