#### Challenge 52: Iterated Hash Function Multicollisions

[Back to Index](CryptoPalsWalkthroughs_Cobb.ipynb)

In [1]:
from Crypto.Random import random
from Crypto.Cipher import AES
from Crypto.Cipher import Blowfish
import cryptopals as cp
import pdb

<div class="alert alert-block alert-info">   

While we're on the topic of hash functions...

The major feature you want in your hash function is collision-resistance. That is, it should be hard to generate collisions, and it should be really hard to generate a collision for a given hash (aka preimage).

Iterated hash functions have a problem: the effort to generate lots of collisions scales sublinearly.

What's an iterated hash function? For all intents and purposes, we're talking about the Merkle-Damgard construction. It looks like this:

```
function MD(M, H, C):
  for M[i] in pad(M):
    H := C(M[i], H)
  return H
```   
<br>    
    
For message `M`, initial state `H`, and compression function `C`.

This should look really familiar, because SHA-1 and MD4 are both in this category. What's cool is you can use this formula to build a makeshift hash function out of some spare crypto primitives you have lying around (e.g. `C = AES-128`).

Back on task: the cost of collisions scales sublinearly. What does that mean? If it's feasible to find one collision, it's probably feasible to find a lot.

How? For a given state `H`, find two blocks that collide. Now take the resulting hash from this collision as your new H and repeat. Recognize that with each iteration you can actually double your collisions by subbing in either of the two blocks for that slot.

This means that if finding two colliding messages takes `2^(b/2)` work (where `b` is the bit-size of the hash function), then finding `2^n` colliding messages only takes `n*2^(b/2)` work.

Let's test it. First, build your own MD hash function. We're going to be generating a LOT of collisions, so don't knock yourself out. In fact, go out of your way to make it bad. Here's one way:

1. Take a fast block cipher and use it as `C`.
2. Make `H` pretty small. I won't look down on you if it's only 16 bits. Pick some initial `H`.
3. `H` is going to be the input key and the output block from `C`. That means you'll need to pad it on the way in and drop bits on the way out.
    
</div>

In [27]:
def zero_pad(m, blockSize=16):
    
    r = len(m) % blockSize
    
    if r != 0:
        pad_length = 16 - r
    else:
        pad_length = 0
    
    padded_m = m + b'\x00'*pad_length
    
    return(padded_m)

def MD(M, H, blockSize=2):
    
    key = zero_pad(H[:blockSize])
    myAES = AES.new(key, AES.MODE_ECB)
    
    Blocks = [M[ii:ii+blockSize] for ii in range(0, len(M), blockSize)]
    
    for block in Blocks:
        
        H = myAES.encrypt(zero_pad(block))[0:blockSize]
        myAES = AES.new(zero_pad(H), AES.MODE_ECB)

    return(H)
 

In [3]:
m = b'Complexity is the worst enemy of security, and our systems are getting more complex all the time. [Bruce Schneier]'
h = b'AA'

print(MD(m, h))

b'|\xc8'


<div class="alert alert-block alert-info">   

Now write the function `f(n)` that will generate `2^n` collisions in this hash function.
    
> How? For a given state `H`, find two blocks that collide. Now take the resulting hash from this collision as your new H and repeat. Recognize that with each iteration you can actually double your collisions by subbing in either of the two blocks for that slot.
    
</div>

In [28]:
def find_collision(h, block_size=2):
        
    max_state = 2**(8*block_size)
    
    for ii in range(max_state-1):
        
        block_1 = ii.to_bytes(block_size, 'little')
        digest_1 = MD(block_1, h)
        
        for jj in range(ii+1, max_state):
            
            block_2 = jj.to_bytes(block_size, 'little')
            digest_2 = MD(block_2, h)
            
            if digest_1 == digest_2:
                
                collision = [block_1, block_2]
                return(collision, digest_1)
           
    raise Exception('Space exhausted.  No collision found')

In [29]:
def extend_collision_list(collision_list):
    
    last_h = collision_list[-1][2]
    collision, next_h = find_collision(last_h)
    next_round_collision_list = []
    last_round_count = int((len(collision_list)+1)/2)
    last_round_collision_list = collision_list[-last_round_count:].copy()
    for c in last_round_collision_list:

        a = c[0] + collision[0]
        b = c[0] + collision[1]
        next_round_collision_list.append([a, b, next_h, last_h])

        a = c[1] + collision[0]
        b = c[1] + collision[1]
        next_round_collision_list.append([a, b, next_h, last_h])

    collision_list += next_round_collision_list.copy()
    del(next_round_collision_list)
    

In [30]:
h = b'\xf8\xf7'

collision, next_h = find_collision(h)
collision_list = [[collision[0], collision[1], next_h, h]]

N = 12
for ii in range(N):
    extend_collision_list(collision_list)
    print(f"Extended to {len(collision_list)} collisions")

