#### Challenge 55:  MD4 Collisions

[Back to Index](CryptoPalsWalkthroughs_Cobb.ipynb)

In [1]:
from Crypto.Random import random
from Crypto.Cipher import AES
from Crypto.Cipher import Blowfish
import math
import cryptopals as cp
import pdb

<div class="alert alert-block alert-info">   

MD4 is a `128-bit` cryptographic hash function, meaning it should take a work factor of roughly `2^64` to find collisions.

It turns out we can do much better.

The paper "Cryptanalysis of the Hash Functions MD4 and RIPEMD" by Wang et al details a cryptanalytic attack that lets us find collisions in `2^8` or less.

Given a message block `M`, Wang outlines a strategy for finding a sister message block `M'`, differing only in a few bits, that will collide with it. Just so long as a short set of conditions holds true for `M`.

What sort of conditions? Simple bitwise equalities within the intermediate hash function state, e.g. `a[1][6] = b[0][6]`. This should be read as: "the sixth bit (zero-indexed) of `a[1]` (i.e. the first update to `a`) should equal the sixth bit of `b[0]` (i.e. the initial value of `b`)".

It turns out that a lot of these conditions are trivial to enforce. To see why, take a look at the first (of three) rounds in the MD4 compression function. In this round, we iterate over each word in the message block sequentially and mix it into the state. So we can make sure all our first-round conditions hold by doing this:

```
# calculate the new value for a[1] in the normal fashion
a[1] = (a[0] + f(b[0], c[0], d[0]) + m[0]).lrot(3)

# correct the erroneous bit
a[1] ^= ((a[1][6] ^ b[0][6]) << 6)

# use algebra to correct the first message block
m[0] = a[1].rrot(3) - a[0] - f(b[0], c[0], d[0])
```
    
Simply ensuring all the first round conditions puts us well within the range to generate collisions, but we can do better by correcting some additional conditions in the second round. This is a bit trickier, as we need to take care not to stomp on any of the first-round conditions.

