In [None]:
%pip install -r requirements.txt

In [None]:
base = "../"
#base = "../victims/sphincsplus/ref/"
keys_file = base + "keys.txt"
sigs_file = base + "sigs.txt"

In [None]:
import fips205

In [None]:
def print_adrs(adrs: fips205.ADRS, end='\n'):
    hex = adrs.adrs().hex()
    print(' '.join([hex[i:i+8] for i in range(0, len(hex), 8)]), end=' ')
    print(end=end)

In [None]:
with open(sigs_file, "r") as f:
    sigs = [bytes.fromhex(s.strip()) for s in f.readlines()]
    #sigs = [(s[0], bytes.fromhex(s[1])) for s in sigs]
    #sigs_correct = [s[1] for s in sigs if s[0] == '[CORRECT]']
    #sigs_faulty = [s[1] for s in sigs if s[0] == '[FAULTY]']
print(f"Loaded {len(sigs)} signatures")

In [None]:
def group_keys_by_addr(self: fips205.SLH_DSA, pk: bytes, sigs: list[bytes]) -> dict[fips205.ADRS, list[fips205.WOTSKeyData]]:
    # we assume SPHINCS-SHAKE-256s
    a = self.a
    d = self.d
    hp = self.hp
    n = self.n
    k = self.k
    wots_len = self.len
    wots_bytes = wots_len * n
    xmss_bytes = hp * n
    fors_bytes = k*(n+a*n) # 10560
    print("WOTS bytes:", wots_bytes)
    print("XMSS bytes:", xmss_bytes)
    print("FORS bytes:", fors_bytes)
    sig_len = n + fors_bytes + d * (wots_bytes + xmss_bytes)
    # sig_fields = [('R', n), ('FORS', fors_bytes)] + [[(f'WOTS{i}', wots_bytes), (f'XMSS{i}', xmss_bytes)] for i in range(8)] + [('msg', 0)]
    
    pk_seed = pk[:n]
    pk_root = pk[n:]
    print("pk_seed:", pk_seed.hex())
    print("pk_root:", pk_root.hex())
    for sig in sigs:
        m = sig[sig_len:]
        sig = sig[:sig_len]
        wots_keys_before = self.wots_keys.copy()
        valid = self.slh_verify_internal(m, sig, pk)
        wots_keys_after = self.wots_keys
        for key, value in wots_keys_after.items():
            old_values = wots_keys_before[key] if key in wots_keys_before else []
            if old_values != value:
                new_vals = [v for v in value if v not in old_values]
                for v in new_vals:
                    v.valid = valid
    return self.wots_keys

In [None]:
with open(keys_file, "r") as f:
    lines = [s.split(': ') for s in f.readlines()]
    keys = {s[0]: bytes.fromhex(s[1].strip()) for s in lines}
keys

In [None]:
slh_dsa = fips205.SLH_DSA('SLH-DSA-SHAKE-256s')
pk = keys['pk']
pk_seed = pk[:slh_dsa.n]
pk_root = pk[slh_dsa.n:]
groups = group_keys_by_addr(slh_dsa, keys['pk'], sigs)

slh_dsa = fips205.SLH_DSA('SLH-DSA-SHAKE-256s')
groups = sorted(groups.items(), key=lambda item: len(item[1]), reverse=True)
groups = [(adrs, value) for adrs, value in groups if len(value) > 1]
print(f"Found {len(groups)} groups with collisions")
for adrs, value in groups:
    # print key.hex(), add whitespace every 4 bytes
    print_adrs(adrs, end='')
    print(len(value))
    pks = []
    for v in value:
        print("Valid" if v.valid else "Invalid", end='\t')
        wots_pk = slh_dsa.wots_pk_from_sig(v.sig, v.msg, pk_seed, adrs)
        print(v.sig.hex(), wots_pk.hex())
        print('\t', end='')
        for chain_idx in v.msg:
            print(f"{chain_idx}", end=' ')
        print()

In [None]:
import fips205 as sp
import postprocessing as pp

inst = 'SLH-DSA-SHAKE-256s'
instance = sp.SLH_DSA_PARAM[inst]

n = instance[1]

#keys = pp.read_keys("../victims/sphincsplus/ref/keys.txt")
#sk = keys["sk"]
#pk = keys["pk"]
#sk_seed = sk[:n]
#sk_prf = sk[n:2*n]
#pk_seed = pk[:n]
#pk_root = pk[n:]
sk_seed = bytes.fromhex("e17e72290e49a44c9c534f211195257cf13b0d45405782ceda2d7f982a551721")
sk_prf = bytes.fromhex("b47cfcf1b7764296d81055df05ff8295e4641ad9aa2db29b7b678e788bb8ea62")
pk_seed = bytes.fromhex("0847487e02a874ef8feee587f5359dfcd722f10e1cb50ac538d74320a5bfd242")
pk_root = bytes.fromhex("a9375b1b113255ce32578768690cd17431de27c356de7c7f34c057d81327b746")
print("sk_seed", sk_seed.hex())
print("sk_prf", sk_prf.hex())
print("pk_seed", pk_seed.hex())
print("pk_root", pk_root.hex())
slh_dsa = sp.SLH_DSA(inst)
(pk, sk) = slh_dsa.slh_keygen_internal(
    sk_seed=sk_seed,
    sk_prf=sk_prf,
    pk_seed=pk_seed,
    param=inst)
print("keygen done")
print("sk", sk.hex())
print("pk", pk.hex())