# Config

In [1]:
base = "../"
#base = "../victims/sphincsplus/ref/"
keys_file = base + "keys.txt"
sigs_file = base + "sigs_filtered.txt"
faulty_sigs_file = base + "sigs_faulty.txt"

params = 'SLH-DSA-SHAKE-256s'

sanity_check = False
filter_sigs  = False

# Setup

Install dependencies, define some helper functions

In [2]:
%pip install -r requirements.txt
!python --version

Note: you may need to restart the kernel to use updated packages.
Python 3.10.12


In [3]:
import fips205

In [4]:
def print_adrs(adrs: fips205.ADRS, end='\n', verbose=False):
    hex = adrs.adrs().hex()
    if verbose:
        print('LAYER' + ' ' * 4 + 
              'TREE ADDR' + ' ' * 18 +
              'TYP' + ' ' * 6 +
              'KADR' + ' ' * 5 +
              'PADD = 0')
    print(' '.join([hex[i:i+8] for i in range(0, len(hex), 8)]), end=' ')
    print(end=end)

In [5]:
def extract_wots_keys(pk: bytes, sigs: list[bytes]) -> dict[fips205.ADRS, set[fips205.WOTSKeyData]]:
    import multiprocessing
    from cryptanalysis_lib import process_sig
    slh = fips205.SLH_DSA(params)
    a = slh.a
    d = slh.d
    hp = slh.hp
    n = slh.n
    k = slh.k
    wots_bytes = slh.len * n
    xmss_bytes = hp * n
    fors_bytes = k * (n + a * n)
    sig_len = n + fors_bytes + d * (wots_bytes + xmss_bytes)
    
    pk_seed = pk[:slh.n]

    with multiprocessing.Pool() as pool:
        args = [(params, pk, sig_idx, sig, sig_len) for sig_idx, sig in enumerate(sigs)]
        print(f"Processing {len(args)} signatures in parallel")
        results = pool.map(process_sig, args)
    
    # Merge results
    merged = {}
    for item in results:
        merged = merge_groups(merged, item)
    return merged

def merge_groups(left: dict[fips205.ADRS, set], right: dict[fips205.ADRS, set]) -> dict[fips205.ADRS, set]:
    for key, items in right.items():
        if key not in left:
            left[key] = set()
        left[key] = left[key] | items
    return left


In [6]:
def has_shared_intermediate(v1: fips205.WOTSKeyData, valid: fips205.WOTSKeyData) -> bool:
    slh = fips205.SLH_DSA(params)
    n = slh.n
    for chain_idx, chain in enumerate(v1.intermediates):
        for step in chain:
            if step == valid.sig[chain_idx*n:(chain_idx+1)*n]:
                return True
    return False

def find_collisions(wots_sigs: dict[fips205.ADRS, set[fips205.WOTSKeyData]]) -> dict[fips205.ADRS, set[fips205.WOTSKeyData]]:
    return {adrs: keys for adrs, keys in wots_sigs.items() if len(keys) > 1}

def print_arr_w(arr: list[int], width=int):
    print('[ ', end='')
    for x in arr:
        print(f"{x:0{width}d}", end=' ')
    print(']')
    
    

