#### Challenge 47: Bleichenbacher's PKCS 1.5 Padding Oracle (Complete Case)

[Back to Index](CryptoPalsWalkthroughs_Cobb.ipynb)

In [19]:
from Crypto.Util import number
from Crypto.Random import random
from Crypto.Hash.SHA256 import SHA256Hash

import math
import base64
import cryptopals as cp

from decimal import *

import pdb

<div class="alert alert-block alert-info">   
    
<div class="alert alert-block alert-warning">
    
#### **Cryptanalytic MVP award**
    
This is an extraordinarily useful attack. PKCS#1v15 padding, despite being totally insecure, is the default padding used by RSA implementations. The OAEP standard that replaces it is not widely implemented. This attack routinely breaks SSL/TLS.

</div>
    


This is a continuation of challenge #47; it implements the complete BB'98 attack.

Set yourself up the way you did in #47, but this time generate a 768 bit modulus.

To make the attack work with a realistic RSA keypair, you need to reproduce step 2b from the paper, and your implementation of Step 3 needs to handle multiple ranges.

The full Bleichenbacher attack works basically like this:

- Starting from the smallest 's' that could possibly produce a plaintext bigger than 2B, iteratively search for an 's' that produces a conformant plaintext.
- For our known 's1' and 'n', solve m1=m0s1-rn (again: just a definition of modular multiplication) for 'r', the number of times we've wrapped the modulus.
- 'm0' and 'm1' are unknowns, but we know both are conformant PKCS#1v1.5 plaintexts, and so are between [2B,3B].
- We substitute the known bounds for both, leaving only 'r' free, and solve for a range of possible 'r' values. This range should be small!
- Solve m1=m0s1-rn again but this time for 'm0', plugging in each value of 'r' we generated in the last step. This gives us new intervals to work with. Rule out any interval that is outside 2B,3B.
- Repeat the process for successively higher values of 's'. Eventually, this process will get us down to just one interval, whereupon we're back to exercise #47.

What happens when we get down to one interval is, we stop blindly incrementing 's'; instead, we start rapidly growing 'r' and backing it out to 's' values by solving m1=m0s1-rn for 's' instead of 'r' or 'm0'. So much algebra! Make your teenage son do it for you! *Note: does not work well in practice* 

</div>    

---

In [20]:
KEY_LENGTH = 768

