# Config

In [13]:
base = "../"
#base = "../victims/sphincsplus/ref/"
keys_file = base + "keys.txt"
sigs_file = base + "sigs.txt"
sigs_simulated_file = base + "sigs_simulated.txt"

params = 'SLH-DSA-SHAKE-256s'

sanity_check = True
simulate_faults = True
filter_sigs  = True

# Setup

Install dependencies, define some helper functions

In [14]:
%pip install -r requirements.txt
!python --version


[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m25.0.1[0m[39;49m -> [0m[32;49m25.1.1[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpip install --upgrade pip[0m
Note: you may need to restart the kernel to use updated packages.
Python 3.10.12


In [15]:
import fips205

In [16]:
slh = fips205.SLH_DSA(params)
a = slh.a
d = slh.d
hp = slh.hp
n = slh.n
k = slh.k

In [17]:
def print_adrs(adrs: fips205.ADRS, end='\n', verbose=False):
    hex = adrs.adrs().hex()
    if verbose:
        print('LAYER' + ' ' * 4 + 
              'TREE ADDR' + ' ' * 18 +
              'TYP' + ' ' * 6 +
              'KADR' + ' ' * 5 +
              'PADD = 0')
    print(' '.join([hex[i:i+8] for i in range(0, len(hex), 8)]), end=' ')
    print(end=end)

In [18]:
def extract_wots_keys(pk: bytes, sigs: list[bytes]) -> dict[fips205.ADRS, set[fips205.WOTSKeyData]]:
    import multiprocessing
    from cryptanalysis_lib import process_sig
    wots_bytes = slh.len * n
    xmss_bytes = hp * n
    fors_bytes = k * (n + a * n)
    sig_len = n + fors_bytes + d * (wots_bytes + xmss_bytes)

    with multiprocessing.Pool() as pool:
        args = [(params, pk, sig_idx, sig, sig_len) for sig_idx, sig in enumerate(sigs)]
        results = pool.map(process_sig, args)
    
    # Merge results
    merged = {}
    for item in results:
        merged = merge_groups(merged, item)
    return merged

def merge_groups(left: dict[fips205.ADRS, set], right: dict[fips205.ADRS, set]) -> dict[fips205.ADRS, set]:
    for key, items in right.items():
        if key not in left:
            left[key] = set()
        left[key] = left[key] | items
    return left


In [None]:
def has_shared_intermediate(v1: fips205.WOTSKeyData, valid: fips205.WOTSKeyData) -> bool:
    for chain_idx, chain in enumerate(v1.intermediates):
        for step in chain:
            if step == valid.sig[chain_idx*n:(chain_idx+1)*n]:
                return True
    return False

def find_collisions(wots_sigs: dict[fips205.ADRS, set[fips205.WOTSKeyData]]) -> dict[fips205.ADRS, set[fips205.WOTSKeyData]]:
    return {adrs: keys for adrs, keys in wots_sigs.items() if len(keys) > 1}

def print_arr_w(arr: list[int], width=int):
    print('[ ', end='')
    for x in arr:
        print(f"{x:0{width}d}", end=' ')
    print(']')
    
def valid_sigs_d(groups):
    return {adrs: key for adrs, keys in groups.items() for key in keys if key.valid}

def find_all_shared_intermediate_offsets(v1, valid_sig):
    if not valid_sig:
        return None
    results = []
    for chain_idx, chain in enumerate(v1.intermediates):
        for hash_iter, step in enumerate(chain[1:], start=1):
            if step == valid_sig.sig[chain_idx*n:(chain_idx+1)*n]:
                # find substring in valid_sig.sig starting with `step` (i.e., the application of `expose_count`+1 times to the value exposed in WOTS key)
                offset = chain_idx*n
                results.append((chain_idx, hash_iter, offset/n))
    return results

def hex(s: bytes | None) -> str:
    return s.hex() if s else "None"

def print_key_data(v: fips205.WOTSKeyData, adrs: fips205.ADRS, pk_seed: bytes, valid_key: fips205.WOTSKeyData = None):
    print("Valid" if v.valid else "Invalid" if v.valid == False else "--", end='\t')
    print(v.sig.hex())
    print('\tPK (from tree)\t' + hex(v.pk))
    pk = v.calculate_pk(params, adrs, pk_seed)
    print('\tPK (calculated)\t' + hex(pk))
    print(f"\tWOTS key is part of signature {v.sig_idx}")
    print('\t\t\t', end='')
    print_arr_w([i for i in range(len(v.chains))], 2)
    print('\t' + "chains\t\t", end='')
    print_arr_w(v.chains, 2)
    print('\t' + "chains (calc)\t", end = '')
    print_arr_w(v.chains_calculated, 2)
    if valid_key and not v.valid:
        shared_intermediates = find_all_shared_intermediate_offsets(v, valid_key)
        shared_intermediates = sorted(shared_intermediates, key=lambda item: item[0])
        for chain_idx, exposed, offset in shared_intermediates:
            assert offset == chain_idx
            print(f"\t\tExposed {exposed} secret values at chain_idx {chain_idx} offset {offset}")
    

def print_groups(pk_seed: bytes, groups: dict[fips205.ADRS, list[fips205.WOTSKeyData]], skip_no_exposed=True):
    #collisions: list[tuple[fips205.ADRS, set[fips205.WOTSKeyData]]] = [(adrs, value) for adrs, value in collisions if any(v.valid for v in value) and not all(v.valid for v in value)]
    #print(f"Found {len(collisions)} groups with at least one valid and one invalid key")
    # collisions where PK match. This is not necessary. Better: find exposed keys by running WOTS chain
    """collisions = [
        (adrs, value)
        for adrs, value in collisions
        if any(v1.msg != v2.msg and v1.pk == v2.pk for v1 in value for v2 in value if v1.valid and not v2.valid)
    ]"""
    
    valid_sigs = valid_sigs_d(groups)

    # sort by layer address
    groups = sorted(groups.items(), key=lambda item: item[0].get_layer_address())
    
    for adrs, value in groups:
        valid_sig = valid_sigs[adrs] if adrs in valid_sigs else None
        invalid_sigs = [v for v in value if not v.valid]
        print_adrs(adrs, end='', verbose=True)
        print(len(value))
        
        for v in [valid_sig] + invalid_sigs:
            if not v:
                continue
            shared_intermediates = find_all_shared_intermediate_offsets(v, valid_sig)
            if skip_no_exposed and valid_sig and not v.valid and not shared_intermediates:
                continue
            print_key_data(v, adrs, pk_seed, valid_sig)

# Clean Start

Run this cell (and below) for a clean analysis

In [20]:
processed_sigs = 0
groups: dict[fips205.ADRS, set[fips205.WOTSKeyData]] = {}

In [21]:
# Load keys
with open(keys_file, "r") as f:
    lines = [s.split(': ') for s in f.readlines()]
    keys = {s[0]: bytes.fromhex(s[1].strip()) for s in lines}
sk = keys['sk']
pk = keys['pk']
pk_seed = pk[:n]
pk.hex()

'0bc005bdb4dfb431bb250e109ca4430d42f4fd7e9270f515640701df308413952d854a500f3e893a8804ad88a600ee6812c3317e422848c5854c2b18588c1b9a'

# Update Experiment
Run this cell (and below) to process new signatures

# Load real signatures

This loads the real signatures from `sigs_file` (see config)

In [22]:
with open(sigs_file, "r") as f:
    sigs = [bytes.fromhex(s.strip()) for s in f.readlines()]
    # sigs = sigs[:1000]
print(f"Loaded {len(sigs)} signatures")

Loaded 29065 signatures


# Tooling sanity check

This section tries to generate a signature using the same key and randomization values as the first signature in `sigs_file`.
We expect them to match. This assumes that the first signature in `sigs_file` is a valid signature.

In [23]:
if sanity_check:
    wots_len = slh.len
    wots_bytes = wots_len * n
    xmss_bytes = hp * n
    fors_bytes = k*(n + a * n)
    sig_len = n + fors_bytes + d * (wots_bytes + xmss_bytes)

    sig = sigs[0]

    m = sig[sig_len:]
    r = sig[:n]
    pysig = slh.slh_sign_internal(m, sk, None, r=r)
    pysig += m

    if pysig != sig:
        print("Signature mismatch")
        print(pysig.hex())
        print(sig.hex())
    print("Passed sanity check")
else:
    print("Skipping sanity check")

Passed sanity check


# Extract WOTS keys

Extract all WOTS keys in all signatures

In [24]:
if processed_sigs > 0:
    print(f"Skipping {processed_sigs} signatures")

sigs = sigs[processed_sigs:]
print(f"Processing {len(sigs)} signatures...", end=' ')

groups = merge_groups(groups, extract_wots_keys(pk, sigs))
print(f"Found {len(groups)} unique addresses")

# update processed_sigs for consecutive runs
processed_sigs += len(sigs)

Processing 29065 signatures... Found 198057 unique addresses


# Load simulated faults
This section loads simulated faults from `sigs_simulated_file`

In [None]:
import os
if simulate_faults and os.path.exists(sigs_simulated_file):
    with open(sigs_simulated_file, "r") as f:
        faulty_sigs = [bytes.fromhex(s.strip()) for s in f.readlines()]
    print(f"Loaded {len(faulty_sigs)} faulty (simulated) signatures")
    simulated_groups = extract_wots_keys(pk, faulty_sigs)
    for adrs, keys in simulated_groups.items():
        for key in keys:
            key.simulated = True
    print_groups(pk_seed, simulated_groups)
    groups = merge_groups(groups, simulated_groups)
else:
    print("Fault simulation disabled or no simulated faulty signatures found")

Loaded 1591 faulty (simulated) signatures
LAYER    TREE ADDR                  TYP      KADR     PADD = 0
00000000 00000000 007535c4 59e8950d 00000001 000000af 00000000 00000000 1
Invalid	b85094b2f5b0dda09c9ef2fc1bd4a858a9fc8033ca1c493ad964294a1e9a90360914ea6676d8d14947c106bf78caae79bdb75c9fd5f14b4813f3a64b05e007d1f46274351b6e2e03c760c4c7e42e5cb65f04427d27fa0f0b68f42e6a886a33289cadfe0e9f50df27046ce070babd86f8004ef0a43bdfc7b1ce527cc69bd65575d1bda7f0b0fadabd7a15cc624540621851efd4c3898f2c8fbefdb9895f1e88c1cac060566001b1472ba247b3eda741228c285f0df05b156a136280de9fc282be2e0ac5c35882b9650cb65b1dd58cd18343280e0d115294087a5f6c768c3b4fe1b68149c9d4fbd1742f90342507665570978aaf55464938b93383bbad3735d2f73958574cd96ac2337ce6c8642e905fb0b68f3b4d6813c6bb2214240e15ca4251243e9ba6e3fa43d573fc0544809b1c46717cda642ab316fb0fc0b8fb91cb742c1e2115bb480ac99227ec44b41c7eee21bc3e5c41eb4d30185898d113fa21735e65671e6a0424852980c3008a86031b6925f12c0bd12ab3e21e6fb9d1542ba43b7cd26ba18d0937500dd0d3cfe9b30caa359d3569dd8b9

TypeError: 'NoneType' object is not iterable

# Results

...appear here!

In [None]:
for adrs, keys in groups.items():
    if len([v for v in keys if v.valid]) > 1:
        print("ERROR: found multiple valid keys for the same address", adrs, keys)
        raise ValueError("Multiple valid keys for the same address")


# maintain a dictionary of valid signatures per adrs
valid_sigs = valid_sigs_d(groups)

# only keep keys where we also observered valid signatures
groups = {adrs: sigs for adrs, sigs in groups.items() if adrs in valid_sigs}
print(f"Found {len(groups)} groups with valid signatures")

groups = find_collisions(groups)
print(f"Found {len(groups)} groups with collisions")

# post-process collided WOTS keys
for adrs, keys in groups.items():
    if adrs not in valid_sigs:
        continue
    valid_sig = valid_sigs[adrs]
    valid_sig.calculate_intermediates(params, adrs, pk_seed, valid_sig)
    for key in keys:
        key.calculate_intermediates(params, adrs, pk_seed, valid_sig)

print_groups(pk_seed, groups, skip_no_exposed=False)

In [None]:
# filter groups with exposed WOTS secrets
groups = {adrs: sigs 
    for adrs, sigs in groups.items()
    if any(not sig.valid and has_shared_intermediate(sig, valid_sigs[adrs]) for sig in sigs)
}
print(f"Found {len(groups)} groups with exposed WOTS secrets")
print_groups(pk_seed, groups)

In [None]:
def filter_signatures(sigs: list[bytes], wots_keys: dict[fips205.ADRS, set[fips205.WOTSKeyData]]) -> set[bytes]:
    filtered_sigs = set()
    collisions = find_collisions(wots_keys)
    print(f"Found {len(collisions)} groups with collisions")
    collisions = {adrs: sigs for adrs, sigs in collisions.items() if any(sig.valid for sig in sigs) and not all(sig.valid for sig in sigs)}
    valid_sigs = valid_sigs_d(wots_keys)
    for adrs, keys in collisions.items():
        for key in keys:
            if key.valid:
                # keep valid keys
                filtered_sigs.add(sigs[key.sig_idx])
            if adrs in valid_sigs and has_shared_intermediate(key, valid_sigs[adrs]):
                filtered_sigs.add(sigs[key.sig_idx])
    return filtered_sigs

if filter_sigs:
    # only keep signatures that are valid or have exposed WOTS keys
    filtered_sigs = filter_signatures(sigs, groups)
    print(f"Kept {len(filtered_sigs)} of {len(sigs)} signatures")
    with open("../sigs_filtered.txt", "w") as f:
        for sig in filtered_sigs:
            f.write(sig.hex() + "\n")

In [None]:
# Join WOTS keys to get a WOTS key usable for signing as many messages as possible
from copy import deepcopy


def join_sigs(wots_keys: dict[fips205.ADRS, set[fips205.WOTSKeyData]], params, pk_seed) -> dict[fips205.ADRS, fips205.WOTSKeyData]:
    joined_sigs = {}
    valid_sigs = valid_sigs_d(wots_keys)
    for adrs, keys in wots_keys.items():
        if len(keys) < 2:
            continue
        retval = deepcopy(valid_sigs[adrs])
        if not retval.chains_calculated:
            print_adrs(adrs)
            retval.calculate_intermediates(params, adrs, pk_seed, valid_sigs[adrs])
        for key in keys:
            retval = retval.join(key, params)
        joined_sigs[adrs] = retval
    return joined_sigs

joined_sigs = join_sigs(groups, params, pk_seed)
print(f"Joined {len(joined_sigs)} signatures")

"""
for adrs, key in joined_sigs.items():
    print_adrs(adrs, verbose=True)
    print_key_data(key, adrs, pk_seed, None)
"""

# Filter out signatures with unresolved chains
joined_sigs = {adrs: key for adrs, key in joined_sigs.items() if all(c < slh.w for c in key.chains_calculated)}

print(f"Filtered {len(joined_sigs)} (unresolved chains)")

# Filter out signatures with all chains set to 0
joined_sigs = {adrs: key for adrs, key in joined_sigs.items() if any(c > 0 for c in key.chains_calculated)}
print(f"Filtered {len(joined_sigs)} (wildcard chains)")

"""
for adrs, key in joined_sigs.items():
    print_adrs(adrs, verbose=True)
    print_key_data(key, adrs, pk_seed, None)
""" 

# Find the key where the maximum number of chains are exposed (i.e., valid_sig.chain[i] - c is maximized for all chains 0 <= i < len)
def score(key: fips205.WOTSKeyData, valid_sig: fips205.WOTSKeyData) -> int:
    score = sum(valid_sig.chains[i] - c for i, c in enumerate(key.chains_calculated))
    #msg_score = sum((slh.w - 1 - c for c in valid_sig.chains_calculated))
    #cksum_score = sum(slh.w - 1 - c for _, c in enumerate(key.chains_calculated[slh.len1:]))
    return score
most_exposed_adrs, most_exposed_key = max(
    joined_sigs.items(),
    key=lambda item: score(item[1], valid_sigs[item[0]])
)
print("Key with most exposed chains (score: {}):".format(
    score(most_exposed_key, valid_sigs[most_exposed_adrs])
))
print_adrs(most_exposed_adrs, verbose=True)
print_key_data(most_exposed_key, most_exposed_adrs, pk_seed, None)

# try to sign the original message with the exposed key
print("Signing original (valid) msg with exposed key...")
assert most_exposed_key.try_sign(most_exposed_key.chains, most_exposed_adrs, pk_seed, params)

# signing message with same checksum
# Find two offsets i, j where the chain is partially exposed
exposed_offsets = [(idx, c - cc) for idx, (c, cc) in enumerate(zip(most_exposed_key.chains[:slh.len1], most_exposed_key.chains_calculated[:slh.len1])) if abs(c - cc) > 0]
if len(exposed_offsets) >= 2:
    i, j = exposed_offsets[:2]
    print(f"Found exposed offsets: i={i}, j={j}")
    chains = most_exposed_key.chains_calculated.copy()
    chains[i[0]] += 1
    print("Signing modified message with exposed key...")
    assert most_exposed_key.try_sign(chains, most_exposed_adrs, pk_seed, params)
else:
    print("Not enough exposed offsets found")

print("="*64)
print("All keys for exposed address:")
print_adrs(most_exposed_adrs, verbose=True)
for key in groups[most_exposed_adrs]:
    print_key_data(key, most_exposed_adrs, pk_seed, valid_sigs[most_exposed_adrs])


In [None]:
import math
from multiprocessing import Pool
from os import cpu_count
from cryptanalysis_lib import sign_worker

def sign_message_batch_mp(total_msgs, joined_sigs_items, msg_len, pk_seed, params, num_procs=None):
    """
    Use a process pool to sign `total_msgs` messages *per* (adrs,key).
    Returns the grand total of successful signatures.
    """
    if num_procs is None:
        num_procs = cpu_count()

    # Split total_msgs into roughly equal chunks per process
    per_proc = math.ceil(total_msgs / num_procs)
    print("Total messages per process:", per_proc)

    # Build one work item per (process x key)
    work = []
    for adrs, key in joined_sigs_items.items():
        for _ in range(num_procs):
            work.append((per_proc, msg_len, adrs, key, pk_seed, params))

    # Spawn the pool
    with Pool(processes=num_procs) as pool:
        # map returns one result per work item
        results = pool.map(sign_worker, work, chunksize=1)

    return sum(results)

num_sigs = 2**10

total_success = sign_message_batch_mp(num_sigs, {most_exposed_adrs: most_exposed_key}, slh.len1, pk_seed, params, num_procs=cpu_count()-2)

print("Total messages signed:", num_sigs)
print("Total successful signatures:", total_success)
print("Success ratio:", total_success/num_sigs)