def print_groups(pk_seed: bytes, groups: dict[fips205.ADRS, list[fips205.WOTSKeyData]]):
    (_,n,_,_,_,_,_,_,_) = fips205.SLH_DSA_PARAM[params]
    print(f"Found {len(groups)} unique addresses")
    collisions = find_collisions(groups)
    print(f"Found {len(collisions)} groups with collisions")
    #collisions: list[tuple[fips205.ADRS, set[fips205.WOTSKeyData]]] = [(adrs, value) for adrs, value in collisions if any(v.valid for v in value) and not all(v.valid for v in value)]
    #print(f"Found {len(collisions)} groups with at least one valid and one invalid key")
    # collisions where PK match. This is not necessary. Better: find exposed keys by running WOTS chain
    """collisions = [
        (adrs, value)
        for adrs, value in collisions
        if any(v1.msg != v2.msg and v1.pk == v2.pk for v1 in value for v2 in value if v1.valid and not v2.valid)
    ]"""
    def find_all_shared_intermediate_offsets(v1, valid_sig):
        results = []
        for chain_idx, chain in enumerate(v1.intermediates):
            for exposed_count, step in enumerate(chain):
                try:
                    # find substring in valid_sig.sig starting with `step` (i.e., the application of `expose_count`+1 times to the value exposed in WOTS key)
                    offset = valid_sig.sig.index(step)
                    results.append((chain_idx, exposed_count, offset/n))
                except ValueError:
                    continue
        return results
    
    # post-process collided WOTS keys
    for adrs, keys in collisions.items():
        valid_sig = [v for v in keys if v.valid][0]
        valid_sig.calculate_intermediates(params, adrs, pk_seed, valid_sig)
        for key in keys:
            key.calculate_intermediates(params, adrs, pk_seed, valid_sig)

    # maintain a dictionary of valid signatures per adrs
    valid_sigs = {adrs: key for adrs, keys in collisions.items() for key in keys if key.valid}
    
    collisions = {adrs: sigs 
        for adrs, sigs in collisions.items()
        if any(not sig.valid and has_shared_intermediate(sig, valid_sigs[adrs]) for sig in sigs)
    }
    print(f"Found {len(collisions)} groups with exposed WOTS secrets")

    # sort by layer address
    collisions = sorted(collisions.items(), key=lambda item: item[0].get_layer_address())
    
    for adrs, value in collisions:
        valid_sig = valid_sigs[adrs]
        invalid_sigs = [v for v in value if not v.valid]
        print_adrs(adrs, end='', verbose=True)
        print(len(value))
        
        for v in [valid_sig] + invalid_sigs:
            shared_intermediates = find_all_shared_intermediate_offsets(v, valid_sig)
            if not v.valid and not shared_intermediates:
                continue
            print("Valid" if v.valid else "Invalid", end='\t')
            print(v.sig.hex()[:64] + "...")
            print('\t' + v.pk.hex())
            print('\t\t\t', end='')
            print_arr_w([i for i in range(len(v.chains))], 2)
            print('\t' + "msg\t\t", end='')
            print_arr_w(v.chains, 2)
            print('\t' + "msg (calc)\t", end = '')
            print_arr_w(v.chains_calculated, 2)
            print(f"\tWOTS key is part of signature {v.sig_idx}")
            if not v.valid:
                shared_intermediates = sorted(shared_intermediates, key=lambda item: item[0])
                for chain_idx, exposed, offset in shared_intermediates:
                    assert offset == chain_idx
                    print(f"\t\tExposed {exposed} secret values at chain_idx {chain_idx} offset {offset}")
            
def filter_signatures(sigs: list[bytes], wots_keys: dict[fips205.ADRS, set[fips205.WOTSKeyData]]) -> set[bytes]:
    filtered_sigs = set()
    collisions = find_collisions(wots_keys)
    print(f"Found {len(collisions)} groups with collisions")
    collisions = {adrs: sigs for adrs, sigs in collisions if any(sig.valid for sig in sigs) and not all(sig.valid for sig in sigs)}
    valid_sigs = {adrs: key for adrs, keys in wots_keys.items() for key in keys if key.valid}
    for adrs, keys in collisions.items():
        for key in keys:
            if key.valid:
                # keep valid keys
                filtered_sigs.add(sigs[key.sig_idx])
            if adrs in valid_sigs and has_shared_intermediate(key, valid_sigs[adrs]):
                filtered_sigs.add(sigs[key.sig_idx])
    return filtered_sigs

# Clean Start

Run this cell (and below) for a clean analysis

In [7]:
processed_sigs = 0
groups: dict[fips205.ADRS, set[fips205.WOTSKeyData]] = {}

In [8]:
# Load keys
with open(keys_file, "r") as f:
    lines = [s.split(': ') for s in f.readlines()]
    keys = {s[0]: bytes.fromhex(s[1].strip()) for s in lines}
sk = keys['sk']
pk = keys['pk']
pk_seed = pk[:fips205.SLH_DSA(params).n]
pk.hex()

'0bc005bdb4dfb431bb250e109ca4430d42f4fd7e9270f515640701df308413952d854a500f3e893a8804ad88a600ee6812c3317e422848c5854c2b18588c1b9a'

# Load real signatures

This loads the real signatures from `sigs_file` (see config)

In [9]:
with open(sigs_file, "r") as f:
    sigs = [bytes.fromhex(s.strip()) for s in f.readlines()]
print(f"Loaded {len(sigs)} signatures")

Loaded 480 signatures


