In [9]:
import struct

# ======== SM3 实现 ============
IV = [
    0x7380166F, 0x4914B2B9,
    0x172442D7, 0xDA8A0600,
    0xA96F30BC, 0x163138AA,
    0xE38DEE4D, 0xB0FB0E4E
]
T_j = [0x79CC4519] * 16 + [0x7A879D8A] * 48

def _rotl(x, n):
    return ((x << n) | (x >> (32 - n))) & 0xFFFFFFFF

def _P0(x): return x ^ _rotl(x, 9) ^ _rotl(x, 17)
def _P1(x): return x ^ _rotl(x, 15) ^ _rotl(x, 23)
def FFj(x, y, z, j): return x ^ y ^ z if j < 16 else (x & y) | (x & z) | (y & z)
def GGj(x, y, z, j): return x ^ y ^ z if j < 16 else (x & y) | (~x & z)

def sm3_padding(msg: bytes, total_bits=None):
    if total_bits is None:
        total_bits = len(msg) * 8
    msg += b'\x80'
    while ((len(msg) * 8) % 512) != 448:
        msg += b'\x00'
    msg += struct.pack('>Q', total_bits)
    return msg

def sm3_compress(v, b):
    W = []
    for i in range(16):
        W.append(int.from_bytes(b[i*4:(i+1)*4], 'big'))
    for i in range(16, 68):
        W.append(_P1(W[i-16] ^ W[i-9] ^ _rotl(W[i-3], 15)) ^ _rotl(W[i-13], 7) ^ W[i-6])
    W_1 = [W[i] ^ W[i+4] for i in range(64)]
    A, B, C, D, E, F, G, H = v
    for j in range(64):
        SS1 = _rotl((_rotl(A, 12) + E + _rotl(T_j[j], j % 32)) & 0xFFFFFFFF, 7)
        SS2 = SS1 ^ _rotl(A, 12)
        TT1 = (FFj(A, B, C, j) + D + SS2 + W_1[j]) & 0xFFFFFFFF
        TT2 = (GGj(E, F, G, j) + H + SS1 + W[j]) & 0xFFFFFFFF
        A, B, C, D = TT1, A, _rotl(B, 9), C
        E, F, G, H = _P0(TT2), E, _rotl(F, 19), G
    return [(v[i] ^ val) & 0xFFFFFFFF for i, val in enumerate([A, B, C, D, E, F, G, H])]

def sm3_hash(msg: bytes):
    msg = sm3_padding(msg)
    v = IV.copy()
    for i in range(0, len(msg), 64):
        v = sm3_compress(v, msg[i:i+64])
    return b''.join(i.to_bytes(4, 'big') for i in v)

def sm3_hash_from_iv(msg: bytes, iv, total_bits):
    msg = sm3_padding(msg, total_bits)
    for i in range(0, len(msg), 64):
        iv = sm3_compress(iv, msg[i:i+64])
    return b''.join(i.to_bytes(4, 'big') for i in iv)


# 模拟 secret 和原始消息
secret = b'secret_key_123456'  # 16 bytes
original_msg = b'userid=1001&role=user'
append_data = b'&admin=true'

# 服务器计算 hash(secret + original_msg)
full_msg = secret + original_msg
original_hash = sm3_hash(full_msg)
print("[Server] Original hash:   ", original_hash.hex())

# 攻击者猜测 secret 长度
guessed_len = len(secret)

# 恢复 IV
iv = [int.from_bytes(original_hash[i*4:(i+1)*4], 'big') for i in range(8)]

# 构造 forged_msg = original_msg + padding + append_data
fake_msg_len = guessed_len + len(original_msg)
padding = sm3_padding(b'A'*fake_msg_len)[fake_msg_len:]  # 只保留 padding 部分
forged_msg = original_msg + padding + append_data

# 构造伪造 hash，从中间状态继续 hash(append_data)
total_bits = (fake_msg_len + len(padding) + len(append_data)) * 8
forged_hash = sm3_hash_from_iv(append_data, iv, total_bits)
print("[Attacker] Forged hash:    ", forged_hash.hex())
print("[Attacker] Forged message: ", forged_msg)

# 服务器验证：hash(secret + forged_msg)
true_hash = sm3_hash(secret + forged_msg)
print("[Server] True hash:        ", true_hash.hex())

print("✅ Attack success:", forged_hash == true_hash)


[Server] Original hash:    2da04cf2ebbcd3d63aa3e0341b181dcfdcd818aff32cbaa639219e60b4136cae
[Attacker] Forged hash:     e25aca5a209545e7472a281fc66275b219fe8a16bad55099404ed3455b0bb1d3
[Attacker] Forged message:  b'userid=1001&role=user\x80\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x010&admin=true'
[Server] True hash:         e25aca5a209545e7472a281fc66275b219fe8a16bad55099404ed3455b0bb1d3
✅ Attack success: True