Extended to 3 collisions
Extended to 7 collisions
Extended to 15 collisions
Extended to 31 collisions
Extended to 63 collisions
Extended to 127 collisions
Extended to 255 collisions
Extended to 511 collisions
Extended to 1023 collisions
Extended to 2047 collisions
Extended to 4095 collisions
Extended to 8191 collisions


In [31]:
# Verify that all of the collisions are valid
for ii in range(1, len(collision_list)):
    assert(MD(collision_list[ii][0], h) == MD(collision_list[ii][1], h))
    
print("All collisions validated")

All collisions validated


<div class="alert alert-block alert-info">   

Why does this matter? Well, one reason is that people have tried to strengthen hash functions by cascading them together. Here's what I mean:

1. Take hash functions `f` and `g`.
2. Build `h` such that `h(x) = f(x) || g(x)`.

The idea is that if collisions in `f` cost `2^(b1/2)` and collisions in `g` cost `2^(b2/2)`, collisions in h should come to the princely sum of `2^((b1+b2)/2)`.

But now we know that's not true!

Here's the idea:

1. Pick the "cheaper" hash function. Suppose it's `f`.
2. Generate `2^(b2/2)` colliding messages in `f`.
3. There's a good chance your message pool has a collision in `g`.
4. Find it.

And if it doesn't, keep generating cheap collisions until you find it.

Prove this out by building a more expensive (but not too expensive) hash function to pair with the one you just used. Find a pair of messages that collide under both functions. Measure the total number of calls to the collision function.

</div>

#### Resources:  

Paper:  [_Multicollisions in Iterated Hash Functions Application to Cascaded Constructions_](https://link.springer.com/content/pdf/10.1007/978-3-540-28628-8_19.pdf) by Joux

In [33]:
def f(m, s):
    
    return(MD(m, s, 2))

def g(m, s):
    
    return(MD(m, s, 3))

def h(m, s):
    
    return(f(m, s) + g(m, s))

---

I have `b1 = 16` bits and `b2 = 24` bits.

2^((b1+b2)/2) would be 2^((16 + 24)/2) = 2^(20)

Lets start by generating 2^(16/2) = 2^8 = 256 colliding messages in f:

In [38]:
def check_for_g_collision(collision_list, initial_state):
   
    multi_collision_list = []
    for c in collision_list:

        # a = g(c[0], initial_state)
        # b = g(c[1], initial_state)
        a = MD(c[0], initial_state, 3)
        b = MD(c[1], initial_state, 3)

        if a == b:

            print(f"Collision found for {c[0]}, {c[1]}")
            multi_collision_found = True
            multi_collision_list.append([c[0], c[1]])

    return(multi_collision_list)

In [None]:
multi_collision_found = False
multi_collision_list = []

n_s_tried = 0
MAX_N = 12 

while not(multi_collision_found):
    
    n_s_tried += 1
    initial_state = random.Random.get_random_bytes(2)
    print()
    print(f"Initial State: {int.from_bytes(initial_state, 'little')}")
    
    # Generate 2**(b2/2) collisions to start:   
    N = 12
    ii = N + 1
    
    collision, next_h = find_collision(initial_state)
    collision_list = [[collision[0], collision[1], next_h, initial_state]]
    for ii in range(N):
        extend_collision_list(collision_list)
    
    new_multi_collisions = check_for_g_collision(collision_list, initial_state)
    if len(new_multi_collisions) > 0:
        multi_collision_list += new_multi_collisions
        multi_collision_found = True
    else:
        print('No multi-collisions found in initial set')
    
    while not(multi_collision_found) and (ii <= MAX_N):
        print(f"Extending to: {2**ii}")
        extend_collision_list(collision_list)
        new_multi_collisions = check_for_g_collision(collision_list[2**(ii-1):].copy(), initial_state)
        if len(new_multi_collisions) > 0:
            multi_collision_list += new_multi_collisions
            multi_collision_found = True
        ii += 1
    
    del(collision_list)


Initial State: 31034
No multi-collisions found in initial set
Extending to: 2048
Extending to: 4096

Initial State: 56524
No multi-collisions found in initial set
Extending to: 2048
Extending to: 4096

Initial State: 36900
No multi-collisions found in initial set
Extending to: 2048
Extending to: 4096

Initial State: 62231
No multi-collisions found in initial set
Extending to: 2048
Extending to: 4096

Initial State: 53039
No multi-collisions found in initial set
Extending to: 2048
Extending to: 4096

Initial State: 61238
No multi-collisions found in initial set
Extending to: 2048
Extending to: 4096

Initial State: 3914
No multi-collisions found in initial set
Extending to: 2048
Extending to: 4096


[Back to Index](CryptoPalsWalkthroughs_Cobb.ipynb)