Once you've adequately massaged `M`, you can simply generate `M'` by flipping a few bits and test for a collision. A collision is not guaranteed as we didn't ensure every condition. But hopefully we got enough that we can find a suitable `(M, M')` pair without too much effort.

Implement Wang's attack.

</div>

---

I'll start by implementing MD4 in Python to get a better handle on what we're doing here.  Based on [RFC 1320](https://tools.ietf.org/html/rfc1320) by Rivest and the paper by Wang.

In [2]:
def bitget(x, n):
    """Return bit #n of x"""
    return (x >> n) & 1

def bit_in_place(x, n):
    """Return bit #n of x in its original bit position"""
    return (x & 2**n)

def bitset(x, n, bv):
    """Set bit #n of x to bv"""
    if bv==1:
        x |= 2**n
    else:
        x ^= bit_in_place(x, n)
    return(x)

def lrot_32(n, d):
    """Circular rotate left.  Python only natively supports non-circular shift."""
    return ( (n << d) | (n >> (32 - d)) )

def rrot_32(n, d):
    """Circular rotate right.  Python only natively supports non-circular shift."""
    return ( (n << (32 - d)) | (n >> d) )

def byte_swap(data, word_size):
    """ 
    Byte-swap's a byte string of words.  
    Specify word-length in bytes.
    """
    
    bs_data = [0]*len(data)
    for ii in range(0, len(data), word_size):
        bs_data[ii:ii+word_size] = data[ii:ii+4][::-1]
    return(bytes(bs_data))

In [3]:
# Define MD4 Auxilliary Functions:
    
def F(X, Y, Z):         
    return (X & Y) | (~X & Z)

def G(X, Y, Z):
    return (X & Y) | (X & Z) | (Y & Z)

def H(X, Y, Z):
    return (X ^ Y ^ Z)

def phi(j, a, b, c, d, w, s):      
    
    MGK_1 = 0x5a827999
    MGK_2 = 0x6ed9eba1
    
    if j == 0:            
        x = lrot_32(((a + F(b, c, d) + w) % 2**32), s)        
    elif j ==  1:            
        x = lrot_32(((a + G(b, c, d) + w + MGK_1) % 2**32), s)            
    elif j == 2:            
        x = lrot_32(((a + H(b, c, d) + w + MGK_2) % 2**32), s)                 
    return(x)

In [4]:
def MD4_pad_data(data):
    
    if isinstance(data, str):
        data = data.encode()
    
    # Step 1:  Append padding bits.  Single 1-bit + 0-bits so that
    #          length of message is congruent to 448 mod 512.
    #          I'll assume we're always passed a string or bytes.
    
    bit_length = len(data)*8
    
    # append 1-bit = 0x80 
    data += b'\x80' # Hex 0x80 = 0b10000000
    
    data_len = len(data) % 64
    padding_len = (56 - data_len) % 64
    
    data += b'\x00'*padding_len
    
    # Step 2:  Append length.  64-bit representation before padding.
    
    data += bit_length.to_bytes(8, 'little', signed=False)
    
    M = []
    for ii in range(0, len(data), 4):
        word = int.from_bytes(data[ii:ii+4], byteorder='little', signed=False)
        M.append(word)
        
    return(M)

In [15]:
def MD4_get_IVs(M):

    
    # Step 4:  Process Message in blocks of 16 32-bit words (512 bits ea)
    
    # Run the compression algorithm.  Loop for each block of 512 bits until
    # full message is consumed.
    
    W = [ [0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15], 
          [0,  4,  8, 12,  1,  5,  9, 13,  2,  6, 10, 14,  3,  7, 11, 15],
          [0,  8,  4, 12,  2, 10,  6, 14,  1,  9,  5, 13,  3, 11,  7, 15] ]   
    
    S = [ [3, 7, 11, 19],
          [3, 5,  9, 13],
          [3, 9, 11, 15]]    
       
    A, B, C, D = [0x67452301], [0xefcdab89], [0x98badcfe], [0x10325476]

    N = len(M)//16
    
    for kk in range(N):
        
        AA, BB, CC, DD = A[-1], B[-1], C[-1], D[-1]
        X = M[16*kk:16*(kk+1)]
        
        for jj in range(3):        
            
            for ii in range(4):

                A.append(phi(jj, A[-1], B[-1], C[-1], D[-1], X[W[jj][4*ii+0]], S[jj][0]))
                D.append(phi(jj, D[-1], A[-1], B[-1], C[-1], X[W[jj][4*ii+1]], S[jj][1]))
                C.append(phi(jj, C[-1], D[-1], A[-1], B[-1], X[W[jj][4*ii+2]], S[jj][2]))
                B.append(phi(jj, B[-1], C[-1], D[-1], A[-1], X[W[jj][4*ii+3]], S[jj][3]))

        A[-1] = (A[-1] + AA) % 2**32
        B[-1] = (B[-1] + BB) % 2**32
        C[-1] = (C[-1] + CC) % 2**32
        D[-1] = (D[-1] + DD) % 2**32

    return(A, B, C, D)

In [16]:
def MD4(data):
    
    """ 
    Modified my MD4 implementation to better track with the notation used
    in Wang's paper -- and to keep intermediate results for all loop iterations
    within a round.  Could also modify to retain intermediates across rounds if
    needed.
    """
    
    M = MD4_pad_data(data)
    A, B, C, D = MD4_get_IVs(M)           

    digest = A[-1].to_bytes(4, 'little') + \
             B[-1].to_bytes(4, 'little') + \
             C[-1].to_bytes(4, 'little') + \
             D[-1].to_bytes(4, 'little')
   
    return(digest)

In [17]:
# Check against the RFC 1320 test suite:

assert(MD4('').hex() == '31d6cfe0d16ae931b73c59d7e0c089c0')
assert(MD4('a').hex() == 'bde52cb31de33e46245e05fbdbd6fb24')
assert(MD4('abc').hex() == 'a448017aaf21d8525fc10ae87aa6729d')
assert(MD4('message digest').hex() == 'd9130a8164549fe818874806e1c7014b')
assert(MD4('abcdefghijklmnopqrstuvwxyz').hex() == 'd79e1c308aa5bbcdeea8ed63df412da9')
assert(MD4('ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789').hex() == '043f8582f241db351ce627e153e7f0e4')
assert(MD4('12345678901234567890123456789012345678901234567890123456789012345678901234567890').hex() == \
      'e33b4ddc9c38f2199c3e7b164fcc0536')

print('If you can see this, all the tests passed.')

If you can see this, all the tests passed.


---

From the Wang paper, the attack includes 3 parts:
    
1. Find a collision differential in which `M` and `M'` produces a collision
2. Derive a set of sufficient conditions which ensure the collision differential to hold.
3. For any random message `M`, make some modification to `M` such that almost all the sufficient conditions hold.

---

In [18]:
# The examples in Table 3 of Wang's paper use big-endian byte order for the examples that
# include message padding.  My MD4 uses little-endian.  I need a byte-swap function to fix 
# the byte ordering...

M_1 = byte_swap(bytes.fromhex(
    '4d7a9c83 56cb927a b9d5a578 57a7a5ee de748a3c dcc366b3 b683a020 3b2a5d9f' \
    'c69d71b3 f9e99198 d79f805e a63bb2e8 45dd8e31 97e31fe5 2794bf08 b9e8c3e9'), 4)
M_1_c = byte_swap(bytes.fromhex(
    '4d7a9c83 d6cb927a 29d5a578 57a7a5ee de748a3c dcc366b3 b683a020 3b2a5d9f'
    'c69d71b3 f9e99198 d79f805e a63bb2e8 45dc8e31 97e31fe5 2794bf08 b9e8c3e9'), 4)
H_1 = bytes.fromhex('4d7e6a1d efa93d2d de05b45d 864c429b')

M_2 = byte_swap(bytes.fromhex(
    '4d7a9c83 56cb927a b9d5a578 57a7a5ee de748a3c dcc366b3 b683a020 3b2a5d9f' \
    'c69d71b3 f9e99198 d79f805e a63bb2e8 45dd8e31 97e31fe5 f713c240 a7b8cf69'), 4)
M_2_c = byte_swap(bytes.fromhex(
    '4d7a9c83 d6cb927a 29d5a578 57a7a5ee de748a3c dcc366b3 b683a020 3b2a5d9f' \
    'c69d71b3 f9e99198 d79f805e a63bb2e8 45dc8e31 97e31fe5 f713c240 a7b8cf69'), 4)

H_2 = bytes.fromhex('c6f3b3fe 1f4833e0 697340fb 214fb9ea')

assert(MD4(M_1) == H_1)
assert(MD4(M_1_c) == H_1)
assert(MD4(M_2) == H_2)
assert(MD4(M_2_c) == H_2)

print('Table 3 Tests passed')

Table 3 Tests passed


---

We've verified that the MD4 implementation works with the sample data...now let's see if we can follow how the colliding messages were found from the originals.

I'll try to follow the notation used in Wang's paper

---

In [22]:
def MD4_Wang_SSM(data):
    
    """ 
    Implements the single step modification from Wang's paper.
    """
    
    M = MD4_pad_data(data)
    A, B, C, D = MD4_get_IVs(M)   
    
    # Modify the message to meet several of the constraints in Table 6
    # All bit #ing has been converted from base 1 to base 0
    
    # a1,6 = b0,6
    A[1] = bitset(A[1], 6, bitget(B[0], 6))
    
    # d1,6 = 0  
    D[1] = bitset(D[1], 6, 0)
    # d1,7 = a1,7  
    D[1] = bitset(D[1], 7, bitget(A[1], 7))
    # d1,10 = a1,10
    D[1] = bitset(D[1], 10, bitget(A[1], 10))
    
    # c1,6 = 1 
    C[1] = bitset(C[1], 6, 1)
    # c1,7 = 1 
    C[1] = bitset(C[1], 7, 1)
    # c1,10 = 0 
    C[1] = bitset(C[1], 10, 0)
    # c1,25 = d1,25
    C[1] = bitset(C[1], 25, bitget(D[1], 25))
    
    # b1,6 = 1
    B[1] = bitset(B[1], 6, 1)
    # b1,7 = 0
    B[1] = bitset(B[1], 7, 0)
    # b1,10 = 0
    B[1] = bitset(B[1], 10, 0)
    # b1,25 = 0
    B[1] = bitset(B[1], 25, 0)
    
    # a2,7 = 1
    A[2] = bitset(A[2], 7, 1)
    # a2,10 = 1
    A[2] = bitset(A[2], 10, 1)
    # a2,13 = b1,13
    A[2] = bitset(A[2], 13, bitget(B[1], 13))
    # a2,25 = 0
    A[2] = bitset(A[2], 25, 0)
    
    # d2,13 = 0
    D[2] = bitset(D[2], 13, 0)
    # d2,18 = a2,18
    D[2] = bitset(D[2], 18, bitget(A[2], 18))
    # d2,19 = a2,19
    D[2] = bitset(D[2], 19, bitget(A[2], 19))
    # d2,20 = a2,20
    D[2] = bitset(D[2], 20, bitget(A[2], 20))
    # d2,21 = a2,21
    D[2] = bitset(D[2], 21, bitget(A[2], 21))
    # d2,25 = 1
    D[2] = bitset(D[2], 25, 1)
        
    # c2,12 = d2,12
    C[2] = bitset(C[2], 12, bitget(D[2], 12))
    # c2,13 = 0
    C[2] = bitset(C[2], 13, 0)
    # c2,14 = d2,14
    C[2] = bitset(C[2], 14, bitget(D[2], 14))
    # c2,18 = 0
    C[2] = bitset(C[2], 18, 0)
    # c2,19 = 0
    C[2] = bitset(C[2], 19, 0)
    # c2,20 = 1
    C[2] = bitset(C[2], 20, 1)
    # c2,21 = 0
    C[2] = bitset(C[2], 21, 0)
    
    # b2,12 = 1
    B[2] = bitset(B[2], 12, 1)
    # b2,13 = 1
    B[2] = bitset(B[2], 13, 1)
    # b2,14 = 0
    B[2] = bitset(B[2], 14, 0)
    # b2,16 = c2,16
    B[2] = bitset(B[2], 16, bitget(C[2], 16))
    # b2,18 = 0
    B[2] = bitset(B[2], 18, 0)
    # b2,19 = 0
    B[2] = bitset(B[2], 19, 0)
    # b2,20 = 0, 
    B[2] = bitset(B[2], 20, 0)
    # b2,21 = 0
    B[2] = bitset(B[2], 21, 0)
    
    # a3,12 = 1
    A[3] = bitset(A[3], 12, 1)
    # a3,13 = 1
    A[3] = bitset(A[3], 13, 1)
    # a3,14 = 1
    A[3] = bitset(A[3], 14, 1)
    # a3,16 = 0
    A[3] = bitset(A[3], 16, 0)
    # a3,18 = 0
    A[3] = bitset(A[3], 18, 0)
    # a3,19 = 0
    A[3] = bitset(A[3], 19, 0)
    # a3,20 = 0, 
    A[3] = bitset(A[3], 20, 0)
    # a3,21 = 1
    A[3] = bitset(A[3], 21, 1)
    # a3,22 = b2,22     
    A[3] = bitset(A[3], 22, bitget(B[2], 22))
    # a3,25 = b2,25
    A[3] = bitset(A[3], 25, bitget(B[2], 25))
    
    # d3,12 = 1
    D[3] = bitset(D[3], 12, 1)
    # d3,13 = 1
    D[3] = bitset(D[3], 13, 1)
    # d3,14 = 1
    D[3] = bitset(D[3], 14, 1)
    # d3,16 = 0
    D[3] = bitset(D[3], 16, 0)
    # d3,19 = 0
    D[3] = bitset(D[3], 19, 0)
    # d3,20 = 1
    D[3] = bitset(D[3], 20, 1)
    # d3,21 = 1
    D[3] = bitset(D[3], 21, 1)
    # d3,22 = 0
    D[3] = bitset(D[3], 22, 0)
    # d3,25 = 1
    D[3] = bitset(D[3], 25, 1)
    # d3,29 = a3,29
    D[3] = bitset(D[3], 29, bitget(A[3], 29))
    
    # c3,16 = 1
    C[3] = bitset(C[3], 16, 1)
    # c3,19 = 0
    C[3] = bitset(C[3], 19, 0)
    # c3,20 = 0
    C[3] = bitset(C[3], 20, 0)
    # c3,21 = 0
    C[3] = bitset(C[3], 21, 0)
    # c3,22 = 0
    C[3] = bitset(C[3], 22, 0)
    # c3,25 = 0
    C[3] = bitset(C[3], 25, 0)
    # c3,29 = 1
    C[3] = bitset(C[3], 29, 1)
    # c3,31 = d3,31
    C[3] = bitset(C[3], 31, bitget(D[3], 31))
    
    # b3,19 = 0
    B[3] = bitset(B[3], 19, 0)
    # b3,20 = 1
    B[3] = bitset(B[3], 20, 1)
    # b3,21 = 1
    B[3] = bitset(B[3], 21, 1)
    # b3,22 = c3;22
    B[3] = bitset(B[3], 22, bitget(C[3], 22))
    # b3,25 = 1
    B[3] = bitset(B[3], 25, 1)
    # b3,29 = 0
    B[3] = bitset(B[3], 29, 0)
    # b3,31 = 0
    B[3] = bitset(B[3], 31, 0)
    
    # a4,22 = 0
    A[4] = bitset(A[4], 22, 0)
    # a4,25 = 0
    A[4] = bitset(A[4], 25, 0)
    # a4,26 = b3,26
    A[4] = bitset(A[4], 26, bitget(B[3], 26))
    # a4,28 = b3,28
    A[4] = bitset(A[4], 28, bitget(B[3], 28))
    # a4,29 = 1
    A[4] = bitset(A[4], 29, 1)
    # a4,31 = 0
    A[4] = bitset(A[4], 31, 0)
    
    # d4,22 = 0
    D[4] = bitset(D[4], 22, 0)
    # d4,25 = 0
    D[4] = bitset(D[4], 25, 0)
    # d4,26 = 1
    D[4] = bitset(D[4], 26, 1)
    # d4,28 = 1
    D[4] = bitset(D[4], 28, 1)
    # d4,29 = 0
    D[4] = bitset(D[4], 29, 0)
    # d4,31 = 1
    D[4] = bitset(D[4], 31, 1)
    
    # c4,18 = d4,18 
    C[4] = bitset(C[4], 18, bitget(D[4], 18))
    # c4,22 = 1 
    C[4] = bitset(C[4], 22, 1)
    # c4,25 = 1 
    C[4] = bitset(C[4], 25, 1)
    # c4,26 = 0 
    C[4] = bitset(C[4], 26, 0)
    # c4,28 = 0 
    C[4] = bitset(C[4], 28, 0)
    # c4,29 = 0
    C[4] = bitset(C[4], 29, 0)

    # b4,18 = 0
    B[4] = bitset(B[4], 18, 0)
    # b4,25 = c4,25 = 1 (Note:  C4,25 was set to 1 in previous step)
    B[4] = bitset(B[4], 25, 1)
    # b4,26 = 1
    B[4] = bitset(B[4], 26, 1)
    # b4,28 = 1
    B[4] = bitset(B[4], 28, 1)
    # b4,29 = 0
    B[4] = bitset(B[4], 29, 0)
    # Extra condition from NSK+05 - b4,31 = c4,31
    B[4] = bitset(B[4], 31, bitget(C[4], 31))
    
    # Now back out the modified M
    M[0] = (rrot_32(A[1], 3)  - A[0] - F(B[0], C[0], D[0])) % 2**32
    M[1] = (rrot_32(D[1], 7)  - D[0] - F(A[1], B[0], C[0])) % 2**32
    M[2] = (rrot_32(C[1], 11) - C[0] - F(D[1], A[1], B[0])) % 2**32
    M[3] = (rrot_32(B[1], 19) - B[0] - F(C[1], D[1], A[1])) % 2**32
    
    M[4] = (rrot_32(A[2], 3)  - A[1] - F(B[1], C[1], D[1])) % 2**32
    M[5] = (rrot_32(D[2], 7)  - D[1] - F(A[2], B[1], C[1])) % 2**32
    M[6] = (rrot_32(C[2], 11) - C[1] - F(D[2], A[2], B[1])) % 2**32
    M[7] = (rrot_32(B[2], 19) - B[1] - F(C[2], D[2], A[2])) % 2**32
    
    M[8] = (rrot_32(A[3], 3)   - A[2] - F(B[2], C[2], D[2])) % 2**32
    M[9] = (rrot_32(D[3], 7)   - D[2] - F(A[3], B[2], C[2])) % 2**32
    M[10] = (rrot_32(C[3], 11) - C[2] - F(D[3], A[3], B[2])) % 2**32
    M[11] = (rrot_32(B[3], 19) - B[2] - F(C[3], D[3], A[3])) % 2**32
    
    M[12] = (rrot_32(A[4], 3)  - A[3] - F(B[3], C[3], D[3])) % 2**32
    M[13] = (rrot_32(D[4], 7)  - D[3] - F(A[4], B[3], C[3])) % 2**32
    M[14] = (rrot_32(C[4], 11) - C[3] - F(D[4], A[4], B[3])) % 2**32
    M[15] = (rrot_32(B[4], 19) - B[3] - F(C[4], D[4], A[4])) % 2**32
    
    data_ = b''
    for ii in range(16):
        data_ += M[ii].to_bytes(4, 'little')
        
    return(data_)

In [23]:
def fixA5(data):
    """ 
    Implement Table 1 changes modify M for A[5] corrections.
    """
    M = MD4_pad_data(data)
    A, B, C, D = MD4_get_IVs(data)
    
    for kk in [18, 25, 26, 28, 31]:
        
        direction = 0
        
        if kk==18:
            direction = bitget(A[5], 18) - bitget(C[4], 18)
            A[5] = A[5] ^ (bit_in_place(A[5], 18) ^ bit_in_place(C[4], 18))
        elif kk == 25:
            if bitget(A[5], 25) == 0:
                direction = 1
                A[5] |= 2**25                
        elif kk == 26:
            if bitget(A[5], 26) == 1:
                A[5] ^= bit_in_place(A[5], 26)
                direction = -1
        elif kk == 28:
            if bitget(A[5], 28) == 0:
                direction = 1
                A[5] |= 2**28
        elif kk == 31:
            if bitget(A[5], 28) == 0:
                direction = 1
                A[5] |= 2**31

        # Now back out the modified M
        if direction == 1:
            M[0] += 2**(kk+1-4) % 2**32
            A[1] |= 2**kk
            
        elif direction == -1:
            M[0] -= 2**(kk+1-4) % 2**32
            A[1] ^= bit_in_place(A[1], kk)

        M[1] = rrot_32(D[1],  7) - D[0] - F(A[1], B[0], C[0]) % 2**32
        M[2] = rrot_32(C[1], 11) - C[0] - F(D[1], A[1], B[0]) % 2**32
        M[3] = rrot_32(B[1], 19) - B[0] - F(C[1], D[1], A[1]) % 2**32
        M[4] = rrot_32(A[2],  3) - A[1] - F(B[1], C[1], D[1]) % 2**32

    data_ = b''
    for ii in range(16):
        data_ += M[ii].to_bytes(4, 'little')

    return(data_)

In [None]:
# Try to find an MD4 collision on random inputs -- probability is 2**-25 using 
# just the round 1 changes.
#
# For progress indicator, make sure ipywidgets is installed:
# https://ipywidgets.readthedocs.io/en/latest/user_install.html

import ipywidgets as widgets

collision_found = False
ctr = 0
max_tries = 2**28
progress = widgets.IntProgress(
    value=0,
    min=0,
    max=max_tries,
    step=1,
    description='Progress:',
    bar_style='', 
    orientation='horizontal'
    )

display(progress)

while not(collision_found) and ctr < max_tries:

    #msg = ctr.to_bytes(30, 'little')
    msg = random.Random.get_random_bytes(30)
    msg_ = MD4_Wang_SSM(msg)
    
    if MD4(msg) == MD4(msg_):
        collision_found = True
        
    ctr += 1
    if ctr % 2**8 == 0:
        progress.value = ctr
        
if ctr == max_tries:
    print('Boooo')
else:
    print('Collision found for:')
    print()
    print('M = {M.hex()}')
    print('M = {M_.hex()}')
    print('Hash = {MD4(M).hex()}')

IntProgress(value=0, description='Progress:', max=268435456)

[Back to Index](CryptoPalsWalkthroughs_Cobb.ipynb)