#### Challenge 47: Bleichenbacher's PKCS 1.5 Padding Oracle (Simple Case)

[Back to Index](CryptoPalsWalkthroughs_Cobb.ipynb)

In [1]:
from Crypto.Util import number
from Crypto.Random import random
from Crypto.Hash.SHA256 import SHA256Hash

import math
import base64
import cryptopals as cp

from decimal import *

import pdb

<div class="alert alert-block alert-info">   
    
<div class="alert alert-block alert-warning">
    
#### **Degree of difficulty: Moderate**
These next two challenges are the hardest in the entire set.

</div>
    
Let us Google this for you: [Chosen ciphertext attacks against protocols based on the RSA encryption standard](http://lmgtfy.com/?q=%22Chosen+ciphertext+attacks+against+protocols+based+on+the+RSA+encryption+standard%22)

This is Bleichenbacher from CRYPTO '98; I get a bunch of .ps versions on the first search page.

Read the paper. It describes a padding oracle attack on PKCS#1v1.5. The attack is similar in spirit to the CBC padding oracle you built earlier; it's an "adaptive chosen ciphertext attack", which means you start with a valid ciphertext and repeatedly corrupt it, bouncing the adulterated ciphertexts off the target to learn things about the original.

This is a common flaw even in modern cryptosystems that use RSA.

It's also the most fun you can have building a crypto attack. It involves 9th grade math, but also has you implementing an algorithm that is complex on par with finding a minimum cost spanning tree.

The setup:

- Build an oracle function, just like you did in the last exercise, but have it check for plaintext[0] == 0 and plaintext[1] == 2.
- Generate a 256 bit keypair (that is, p and q will each be 128 bit primes), [n, e, d].
- Plug d and n into your oracle function.
- PKCS1.5-pad a short message, like "kick it, CC", and call it "m". Encrypt to to get "c".
- Decrypt "c" using your padding oracle.

For this challenge, we've used an untenably small RSA modulus (you could factor this keypair instantly). That's because this exercise targets a specific step in the Bleichenbacher paper --- Step 2c, which implements a fast, nearly O(log n) search for the plaintext.

Things you want to keep in mind as you read the paper:

- RSA ciphertexts are just numbers.
- RSA is "homomorphic" with respect to multiplication, which means you can multiply c * RSA(2) to get a c' that will decrypt to plaintext * 2. This is mindbending but easy to see if you play with it in code --- try multiplying ciphertexts with the RSA encryptions of numbers so you know you grok it.
- What you need to grok for this challenge is that Bleichenbacher uses multiplication on ciphertexts the way the CBC oracle uses XORs of random blocks.
- A PKCS#1v1.5 conformant plaintext, one that starts with 00:02, must be a number between 02:00:00...00 and 02:FF:FF..FF --- in other words, 2B and 3B-1, where B is the bit size of the modulus minus the first 16 bits. When you see 2B and 3B, that's the idea the paper is playing with.

To decrypt "c", you'll need Step 2a from the paper (the search for the first "s" that, when encrypted and multiplied with the ciphertext, produces a conformant plaintext), Step 2c, the fast O(log n) search, and Step 3.

Your Step 3 code is probably not going to need to handle multiple ranges.

We recommend you just use the raw math from paper (check, check, double check your translation to code) and not spend too much time trying to grok how the math works.

</div>    

---

<div class="alert alert-block alert-info">   

- Generate a 256 bit keypair (that is, p and q will each be 128 bit primes), [n, e, d].
    
</div> 

In [2]:
valid_params = False

while not(valid_params):
    
    print('.', end='')
    p = number.getPrime(256 // 2)
    q = number.getPrime(256 // 2)

    n = (p * q)

    et = (p-1) * (q-1)
    e = 3

    d = cp.invmod(e, et)

    # Check parameters:
    PT = random.randint(0, 2**32-1)
    valid_params = (pow(pow(PT, e, n), d, n) == PT)

print(f"\nGenerated working parameters:\n")
print(f"e={e}\nd={d}\nn={n}")

........
Generated working parameters:

e=3
d=27333576007230398530034453279377827125001860239976180339592623082207262319915
n=41000364010845597795051679919066740687908140166231421125439837447715752896331


In [3]:
# Set decimal precision to handle math  for this challenge...
getcontext().prec = int(math.log2(n))

In [4]:
# e, d, n = cp.genRSA_keypair(1024)

<div class="alert alert-block alert-info">   
    
- Build an oracle function, just like you did in the last exercise, but have it check for plaintext[0] == 0 and plaintext[1] == 2.
- Plug d and n into your oracle function.
- PKCS1.5-pad a short message, like "kick it, CC", and call it "m". Encrypt to to get "c".

</div> 

Here is [RFC2313](https://tools.ietf.org/html/rfc2313) that describes PKCS1.5 padding for RSA.  For encryption, we assume block type 02.  

The encryption block `EB` is constructed as:

`EB = 00 || BT || PS || 00 || D`

where:

- `BT = 02`
- `PS` is the padding string, which is pseudorandomly generated w/ no `00`'s
- `D` is the data

The length of the encryption block `EB` is equal to `k`, where `k` is the length of the modulus in octets (bytes).  

In [5]:
def remove_pkcs15_padding(byte_data, n):
    
    k = math.ceil(n.bit_length() / 8)
    
    if len(byte_data)==(k-1):        
        byte_data = b'\x00' + byte_data
    
    if not(len(byte_data) == k):
        return(False)

    if not(byte_data[1] == 0x02):
        return(False)
    
    data_idx = byte_data.find(b'\x00', 2) + 1   
    payload = byte_data[data_idx:]
    
    return(payload)
    
def validate_pkcs15_padding(byte_data, n):

    k = math.ceil(n.bit_length() / 8)
    
    if len(byte_data)==(k-1):        
        byte_data = b'\x00' + byte_data
    
    if not(len(byte_data) == k):
        return(False)

    if not(byte_data[1] == 0x02):
        return(False)
    
    data_idx = byte_data.find(b'\x00', 3) + 1
    
    return not(data_idx == 0)

def simple_validate_padding(byte_data, n):

    k = math.ceil(n.bit_length() / 8)
    return(len(byte_data)==(k-1) and (byte_data[0] == 0x02))

def pkcs15_pad(data, n):
    
    k = math.ceil(n.bit_length() / 8)
    data_len = len(data)
    ps_len = k - data_len - 3
    
    b00 = b'\x00'
    BT = b'\x02'
    PS = []
    
    for ii in range(ps_len):
        PS.append(random.randint(1, 255))
    
    EB = b00 + BT + bytes(PS) + b00 + data
    
    return(EB)

def bytes_to_bigint(byte_data):
    
    return(int(byte_data.hex(), 16))

def bigint_to_bytes(int_data):
    
    hex_data = hex(int_data)[2:]
    if len(hex_data) % 2:
        hex_data = '0' + hex_data
    return(bytes.fromhex(hex_data))


In [6]:
def challenge47_oracle(ciphertext):
    
    plaintext = int(pow(ciphertext, d, n))
    plaintext_hex = hex(plaintext)[2:]
    if (len(plaintext_hex) % 2):
        plaintext_hex = '0' + plaintext_hex
    plaintext_bytes = bytes.fromhex(plaintext_hex)
    
    return( simple_validate_padding(plaintext_bytes, n) )
        

In [7]:
x = pkcs15_pad(b'help', n)
x_int = bytes_to_bigint(x)
bigint_to_bytes(x_int)

b'\x02\xeb\xf1I\xbc\x02\xcc\x90\xed\xd7\xb3\x8b\x98\x9fMK\xa8\x9a\xbdY\xb6)>\x02\x93\xe8\x00help'

In [8]:
# Test everything out to make sure it works....
m = bytes_to_bigint(pkcs15_pad(b'This is a test message', n))
c = pow(m, e, n)
m_r = bigint_to_bytes(pow(c, d, n))
assert(validate_pkcs15_padding(m_r, n))
m_r = remove_pkcs15_padding(m_r, n)
print(m_r.decode())

This is a test message


In [9]:
m = bytes_to_bigint(pkcs15_pad(b'kick it, cc', n))
c = pow(m, e, n)
true_p = pow(c, d, n)

<div class="alert alert-block alert-info">   

- Decrypt "c" using your padding oracle.
    
</div>

In [10]:
# Implement Step #3 - Narrowing set of solutions.
    
def update_intervals(M_Last, s):
    
    M = []

    for interval in M_Last:

        last_a, last_b = Decimal(interval[1]), Decimal(interval[2])

        r_min = math.ceil((last_a*s - 3*B + 1) / n)
        r_max = math.floor((last_b*s - 2*B) / n)
        
        for r in range(r_min, r_max+1):
            
            new_a = int(max(last_a, math.ceil((2*B + r*n) / s)))
            new_b = int(min(last_b, math.floor((3*B - 1 + r*n) / s ) ))
            if (new_a > new_b):
                raise Exception('CRAP')
            
            if len(M) == 0:                
                M.append([r, new_a, new_b])
                    
            else:
                    
                M_min = sorted(M, key=lambda x: x[1])[0][1]
                M_max = sorted(M, key=lambda x: x[2])[-1][2]
                if new_b < M_min or new_a > M_max:  
                    M.append([r, new_a, new_b])                        
                else:                        
                    for this_interval in M:                            
                        this_min, this_max = this_interval[1], this_interval[2]
                        if (new_a < this_min) and (new_b < this_max):
                            this_interval[1] = new_a
                        elif (new_a > this_min) and (new_b > this_max):
                            this_interval[2] = new_b
                    
    if M == []:
        return(M_Last)
        print('Houston, we have a problem')
    else:
        return(M)

In [11]:
# Implement Step 2.a - find an initial s that results in conforming PKCS padding:

k = math.ceil(n.bit_length() / 8)
B = Decimal(2**(8*(k-2)))
s0 = math.ceil(n / (3*B))
n_queries = 0

while True:
    
    c_ = c*(pow(s0, e, n)) % n
    if(challenge47_oracle(c_)):
        break
    n_queries += 1
    s0 += 1    

In [12]:
s = s0

M = [[0, 2*B, 3*B - 1]]
M = update_intervals(M, s)

In [13]:
done = False

while not(done):

    conforming = False
    if len(M) > 1:
        
        # Step 2.b:  Searching with more than one interval left
        while not(conforming):
            s += 1
            c_ = int(c*(pow(s, e, n)) % n)            
            conforming = challenge47_oracle(c_)
            n_queries += 1
            
    else: 
        
        # Step 2.c:  Searching with one interval left
        a, b = M[0][1], M[0][2]
        r = math.ceil(2*((b*s - 2*B) / n))
        
        while not(conforming):
            
            s_min = math.ceil((2*B + r*n) / b) 
            s_max = math.floor((3*B + r*n) / a) 
            
            for s in range(s_min, s_max + 1):
                c_ = int(c*(pow(s, e, n)) % n)
                conforming = challenge47_oracle(c_)
                if conforming:
                    break
                
            r += 1
            
    M = update_intervals(M, s)
    
    if (len(M) == 1) and (M[0][1] == M[0][2]): 
        done = True

In [14]:
recovered_msg = remove_pkcs15_padding(bigint_to_bytes(M[0][1]), n)
print(recovered_msg)


b'kick it, cc'


[Back to Index](CryptoPalsWalkthroughs_Cobb.ipynb)