#### Challenge 53:  Kelsey and Schneier's Expandable Messages

[Back to Index](CryptoPalsWalkthroughs_Cobb.ipynb)

In [1]:
from Crypto.Random import random
from Crypto.Cipher import AES
from Crypto.Cipher import Blowfish
import math
import cryptopals as cp
import pdb

<div class="alert alert-block alert-info">   

One of the basic yardsticks we use to judge a cryptographic hash function is its resistance to second preimage attacks. That means that if I give you `x` and `y` such that `H(x) = y`, you should have a tough time finding `x'` such that `H(x') = H(x) = y`.

How tough? Brute-force tough. For a `2^b` hash function, we want second preimage attacks to cost `2^b` operations.

This turns out not to be the case for very long messages.

Consider the problem we're trying to solve: we want to find a message that will collide with `H(x)` in the very last block. But there are a ton of intermediate blocks, each with its own intermediate hash state.

What if we could collide into one of those? We could then append all the following blocks from the original message to produce the original `H(x)`. Almost.

We can't do this exactly because the padding will mess things up.

What we need are expandable messages.

In the last problem we used multicollisions to produce `2^n` colliding messages for `n*2^(b/2)` effort. We can use the same principles to produce a set of messages of length `(k, k + 2^k - 1)` for a given `k`.
    
</div>

--- 

Here's the reference paper by Kelsey and Schneier:

