#### Challenge 46: RSA parity oracle

[Back to Index](CryptoPalsWalkthroughs_Cobb.ipynb)

In [None]:
from Crypto.Util import number
from Crypto.Random import random
from Crypto.Hash.SHA256 import SHA256Hash

import base64
import cryptopals as cp

<div class="alert alert-block alert-info">   

<div class="alert alert-block alert-warning">
    
#### **When does this ever happen?**

This is a bit of a toy problem, but it's very helpful for understanding what RSA is doing (and also for why pure number-theoretic encryption is terrifying). Trust us, you want to do this before trying the next challenge. Also, it's fun.

</div>
</div>    

<div class="alert alert-block alert-info">   
    
Generate a 1024 bit RSA key pair.
    
</div>

In [None]:
e,d,n = cp.genRSA_keypair(1024)

<div class="alert alert-block alert-info">   

Write an oracle function that uses the private key to answer the question "is the plaintext of this message even or odd" (is the last bit of the message 0 or 1). Imagine for instance a server that accepted RSA-encrypted messages and checked the parity of their decryption to validate them, and spat out an error if they were of the wrong parity.

Anyways: function returning true or false based on whether the decrypted plaintext was even or odd, and nothing else.

</div>

In [None]:
def pt_is_odd(ciphertext):
    """Return True of plaintext is odd.  False if plaintext is even."""
    plaintext = pow(ciphertext, d, n)
    return((plaintext % 2) == 1)

<div class="alert alert-block alert-info">   
    
Take the following string and un-Base64 it in your code (without looking at it!) and encrypt it to the public key, creating a ciphertext:

`VGhhdCdzIHdoeSBJIGZvdW5kIHlvdSBkb24ndCBwbGF5IGFyb3VuZCB3aXRoIHRoZSBGdW5reSBDb2xkIE1lZGluYQ==`

</div>

In [None]:
s = 'VGhhdCdzIHdoeSBJIGZvdW5kIHlvdSBkb24ndCBwbGF5IGFyb3VuZCB3aXRoIHRoZSBGdW5reSBDb2xkIE1lZGluYQ=='
s_int = int(base64.b64decode(s).hex(), 16)
ciphertext = pow(s_int, e, n)

<div class="alert alert-block alert-info">   

With your oracle function, you can trivially decrypt the message.

Here's why:

- RSA ciphertexts are just numbers. You can do trivial math on them. You can for instance multiply a ciphertext by the RSA-encryption of another number; the corresponding plaintext will be the product of those two numbers.
- If you double a ciphertext (multiply it by `(2**e)%n)`, the resulting plaintext will (obviously) be either even or odd.
- If the plaintext after doubling is even, doubling the plaintext didn't wrap the modulus --- the modulus is a prime number. That means the plaintext is less than half the modulus.

You can repeatedly apply this heuristic, once per bit of the message, checking your oracle function each time.

Your decryption function starts with bounds for the plaintext of `[0,n]`.

Each iteration of the decryption cuts the bounds in half; either the upper bound is reduced by half, or the lower bound is.

After `log2(n)` iterations, you have the decryption of the message.

Print the upper bound of the message as a string at each iteration; you'll see the message decrypt "hollywood style".

Decrypt the string (after encrypting it to a hidden private key) above.

</div>

In [None]:
def print_it(x):
    
    hex_x = hex(x)[2:]
    if len(hex_x) % 2:
        hex_x = '0' + hex_x
    print(bytes.fromhex(hex_x))

In [91]:
DISPLAY_PROGRESS = False

lower_bound = 0
upper_bound = n

tmp = ciphertext

while (upper_bound - lower_bound) > 0:
    tmp = (tmp * pow(2, e, n)) % n
    if pt_is_odd(tmp):
        #lower_bound += (upper_bound - lower_bound) // 2
        lower_bound = (lower_bound + upper_bound) // 2
    else:        
        #upper_bound -= (upper_bound - lower_bound) // 2
        upper_bound = (lower_bound + upper_bound) // 2
        
    if DISPLAY_PROGRESS:
        
        print_it(upper_bound)
       
print_it(upper_bound)

b"That's why I found you don't play around with the Funky Cold Medin\x1c"


---

So close, but not quite.  The last byte isn't being recovered properly...probably an issue related to integer math / rounding.

I tried floats and they don't provide nearly enough precision.  Decimals are the thing to use here, and then set precision high enough to handle the big #'s we're dealing with.

---

In [137]:
?Decimal

[1;31mInit signature:[0m [0mDecimal[0m[1;33m([0m[0mvalue[0m[1;33m=[0m[1;34m'0'[0m[1;33m,[0m [0mcontext[0m[1;33m=[0m[1;32mNone[0m[1;33m)[0m[1;33m[0m[1;33m[0m[0m
[1;31mDocstring:[0m     
Construct a new Decimal object. 'value' can be an integer, string, tuple,
or another Decimal object. If no value is given, return Decimal('0'). The
context does not affect the conversion and is only passed to determine if
the InvalidOperation trap is active.
[1;31mFile:[0m           c:\programdata\anaconda3\lib\decimal.py
[1;31mType:[0m           type
[1;31mSubclasses:[0m     


In [90]:
DISPLAY_PROGRESS = False

import math
from decimal import *

getcontext().prec = int(math.log2(n))

lower_bound = Decimal(0)
upper_bound = Decimal(n)

tmp = ciphertext

while (upper_bound - lower_bound) >= 0.5:
    tmp = (tmp * pow(2, e, n)) % n
    if pt_is_odd(tmp):
        lower_bound += (upper_bound - lower_bound) / 2
    else:        
        upper_bound -= (upper_bound - lower_bound) / 2
        
    if DISPLAY_PROGRESS:
        
        print_it(int(upper_bound))
       
print_it(int(upper_bound))

b"That's why I found you don't play around with the Funky Cold Medina"


Let's try to understand what's going on here a little better.  Let's choose some small RSA parameters:

In [99]:
p = 11
q = 17
n = p*q
et = (p-1)*(q-1)
e=3
d = cp.invmod(e, et)
print(f"e={e}, d={d}, n={n}")

e=3, d=107, n=187


Make sure my Simple RSA works properly:

In [100]:
x = 7
assert(((((x**e)%n)**d) % n) == x)

Now, demonstrate that math on the ciphertext ==> same math on the plaintexts:

In [101]:
pt = 3
ct = (pt**e) % n

The challenge says that:
    
    - If you double a ciphertext (multiply it by `(2**e) % n)`, the resulting plaintext will (obviously) be either even or odd. 
    - If the plaintext after doubling is even, doubling the plaintext didn't wrap the modulus --- the modulus is a prime number. That means the plaintext is less than half the modulus.
    
The second one took me a bit to grasp.  Since the modulus is a prime #, it is odd.  Doubling a # should always result in an even result in our normal math world.  But, if the doubled # is bigger than the modulus, the result is the result minus the modulus -- so an even # minus the odd modulus will be odd.  

Therefore, if doubling a # gives an even result, the original # was less than half the modulus.  If doubling it gives an odd result, the original was > than the modulus.  

Doing this iteratively, we're actually multiplying by powers of two:  `2**1 = 2`, `2**2 = 4`, `2**3 = 8`, etc. and narrowing down the possible value of the plaintext based on the parity of the result.


Demonstrate: If I multiply the ciphertext by an "encrypted" 2, it will also multiply the plaintext by 2:

In [120]:
ct = (ct * (2**e % n))
pt = (ct**d) % n
print(pt)

67


In [141]:
import numpy as np

pt = 75 
original_ct = (pt**e) % n

lower_bound = 0
upper_bound = n

count = 0

n_bits = int(math.log2(n))+1

for bit_idx in range(n_bits):
    
    print()
    print(f"Lower bound = {lower_bound}")
    print(f"Upper Bound = {upper_bound}")
    print()
    multiplier = (2**(bit_idx+1))
    print(f"bit = {bit_idx+1}")
    print(f"Multiplier = {multiplier}")
    ct = (original_ct * (multiplier**e % n) % n)
    
    # Peak at the current pt:
    print(f"Oracle saw: {ct**d % n}")
    
    pt_odd = pt_is_odd(ct)
    if pt_odd:
        print(f"PT was odd, so original PT is > ({lower_bound} + {n/multiplier}) = {lower_bound + n/multiplier}")
    else:
        print(f"PT was even, so original PT is < ({upper_bound} - {n/multiplier}) = {upper_bound - n/multiplier}")

    if pt_odd:
        lower_bound += (upper_bound - lower_bound) / 2
    else:        
        upper_bound -= (upper_bound - lower_bound) / 2

print()
print(f"Lower bound = {lower_bound}")
print(f"Upper Bound = {upper_bound}")
print(f"Guessed 'PT' = {int(np.round((upper_bound + lower_bound)/2))}")


Lower bound = 0
Upper Bound = 187

bit = 1
Multiplier = 2
Oracle saw: 150
PT was even, so original PT is < (187 - 93.5) = 93.5

Lower bound = 0
Upper Bound = 93.5

bit = 2
Multiplier = 4
Oracle saw: 113
PT was odd, so original PT is > (0 + 46.75) = 46.75

Lower bound = 46.75
Upper Bound = 93.5

bit = 3
Multiplier = 8
Oracle saw: 39
PT was odd, so original PT is > (46.75 + 23.375) = 70.125

Lower bound = 70.125
Upper Bound = 93.5

bit = 4
Multiplier = 16
Oracle saw: 78
PT was even, so original PT is < (93.5 - 11.6875) = 81.8125

Lower bound = 70.125
Upper Bound = 81.8125

bit = 5
Multiplier = 32
Oracle saw: 156
PT was even, so original PT is < (81.8125 - 5.84375) = 75.96875

Lower bound = 70.125
Upper Bound = 75.96875

bit = 6
Multiplier = 64
Oracle saw: 125
PT was odd, so original PT is > (70.125 + 2.921875) = 73.046875

Lower bound = 73.046875
Upper Bound = 75.96875

bit = 7
Multiplier = 128
Oracle saw: 63
PT was odd, so original PT is > (73.046875 + 1.4609375) = 74.5078125

Lower bo

[Back to Index](CryptoPalsWalkthroughs_Cobb.ipynb)