# Config

In [None]:
base = "../"
#base = "../victims/sphincsplus/ref/"
keys_file = base + "keys.txt"
sigs_file = base + "sigs.txt"
faulty_sigs_file = base + "sigs_faulty.txt"

params = 'SLH-DSA-SHAKE-256s'

sanity_check = False

# Setup

Install dependencies, define some helper functions

In [None]:
%pip install -r requirements.txt

In [None]:
import fips205

In [None]:
def print_adrs(adrs: fips205.ADRS, end='\n', verbose=False):
    hex = adrs.adrs().hex()
    if verbose:
        print('LAYER' + ' ' * 4 + 
              'TREE ADDR' + ' ' * 18 +
              'TYP' + ' ' * 6 +
              'KADR' + ' ' * 5 +
              'PADD = 0')
    print(' '.join([hex[i:i+8] for i in range(0, len(hex), 8)]), end=' ')
    print(end=end)

In [None]:
import multiprocessing

def process_sig(args):
    params, pk, sig, sig_len = args
    slh = fips205.SLH_DSA(params)
    m = sig[sig_len:]
    sig = sig[:sig_len]
    valid = slh.slh_verify_internal(m, sig, pk)
    for _, keys in slh.wots_keys.items():
        for key in keys:
            key.valid = valid
    return slh.wots_keys

def group_keys_by_addr(pk: bytes, sigs: list[bytes]) -> dict[fips205.ADRS, set[fips205.WOTSKeyData]]:
    # Initialize variables
    slh = fips205.SLH_DSA(params)
    a = slh.a
    d = slh.d
    hp = slh.hp
    n = slh.n
    k = slh.k
    wots_len = slh.len
    wots_bytes = wots_len * n
    xmss_bytes = hp * n
    fors_bytes = k*(n + a * n)
    sig_len = n + fors_bytes + d * (wots_bytes + xmss_bytes)

    #print("WOTS bytes:", wots_bytes)
    #print("XMSS bytes:", xmss_bytes)
    #print("FORS bytes:", fors_bytes)

    with multiprocessing.Pool() as pool:
        args = [(params, pk, sig, sig_len) for sig in sigs]
        results = pool.map(process_sig, args)

    # Merge results
    merged = {}
    for item in results:
        merged = merge_groups(merged, item)
    return merged

def merge_groups(left: dict[fips205.ADRS, set], right: dict[fips205.ADRS, set]) -> dict[fips205.ADRS, set]:
    for key, items in right.items():
        if key not in left:
            left[key] = set()
        left[key] = left[key] | items
    return left


In [None]:
def print_groups(groups: dict[fips205.ADRS, list[fips205.WOTSKeyData]]):
    print(f"Found {len(groups)} unique addresses")
    faulted = [(adrs, value) for adrs, value in groups.items() if False in [v.valid for v in value]]
    print(f"Found {len(faulted)} groups with faulty keys")
    collisions = [(adrs, value) for adrs, value in faulted if len(value) > 1]
    N = len(faulted)
    K = 256
    expected_collisions = (N*(N-1))/(2*K)
    print(f"Found {len(collisions)} groups with collisions (expected: {expected_collisions})")
    collisions: list[tuple[fips205.ADRS, fips205.WOTSKeyData]] = [(adrs, value) for adrs, value in collisions if True in [v.valid for v in value] and False in [v.valid for v in value]]
    print(f"Found {len(collisions)} groups with at least one valid and one invalid key")
    collisions = [
        (adrs, value)
        for adrs, value in collisions
        if any(v1.msg != v2.msg and v1.pk == v2.pk for v1 in value for v2 in value if v1.valid and not v2.valid)
    ]
    print(f"Found {len(collisions)} groups where at least one valid and one invalid key share the same pk")

    # sort by validity for display
    collisions: dict[fips205.ADRS, list[fips205.WOTSKeyData]] = {adrs: sorted(value, key=lambda v: v.valid, reverse=True) for adrs, value in collisions}
    groups = dict(sorted(groups.items(), key=lambda item: item[0].get_layer_address()))
    for adrs, value in groups.items():
        print_adrs(adrs, end='', verbose=True)
        print(len(value))
        pks = []
        for v in value:
            print("Valid" if v.valid else "Invalid", end='\t')
            print(v.sig.hex())
            print('\t', end='')
            print(v.pk.hex())
            print('\t', end='')
            for chain_idx in v.msg:
                print(f"{chain_idx}", end=' ')
            print()

# Clean Start

Run this cell (and below) for a clean analysis

In [None]:
processed_sigs = 0
groups: dict[fips205.ADRS, set[fips205.WOTSKeyData]] = {}

In [None]:
# Load keys
with open(keys_file, "r") as f:
    lines = [s.split(': ') for s in f.readlines()]
    keys = {s[0]: bytes.fromhex(s[1].strip()) for s in lines}
sk = keys['sk']
pk = keys['pk']
pk.hex()

# Load real signatures

This loads the real signatures from `sigs_file` (see config)

In [None]:
with open(sigs_file, "r") as f:
    sigs = [bytes.fromhex(s.strip()) for s in f.readlines()]
print(f"Loaded {len(sigs)} signatures")

# Tooling sanity check

This section tries to generate a signature using the same key and randomization values as the first signature in `sigs_file`.
We expect them to match.

In [None]:
if sanity_check:
    slh_dsa = fips205.SLH_DSA(params)
    (_, n, h, d, hp, a, k, lgw, m) = fips205.SLH_DSA_PARAM[params]
    wots_len = slh_dsa.len
    wots_bytes = wots_len * n
    xmss_bytes = hp * n
    fors_bytes = k*(n + a * n)
    sig_len = n + fors_bytes + d * (wots_bytes + xmss_bytes)

    sig = sigs[0]

    m = sig[sig_len:]
    r = sig[:n]
    pysig = slh_dsa.slh_sign_internal(m, sk, None, r=r)
    pysig += m

    if pysig != sig:
        print("Signature mismatch")
        print(pysig.hex())
        print(sig.hex())
    print("Passed sanity check")
else:
    print("Skipping sanity check")

# Update Experiment
Run this cell (and below) to process new signatures

In [None]:
if processed_sigs > 0:
    print(f"Skipping {processed_sigs} signatures")

sigs = sigs[processed_sigs:]

print(f"Processing {len(sigs)} signatures...", end=' ')

groups = merge_groups(groups, group_keys_by_addr(pk, sigs))

print(f"done!")

# update processed_sigs for consecutive runs
processed_sigs += len(sigs)

# Load simulated faults
This section loads simulated faults from `faulty_sigs_file`

In [None]:
with open(faulty_sigs_file, "r") as f:
    faulty_sigs = [bytes.fromhex(s.strip()) for s in f.readlines()]
    #sigs = [(s[0], bytes.fromhex(s[1])) for s in sigs]
    #sigs_correct = [s[1] for s in sigs if s[0] == '[CORRECT]']
    #sigs_faulty = [s[1] for s in sigs if s[0] == '[FAULTY]']
print(f"Loaded {len(faulty_sigs)} faulty (simulated) signatures")
simulated_groups = group_keys_by_addr(pk, faulty_sigs)
groups = merge_groups(groups, simulated_groups)
print_groups(groups)

# Results

...appear here!

In [None]:
print_groups(groups)