# Tooling sanity check

This section tries to generate a signature using the same key and randomization values as the first signature in `sigs_file`.
We expect them to match. This assumes that the first signature in `sigs_file` is a valid signature.

In [10]:
if sanity_check:
    slh_dsa = fips205.SLH_DSA(params)
    (_, n, h, d, hp, a, k, lgw, m) = fips205.SLH_DSA_PARAM[params]
    wots_len = slh_dsa.len
    wots_bytes = wots_len * n
    xmss_bytes = hp * n
    fors_bytes = k*(n + a * n)
    sig_len = n + fors_bytes + d * (wots_bytes + xmss_bytes)

    sig = sigs[0]

    m = sig[sig_len:]
    r = sig[:n]
    pysig = slh_dsa.slh_sign_internal(m, sk, None, r=r)
    pysig += m

    if pysig != sig:
        print("Signature mismatch")
        print(pysig.hex())
        print(sig.hex())
    print("Passed sanity check")
else:
    print("Skipping sanity check")

Skipping sanity check


# Update Experiment
Run this cell (and below) to process new signatures

In [11]:
if processed_sigs > 0:
    print(f"Skipping {processed_sigs} signatures")

sigs = sigs[processed_sigs:]

print(f"Processing {len(sigs)} signatures...", end=' ')

groups = merge_groups(groups, extract_wots_keys(pk, sigs))

print(f"done!")

# update processed_sigs for consecutive runs
processed_sigs += len(sigs)

Processing 480 signatures... Processing 480 signatures in parallel
done!


# Load simulated faults
This section loads simulated faults from `faulty_sigs_file`

In [12]:
import os
if os.path.exists(faulty_sigs_file):
    with open(faulty_sigs_file, "r") as f:
        faulty_sigs = [bytes.fromhex(s.strip()) for s in f.readlines()]
        #sigs = [(s[0], bytes.fromhex(s[1])) for s in sigs]
        #sigs_correct = [s[1] for s in sigs if s[0] == '[CORRECT]']
        #sigs_faulty = [s[1] for s in sigs if s[0] == '[FAULTY]']
    print(f"Loaded {len(faulty_sigs)} faulty (simulated) signatures")
    simulated_groups = extract_wots_keys(pk, faulty_sigs)
    groups = merge_groups(groups, simulated_groups)
    print_groups(pk, groups)
else:
    print("No simulated faulty signatures found")

No simulated faulty signatures found


# Results

...appear here!

In [13]:
print_groups(pk_seed, groups)

Found 3597 unique addresses
Found 13 groups with collisions
Signature 2: Found preimage 923b72d6def48604e3fcd7162d020bde6655be916808c5932871975868fd0091 of 923b72d6def48604e3fcd7162d020bde6655be916808c5932871975868fd0091 at step 0
Signature 2: Found preimage 06b958e4dbb0dba8e13c09f66a758d570c97989d3ca320b26713f971a353c57e of 06b958e4dbb0dba8e13c09f66a758d570c97989d3ca320b26713f971a353c57e at step 0
Signature 2: Found preimage 38f39ab17d09cfccabea9bfccf3fbb994f1060ea828f13887ac919530b100174 of 38f39ab17d09cfccabea9bfccf3fbb994f1060ea828f13887ac919530b100174 at step 0
Signature 2: Found preimage 7598dd521adbe2e3f437db03a0d6f2749b2fed33aca401f9268bf7d0025c9000 of 7598dd521adbe2e3f437db03a0d6f2749b2fed33aca401f9268bf7d0025c9000 at step 0
Signature 2: Found preimage dd2b1d8ea2c7543725fd3f9ffb272f36a02b9779b0e0e10e6cf8776598c0224f of dd2b1d8ea2c7543725fd3f9ffb272f36a02b9779b0e0e10e6cf8776598c0224f at step 0
Signature 2: Found preimage f6c9ede6cbbef337c3e29c2d00b46c07d3774100383d52e87aca24aef

In [14]:
if filter_sigs:
    # only keep signatures that are valid or have exposed WOTS keys
    filtered_sigs = filter_signatures(sigs, groups)
    print(f"Kept {len(filtered_sigs)} of {len(sigs)} signatures")
    with open("../sigs_filtered.txt", "w") as f:
        for sig in filtered_sigs:
            f.write(sig.hex() + "\n")