In [1]:
from hashlib import sha256
from math import log2
import os

class XHASH:
    def __init__(self, blocksize = 32 , max_block_num = 2**16, hashfunc = sha256, hashbyte = 32):
        self.blocksize = blocksize
        self.max_block_num = max_block_num
        self.n = int(log2(max_block_num)//8)
        self.l = blocksize + self.n
        self.hashfunc = hashfunc
        self.hashbyte = hashbyte
        
    def encode_blcok(self, msg:bytes, idx: int):
        return idx.to_bytes(self.n, "big") + msg
    
    def pad(self, msg):
        pad_len = int(self.blocksize - len(msg) % self.blocksize)
        return msg + pad_len * pad_len.to_bytes(1,"big")
    
    @staticmethod
    def xor(b1,b2):
        assert len(b1) == len(b2), "xor bytes length does not equal"
        return bytes([ i^^j for i,j in zip(b1,b2)])
    
    def digest(self, data:bytes):
        if len(data) % self.blocksize != 0:
            data = self.pad(data)
        result = b"\x00"*self.hashbyte
        nblocks =  len(data)//self.blocksize
        assert nblocks < self.max_block_num, "message too long"
        for i in range(0, nblocks):
            block = self.encode_blcok(data[i*self.blocksize:(i+1)*self.blocksize], i)
            assert len(block) ==  self.l
            result = self.xor(result, self.hashfunc(block).digest())
        return result
    
    def __call__(self, message):
        return self.digest(message)
    
    def __str__(self):
        return f"XHASH Implementation with b = {self.blocksize*8} , hash_nbit = {self.hashbyte*8}, N = {self.max_block_num}"

def encode_blcok_with_idx(msg, idx, idx_bytes_num = 2):
    return idx.to_bytes(idx_bytes_num, "big") + msg

def bytes2bitvector(data, padlen = 256):
    bitstr = bin(int.from_bytes(data, "big"))[2:].zfill(padlen)
    return vector(GF(2), map(int,bitstr))
    
def linearization_attacks(blocksize = 32 , max_block_num = 2**16, hashfunc = sha256, hashbyte = 32):
    k = hashbyte * 8
    # n = k + 1
    n = k + 1
    idx_bytes_num = int(log2(max_block_num)//8)
    l = blocksize + idx_bytes_num
    x0_blocks = [os.urandom(blocksize) for _ in range(n)]
    x1_blocks = [os.urandom(blocksize) for _ in range(n)]
    y0_blocks = [encode_blcok_with_idx(x0,i,idx_bytes_num) for i,x0 in enumerate(x0_blocks)]
    y1_blocks = [encode_blcok_with_idx(x1,i,idx_bytes_num) for i,x1 in enumerate(x1_blocks)]
    alpha0_list = [ bytes2bitvector(sha256(block).digest(), hashbyte * 8) for block in y0_blocks ]
    alpha1_list = [ bytes2bitvector(sha256(block).digest(), hashbyte * 8) for block in y1_blocks ]

    vecs = []
    for i in range(n):
        vec = [0]*(2*n)
        vec[i] = 1
        vec[i+n] = 1
        vecs.append(vec)
        
    for j in range(k):
        vec = [0]*(2*n)
        for i in range(n):
            vec[i] = alpha0_list[i][j]
            vec[i+n] = alpha1_list[i][j]
        vecs.append(vec)
    M = matrix(GF(2), vecs)
    print(f"[+] {M.dimensions() = }, { M.rank() = }")
    K = M.right_kernel()
    bs= K.basis()
    print(f"[+] kernel basis {len(bs) = }")
    for i in range(10):
        z = os.urandom(hashbyte)
        z_vec = bytes2bitvector(z, hashbyte * 8)
        target = vector(GF(2),[1]*n + list(z_vec))
        try:
            sol = M.solve_right(target)
            break
        except:
            pass
    sols =  [sol]
    for b in bs:
        sols.append(b + sol)
    print(f"[+] solutions number : {len(sols)}")
    print(f"[+] total solutions number : {2**(len(bs))}")

    results = []
    for solution in sols:
        result = b""
        u = solution[:n]
        v = solution[n:]
        for i,si in enumerate(u):
            if si == 1:
                result += x0_blocks[i]
            else:
                result += x1_blocks[i]
        results.append(result)
    XHASH_SHA256 = XHASH()
    print(f"[+] target hash = {z.hex()}")
    print(f"[+] we only show sha256 hash of resultant messages since they are too long")
    for result in results:
        print(f"[+] blocks num = { n }, {sha256(result).hexdigest() = }")
        print(f"[+] {XHASH_SHA256(result).hex() = }")

if __name__ == "__main__":
    XHASH_SHA256 = XHASH() 
    linearization_attacks()

[+] M.dimensions() = (513, 514),  M.rank() = 512
[+] kernel basis len(bs) = 2
[+] solutions number : 3
[+] total solutions number : 4
[+] target hash = a1f7a5ea6dcba6e0fc2f4655b97fe77d4edd4236a49bcec5ab927fc42c6592ce
[+] we only show sha256 hash of resultant messages since they are too long
[+] blocks num = 257, sha256(result).hexdigest() = 'f97183eef6e999699b8ab50677643c6bb86ba1c1080525c79821f9632028d967'
[+] XHASH_SHA256(result).hex() = 'a1f7a5ea6dcba6e0fc2f4655b97fe77d4edd4236a49bcec5ab927fc42c6592ce'
[+] blocks num = 257, sha256(result).hexdigest() = '015eca67d8fbdf8edb4af3c7c91fa6015bb893392193932ac4223c77165bc55b'
[+] XHASH_SHA256(result).hex() = 'a1f7a5ea6dcba6e0fc2f4655b97fe77d4edd4236a49bcec5ab927fc42c6592ce'
[+] blocks num = 257, sha256(result).hexdigest() = '2831f3eb19729e27d0c4b85abf99e7365f8953e8394c221701c34d5ec82b95b2'
[+] XHASH_SHA256(result).hex() = 'a1f7a5ea6dcba6e0fc2f4655b97fe77d4edd4236a49bcec5ab927fc42c6592ce'
