In [None]:
%pip install -r requirements.txt

In [None]:
base = "../"
#base = "../victims/sphincsplus/ref/"
keys_file = base + "keys.txt"
sigs_file = base + "sigs.txt"

In [None]:
import fips205

In [None]:
def print_adrs(adrs: fips205.ADRS, end='\n'):
    hex = adrs.adrs().hex()
    print(' '.join([hex[i:i+8] for i in range(0, len(hex), 8)]), end=' ')
    print(end=end)

In [None]:
def group_keys_by_addr(self: fips205.SLH_DSA, pk: bytes, sigs: list[bytes]) -> dict[fips205.ADRS, list[fips205.WOTSKeyData]]:
    # we assume SPHINCS-SHAKE-256s
    a = self.a
    d = self.d
    hp = self.hp
    n = self.n
    k = self.k
    wots_len = self.len
    wots_bytes = wots_len * n
    xmss_bytes = hp * n
    fors_bytes = k*(n+a*n) # 10560
    print("WOTS bytes:", wots_bytes)
    print("XMSS bytes:", xmss_bytes)
    print("FORS bytes:", fors_bytes)
    sig_len = n + fors_bytes + d * (wots_bytes + xmss_bytes)
    
    for sig in sigs:
        m = sig[sig_len:]
        sig = sig[:sig_len]
        wots_keys_before = self.wots_keys.copy()
        valid = self.slh_verify_internal(m, sig, pk)
        wots_keys_after = self.wots_keys
        for key, value in wots_keys_after.items():
            old_values = wots_keys_before[key] if key in wots_keys_before else []
            if old_values != value:
                new_vals = [v for v in value if v not in old_values]
                for v in new_vals:
                    v.valid = valid
    return self.wots_keys

In [None]:
import multiprocessing
from copy import deepcopy

def process_sig(args):
    self_copy, pk, sig, sig_len = args
    m = sig[sig_len:]
    sig = sig[:sig_len]
    wots_keys_before = deepcopy(self_copy.wots_keys)
    valid = self_copy.slh_verify_internal(m, sig, pk)
    wots_keys_after = self_copy.wots_keys

    result = []
    for key, value in wots_keys_after.items():
        old_values = wots_keys_before.get(key, [])
        if old_values != value:
            new_vals = [v for v in value if v not in old_values]
            for v in new_vals:
                v.valid = valid
            result.append((key, new_vals))
    return result

def group_keys_by_addr(self: fips205.SLH_DSA, pk: bytes, sigs: list[bytes]) -> dict[fips205.ADRS, list[fips205.WOTSKeyData]]:
    # Initialize variables
    a = self.a
    d = self.d
    hp = self.hp
    n = self.n
    k = self.k
    wots_len = self.len
    wots_bytes = wots_len * n
    xmss_bytes = hp * n
    fors_bytes = k*(n + a * n)
    sig_len = n + fors_bytes + d * (wots_bytes + xmss_bytes)

    print("WOTS bytes:", wots_bytes)
    print("XMSS bytes:", xmss_bytes)
    print("FORS bytes:", fors_bytes)

    with multiprocessing.Pool() as pool:
        args = [(deepcopy(self), pk, sig, sig_len) for sig in sigs]
        results = pool.map(process_sig, args)

    # Merge results
    merged = {}
    for item in results:
        for key, val_list in item:
            if key not in merged:
                merged[key] = []
            merged[key].extend(val_list)

    return merged


In [None]:
with open(keys_file, "r") as f:
    lines = [s.split(': ') for s in f.readlines()]
    keys = {s[0]: bytes.fromhex(s[1].strip()) for s in lines}
keys

In [None]:
with open(sigs_file, "r") as f:
    sigs = [bytes.fromhex(s.strip()) for s in f.readlines()]
    #sigs = [(s[0], bytes.fromhex(s[1])) for s in sigs]
    #sigs_correct = [s[1] for s in sigs if s[0] == '[CORRECT]']
    #sigs_faulty = [s[1] for s in sigs if s[0] == '[FAULTY]']
print(f"Loaded {len(sigs)} signatures")

In [None]:
slh_dsa = fips205.SLH_DSA('SLH-DSA-SHAKE-256s')
pk = keys['pk']
pk_seed = pk[:slh_dsa.n]
pk_root = pk[slh_dsa.n:]
groups = group_keys_by_addr(slh_dsa, keys['pk'], sigs)

groups = sorted(groups.items(), key=lambda item: len(item[1]), reverse=True)

In [None]:
slh_dsa = fips205.SLH_DSA('SLH-DSA-SHAKE-256s')

print(f"Found {len(groups)} unique addresses")
faulted = [1 for _, value in groups if False in [v.valid for v in value]]
print(f"Found {len(faulted)} groups with faulty keys")
collisions = [(adrs, value) for adrs, value in groups if len(value) > 1]
print(f"Found {len(collisions)} groups with collisions")
collisions = [(adrs, value) for adrs, value in collisions if True in [v.valid for v in value] and False in [v.valid for v in value]]
print(f"Found {len(collisions)} groups with at least one valid and one invalid key")
collisions = [
    (adrs, value)
    for adrs, value in collisions
    if any(v1.pk == v2.pk for v1 in value for v2 in value if v1.valid and not v2.valid)
]
print(f"Found {len(collisions)} groups where at least one valid and one invalid key share the same pk")
# sort by validity for display
collisions = [(adrs, sorted(value, key=lambda v: v.valid, reverse=True)) for adrs, value in collisions]
for adrs, value in collisions:
    # print key.hex(), add whitespace every 4 bytes
    print_adrs(adrs, end='')
    print(len(value))
    pks = []
    for v in value:
        print("Valid" if v.valid else "Invalid", end='\t')
        print(v.pk.hex())
        print('\t', end='')
        for chain_idx in v.msg:
            print(f"{chain_idx}", end=' ')
        print()