if KEY_LENGTH < 1024:

    valid_params = False

    while not(valid_params):

        print('.', end='')
        p = number.getPrime(768 // 2)
        q = number.getPrime(768 // 2)

        n = (p * q)

        et = (p-1) * (q-1)
        e = 3

        d = cp.invmod(e, et)

        # Check parameters:
        PT = random.randint(0, 2**32-1)
        valid_params = (pow(pow(PT, e, n), d, n) == PT)

else:
    
    e, d, n = cp.genRSA_keypair(2048)

print(f"\nGenerated working parameters:\n")
print(f"e={e}\nd={d}\nn={n}")


........
Generated working parameters:

e=3
d=637771288823468390827832444586248412262575068825454780825261485341355807699823956180692567152552454513507818378933403549397600271050148710484685445459847966301765686209429357272894149109037974163413358690835702194446373704484302667
n=956656933235202586241748666879372618393862603238182171237892228012033711549735934271038850728828681770261727568400167195967814020157161457501887268683553846092122319074156798162812975049645077997200212312077125622271677218150810871


In [21]:
# Set decimal precision to handle math  for this challenge...
getcontext().prec = n.bit_length()

In [22]:
def remove_pkcs15_padding(byte_data, n):
    
    k = math.ceil(n.bit_length() / 8)
    
    if len(byte_data)==(k-1):        
        byte_data = b'\x00' + byte_data
    
    if not(len(byte_data) == k):
        return(False)

    if not(byte_data[1] == 0x02):
        return(False)
    
    data_idx = byte_data.find(b'\x00', 2) + 1   
    payload = byte_data[data_idx:]
    
    return(payload)
    
def validate_pkcs15_padding(byte_data, n):

    k = math.ceil(n.bit_length() / 8)
    
    if len(byte_data)==(k-1):        
        byte_data = b'\x00' + byte_data
    
    if not(len(byte_data) == k):
        return(False)

    if not(byte_data[1] == 0x02):
        return(False)
    
    data_idx = byte_data.find(b'\x00', 3) + 1
    
    return not(data_idx == 0)

def simple_validate_padding(byte_data, n):

    k = math.ceil(n.bit_length() / 8)
    return(len(byte_data)==(k-1) and (byte_data[0] == 0x02))

def pkcs15_pad(data, n):
    
    k = math.ceil(n.bit_length() / 8)
    data_len = len(data)
    ps_len = k - data_len - 3
    
    b00 = b'\x00'
    BT = b'\x02'
    PS = []
    
    for ii in range(ps_len):
        PS.append(random.randint(1, 255))
    
    EB = b00 + BT + bytes(PS) + b00 + data
    
    return(EB)

def bytes_to_bigint(byte_data):
    
    return(int(byte_data.hex(), 16))

def bigint_to_bytes(int_data):
    
    hex_data = hex(int_data)[2:]
    if len(hex_data) % 2:
        hex_data = '0' + hex_data
    return(bytes.fromhex(hex_data))


In [23]:
def challenge47_oracle(ciphertext):
    
    plaintext = int(pow(ciphertext, d, n))
    plaintext_hex = hex(plaintext)[2:]
    if (len(plaintext_hex) % 2):
        plaintext_hex = '0' + plaintext_hex
    plaintext_bytes = bytes.fromhex(plaintext_hex)
    
    return( simple_validate_padding(plaintext_bytes, n) )
        

In [24]:
m = bytes_to_bigint(pkcs15_pad(b'If something is free, you\'re not the customer; you\'re the product.', n))
c = pow(m, e, n)
true_p = pow(c, d, n)

In [25]:
# Implement Step #3 - Narrowing set of solutions.
    
def update_intervals(M_Last, s):
    
    M = []

    for interval in M_Last:

        last_a, last_b = Decimal(interval[1]), Decimal(interval[2])

        r_min = math.ceil((last_a*s - 3*B + 1) / n)
        r_max = math.floor((last_b*s - 2*B) / n)
        
        for r in range(r_min, r_max+1):
            
            new_a = int(max(last_a, math.ceil((2*B + r*n) / s)))
            new_b = int(min(last_b, math.floor((3*B - 1 + r*n) / s ) ))
            
            if (new_a > new_b):
                raise Exception('CRAP')
            
            if len(M) == 0:                
                M.append([r, new_a, new_b])
                    
            else:
                    
                M_min = sorted(M, key=lambda x: x[1])[0][1]
                M_max = sorted(M, key=lambda x: x[2])[-1][2]
                
                if new_b < M_min or new_a > M_max:  
                    M.append([r, new_a, new_b])                        
                else:                        
                    for this_interval in M:                            
                        this_min, this_max = this_interval[1], this_interval[2]
                        if (new_a < this_min) and (new_b < this_max):
                            this_interval[1] = new_a
                        elif (new_a > this_min) and (new_b > this_max):
                            this_interval[2] = new_b
                    
    if M == []:
        return(M_Last)
        print('Houston, we have a problem')
    else:
        return(M)

In [26]:
# Implement Step 2.a - find an initial s that results in conforming PKCS padding:

k = math.ceil(n.bit_length() / 8)
B = Decimal(2**(8*(k-2)))
s0 = math.ceil(n / (3*B))
n_queries = 0

while True:
    
    c_ = c*(pow(s0, e, n)) % n
    if(challenge47_oracle(c_)):
        break
    n_queries += 1
    s0 += 1    

In [27]:
s = s0

M = [[0, 2*B, 3*B - 1]]
M = update_intervals(M, s)

In [28]:
done = False

while not(done):

    conforming = False
    if len(M) > 1:
        
        # Step 2.b:  Searching with more than one interval left
        while not(conforming):
            s += 1
            c_ = int(c*(pow(s, e, n)) % n)            
            conforming = challenge47_oracle(c_)
            n_queries += 1
            
    else: 
        
        # Step 2.c:  Searching with one interval left
        a, b = M[0][1], M[0][2]
        r = math.ceil(2*((b*s - 2*B) / n))
        
        while not(conforming):
            
            s_min = math.ceil((2*B + r*n) / b) 
            s_max = math.floor((3*B + r*n) / a) 
            
            for s in range(s_min, s_max + 1):
                c_ = int(c*(pow(s, e, n)) % n)
                conforming = challenge47_oracle(c_)
                if conforming:
                    break
                
            r += 1
            
    M = update_intervals(M, s)
    
    if (len(M) == 1) and (M[0][1] == M[0][2]): 
        done = True

In [29]:
recovered_msg = remove_pkcs15_padding(bigint_to_bytes(M[0][1]), n)
print(recovered_msg)


b"If something is free, you're not the customer; you're the product."


[Back to Index](CryptoPalsWalkthroughs_Cobb.ipynb)