[https://www.schneier.com/academic/archives/2005/01/second_preimages_on.html](https://www.schneier.com/academic/archives/2005/01/second_preimages_on.html)
    
    
> __Abstract.__  We expand a previous result of Dean [Dea99] to provide asecond preimage attack on alln-bit iterated hash functions with Damg ̊ard-Merkle strengthening andn-bit intermediate states, allowing a secondpreimage to be found for a 2k-message-block message with aboutk ×2n/2+1+2n−k+1 work. Using RIPEMD-160 as an example, our attack canfind a second preimage for a 2^60 byte message in about 2^106 work, rather than the previously expected 2^160 work. We also provide slightly cheaperways to find multicollisions than the method of Joux [Jou04]. Both of these results are based on expandable messages–patterns for producing messages of varying length, which all collide on the intermediate hash result immediately after processing the message. We provide an algorithm for finding expandable messages for any n-bit hash function built usingthe Damgard-Merkle construction, which requires only a small multiple of the work done to find a single collision in the hash function.
---

<div class="alert alert-block alert-info">  
    
Here's how:

- Starting from the hash function's initial state, find a collision between a single-block message and a message of `2^(k-1)+1` blocks. DO NOT hash the entire long message each time. Choose `2^(k-1)` dummy blocks, hash those, then focus on the last block.
    
</div>

---

In [2]:
b = 16
block_size = b//8
k = 5
N_blocks = 2**(k-1)+1
N_dummy_blocks = N_blocks - 1

# Start with some random b-bit message and initial state:
message = random.Random.get_random_bytes(block_size)
initial_state = random.Random.get_random_bytes(block_size)
original_H = cp.MD(message, initial_state, block_size)

# Now search for a collision...

collision_found = False
while not(collision_found):

    dummy_blocks = b''
    for ii in range(N_dummy_blocks):
        dummy_blocks += random.Random.get_random_bytes(block_size)    

    db = 0
    max_msg = 2**b
    dummy_state = cp.MD(dummy_blocks, initial_state, block_size)
    
    while not(collision_found) and (db < max_msg):

        db_bytes = db.to_bytes(block_size, 'little')
        this_H = cp.MD(db_bytes, dummy_state, block_size)

        if original_H == this_H:
            colliding_message = dummy_blocks + db_bytes
            colliding_H = cp.MD(colliding_message, initial_state, block_size)
            collision_found = True
        else:
            db += 1

    print('No collision found for this dummy state.  Resetting.')
                
print("\nCollision Found:\n")
print(f"Original Message:  {message}")
print(f"Original Hash Digest:  {original_H}\n")
print(f"Colliding Message for k={k}:  {colliding_message}")
print(f"Colliding Hash Digest:  {colliding_H}")

No collision found for this dummy state.  Resetting.

Collision Found:

Original Message:  b'E\xc3'
Original Hash Digest:  b'\xfd\xc7'

Colliding Message for k=5:  b'-\xf8\xb9\x04\xe9X\x0c\x81\rr\x1b6\x19\xd7\x8buo@\xc4?Fn\xa7g\x16\xf8P\xfc\xa6\xefV\xe9@#'
Colliding Hash Digest:  b'\xfd\xc7'


<div class="alert alert-block alert-info">  

- Take the output state from the first step. Use this as your new initial state and find another collision between a single-block message and a message of `2^(k-2)+1` blocks.

</div>

In [3]:
starting_state = colliding_H
N_blocks = 2**(k-2)+1
N_dummy_blocks = N_blocks - 1

# Now search for a collision...

collision_found = False
while not(collision_found):

    dummy_blocks = b''
    for ii in range(N_dummy_blocks):
        dummy_blocks += random.Random.get_random_bytes(block_size)    

    db = 0
    max_msg = 2**b
    dummy_state = cp.MD(dummy_blocks, starting_state, block_size)
    
    while not(collision_found) and (db < max_msg):

        db_bytes = db.to_bytes(block_size, 'little')
        this_H = cp.MD(db_bytes, dummy_state, block_size)

        if original_H == this_H:
            new_colliding_message = colliding_message + dummy_blocks + db_bytes
            colliding_H = cp.MD(new_colliding_message, initial_state, block_size)
            collision_found = True
        else:
            db += 1

    print('No collision found for this dummy state.  Resetting.')
                
print("\nCollision Found:\n")
print(f"Original Message:  {message}")
print(f"Original Hash Digest:  {original_H}\n")
print(f"Colliding Message for k={k}:  {new_colliding_message}")
print(f"Colliding Hash Digest:  {colliding_H}")

No collision found for this dummy state.  Resetting.

Collision Found:

Original Message:  b'E\xc3'
Original Hash Digest:  b'\xfd\xc7'

Colliding Message for k=5:  b'-\xf8\xb9\x04\xe9X\x0c\x81\rr\x1b6\x19\xd7\x8buo@\xc4?Fn\xa7g\x16\xf8P\xfc\xa6\xefV\xe9@#\xd3\xb1A\x9c\xf6\xe9\xf0.yQ\x04\xbc\x95\xfc\x93\xf0%@'
Colliding Hash Digest:  b'\xfd\xc7'


<div class="alert alert-block alert-info">  
    
- Repeat this process `k` total times. Your last collision should be between a single-block message and a message of `2^0+1 = 2` blocks.
    
</div>

In [4]:
k = 5
max_msg = 2**b
round_initial_state = initial_state
colliding_message = b''
collision_list = []

for ii in range(1, k + 1):

    print(f"Processing:  ii={ii}")
    N_blocks = 2**(k - ii) + 1
    N_dummy_blocks = N_blocks - 1

    collision_found = False
    while not(collision_found):

        dummy_blocks = b''
        for ii in range(N_dummy_blocks):
            dummy_blocks += random.Random.get_random_bytes(block_size)    
        dummy_hash = cp.MD(dummy_blocks, round_initial_state, block_size)
        
        db = 0
        while not(collision_found) and (db < max_msg):

            last_block = db.to_bytes(block_size, 'little')
            this_H = cp.MD(last_block, dummy_hash, block_size)

            if original_H == this_H:
                colliding_message = colliding_message + dummy_blocks + last_block
                collision_list.append(colliding_message)
                collision_found = True
                round_initial_state = this_H
            else:
                db += 1

        print('No collision found--generating new dummy values....') 
            
colliding_H = cp.MD(colliding_message, initial_state, block_size)            
print("\nCollision Found:\n")
print(f"Original Message:  {message}")
print(f"Original Hash Digest:  {original_H}\n")
print("Hashes of generated collisions:")
for c in collision_list:
    print(cp.MD(c, initial_state, block_size))

Processing:  ii=1
No collision found--generating new dummy values....
Processing:  ii=2
No collision found--generating new dummy values....
Processing:  ii=3
No collision found--generating new dummy values....
No collision found--generating new dummy values....
No collision found--generating new dummy values....
No collision found--generating new dummy values....
Processing:  ii=4
No collision found--generating new dummy values....
Processing:  ii=5
No collision found--generating new dummy values....

Collision Found:

Original Message:  b'E\xc3'
Original Hash Digest:  b'\xfd\xc7'

Hashes of generated collisions:
b'\xfd\xc7'
b'\xfd\xc7'
b'\xfd\xc7'
b'\xfd\xc7'
b'\xfd\xc7'


<div class="alert alert-block alert-info">  
    
Now you can make a message of any length in `(k, k + 2^k - 1)` blocks by choosing the appropriate message (short or long) from each pair.

Now we're ready to attack a long message `M` of `2^k` blocks.

1. Generate an expandable message of length `(k, k + 2^k - 1)` using the strategy outlined above.
2. Hash `M` and generate a map of intermediate hash states to the block indices that they correspond to.
3. From your expandable message's final state, find a single-block "bridge" to intermediate state in your map. Note the index i it maps to.
4. Use your expandable message to generate a prefix of the right length such that `len(prefix || bridge || M[i..]) = len(M)`.

The padding in the final block should now be correct, and your forgery should hash to the same value as `M`.

    
</div>

In [5]:
def MDwithIVs(M, H, blockSize=2):
    
    key = cp.zero_pad(H[:blockSize])
    myAES = AES.new(key, AES.MODE_ECB)
    
    Blocks = [M[ii:ii+blockSize] for ii in range(0, len(M), blockSize)]
    IVs = [H]
    for block in Blocks:
        
        H = myAES.encrypt(cp.zero_pad(block))[0:blockSize]
        IVs.append(H)
        myAES = AES.new(cp.zero_pad(H), AES.MODE_ECB)

    return(H[0:blockSize], IVs)

In [24]:
b = 16
block_size = b // 8
initial_state = random.Random.get_random_bytes(block_size)
#M_target = b"I am regularly asked what the average Internet user can do to ensure his security. My first answer is usually 'Nothing; you're screwed'. [Schneier]" 
M_target = b"ABCDEFGH"

k = int(len(M_target)/block_size).bit_length()

# Step 1 - Generate an expandable message of length `(k, k + 2^k - 1)` 
# using the strategy outlined above.
collision_list, final_state = cp.generate_expandable_message2(k, initial_state, block_size)

Processing:  ii=1, N=5
Processing:  ii=2, N=3
Processing:  ii=3, N=2


In [10]:
current_state = initial_state
for c in collision_list:
    a = cp.MD(c[0], current_state, block_size)
    print(a.hex())
    b = cp.MD(c[1], current_state, block_size)
    print(b.hex())
    assert(a==b)
    current_state = a

297c
297c
0b42
0b42
8e66
8e66


In [21]:
# Step 2 - Hash `M` and generate a map of intermediate hash states to the 
# block indices that they correspond to.
_, IVs = MDwithIVs(M_target, initial_state, block_size) 

# Step 3 - From your expandable message's final state, find a single-
# block "bridge" to intermediate state in your map. Note the index i 
# it maps to.
bridge_val = 0
bridge_found = False
while not(bridge_found):
    bridge_bytes = bridge_val.to_bytes(block_size, 'little')
    bridge_hash = cp.MD(bridge_bytes, final_state, block_size)
    if bridge_hash in IVs:
        bridge_found = True
    elif bridge_val >= max_msg:
        raise(Exception('No Valid Bridge Found'))
    else:
        bridge_val += 1

match_idx = IVs.index(bridge_hash)

# Step 4 - Use your expandable message to generate a prefix of the 
# right length such that `len(prefix || bridge || M[i..]) = len(M)`.
target_len = len(M_target)
prefix_len = target_len - (match_idx) - 2

print('Done')

Done


[Back to Index](CryptoPalsWalkthroughs_Cobb.ipynb)