# Config

In [1]:
base = "../"
#base = "../victims/sphincsplus/ref/"
keys_file = base + "keys.txt"
sigs_file = base + "sigs.txt"
faulty_sigs_file = base + "sigs_faulty.txt"

params = 'SLH-DSA-SHAKE-256s'

sanity_check = False

# Setup

Install dependencies, define some helper functions

In [2]:
%pip install -r requirements.txt
!python --version

Note: you may need to restart the kernel to use updated packages.
Python 3.10.15


In [3]:
import fips205

In [4]:
def print_adrs(adrs: fips205.ADRS, end='\n', verbose=False):
    hex = adrs.adrs().hex()
    if verbose:
        print('LAYER' + ' ' * 4 + 
              'TREE ADDR' + ' ' * 18 +
              'TYP' + ' ' * 6 +
              'KADR' + ' ' * 5 +
              'PADD = 0')
    print(' '.join([hex[i:i+8] for i in range(0, len(hex), 8)]), end=' ')
    print(end=end)

In [5]:
def group_keys_by_addr(pk: bytes, sigs: list[bytes]) -> dict[fips205.ADRS, set[fips205.WOTSKeyData]]:
    import multiprocessing
    from cryptanalysis_lib import process_sig
    slh = fips205.SLH_DSA(params)
    a = slh.a
    d = slh.d
    hp = slh.hp
    n = slh.n
    k = slh.k
    wots_bytes = slh.len * n
    xmss_bytes = hp * n
    fors_bytes = k * (n + a * n)
    sig_len = n + fors_bytes + d * (wots_bytes + xmss_bytes)

    with multiprocessing.Pool() as pool:
        args = [(params, pk, sig, sig_len) for sig in sigs]
        results = pool.map(process_sig, args)

    # Merge results
    merged = {}
    for item in results:
        merged = merge_groups(merged, item)
    return merged

def merge_groups(left: dict[fips205.ADRS, set], right: dict[fips205.ADRS, set]) -> dict[fips205.ADRS, set]:
    for key, items in right.items():
        if key not in left:
            left[key] = set()
        left[key] = left[key] | items
    return left


In [6]:
def print_groups(groups: dict[fips205.ADRS, list[fips205.WOTSKeyData]]):
    (_,n,_,_,_,_,_,_,_) = fips205.SLH_DSA_PARAM[params]
    print(f"Found {len(groups)} unique addresses")
    faulted = [(adrs, value) for adrs, value in groups.items() if False in [v.valid for v in value]]
    print(f"Found {len(faulted)} groups with faulty keys")
    collisions = [(adrs, value) for adrs, value in faulted if len(value) > 1]
    N = len(faulted)
    K = 256
    expected_collisions = (N*(N-1))/(2*K)
    print(f"Found {len(collisions)} groups with collisions (expected: {expected_collisions})")
    collisions: list[tuple[fips205.ADRS, set[fips205.WOTSKeyData]]] = [(adrs, value) for adrs, value in collisions if any(v.valid for v in value) and not all(v.valid for v in value)]
    print(f"Found {len(collisions)} groups with at least one valid and one invalid key")
    """collisions = [
        (adrs, value)
        for adrs, value in collisions
        if any(v1.msg != v2.msg and v1.pk == v2.pk for v1 in value for v2 in value if v1.valid and not v2.valid)
    ]"""
    def has_shared_intermediate(v1, value):
        return any(
            inter in v2.sig
            for chain in v1.intermediates
            for inter in chain
            for v2 in value if v2.valid
        )
    def find_all_shared_intermediate_offsets(v1, valid_sig):
        results = []
        for chain_idx, chain in enumerate(v1.intermediates):
            for exposed_count, inter in enumerate(chain):
                try:
                    offset = valid_sig.sig.index(inter)
                    results.append((chain_idx, exposed_count, offset/n))
                except ValueError:
                    continue
        return results
    collisions = [
        (adrs, value)
        for adrs, value in collisions
        if any(not v1.valid and has_shared_intermediate(v1, value) for v1 in value)
    ]
    print(f"Found {len(collisions)} groups with exposed WOTS secrets")

    # sort by layer address
    collisions = sorted(collisions, key=lambda item: item[0].get_layer_address())
    for adrs, value in collisions:
        valid_sig = [v for v in value if v.valid][0]
        invalid_sigs = [v for v in value if not v.valid]
        print_adrs(adrs, end='', verbose=True)
        print(len(value))
        
        for v in [valid_sig] + invalid_sigs:
            print("Valid" if v.valid else "Invalid", end='\t')
            print(v.sig.hex())
            print('\t', end='')
            print(v.pk.hex())
            print('\t', end='')
            for chain_idx in v.msg:
                print(f"{chain_idx}", end=' ')
            print()
            print("Exposed WOTS secrets: ", end='')
            print(find_all_shared_intermediate_offsets(v, valid_sig))

# Clean Start

Run this cell (and below) for a clean analysis

In [7]:
processed_sigs = 0
groups: dict[fips205.ADRS, set[fips205.WOTSKeyData]] = {}

In [8]:
# Load keys
with open(keys_file, "r") as f:
    lines = [s.split(': ') for s in f.readlines()]
    keys = {s[0]: bytes.fromhex(s[1].strip()) for s in lines}
sk = keys['sk']
pk = keys['pk']
pk.hex()

'0bc005bdb4dfb431bb250e109ca4430d42f4fd7e9270f515640701df308413952d854a500f3e893a8804ad88a600ee6812c3317e422848c5854c2b18588c1b9a'

# Load real signatures

This loads the real signatures from `sigs_file` (see config)

In [9]:
with open(sigs_file, "r") as f:
    sigs = [bytes.fromhex(s.strip()) for s in f.readlines()]
print(f"Loaded {len(sigs)} signatures")

Loaded 20720 signatures


# Tooling sanity check

This section tries to generate a signature using the same key and randomization values as the first signature in `sigs_file`.
We expect them to match. This assumes that the first signature in `sigs_file` is a valid signature.

In [10]:
if sanity_check:
    slh_dsa = fips205.SLH_DSA(params)
    (_, n, h, d, hp, a, k, lgw, m) = fips205.SLH_DSA_PARAM[params]
    wots_len = slh_dsa.len
    wots_bytes = wots_len * n
    xmss_bytes = hp * n
    fors_bytes = k*(n + a * n)
    sig_len = n + fors_bytes + d * (wots_bytes + xmss_bytes)

    sig = sigs[0]

    m = sig[sig_len:]
    r = sig[:n]
    pysig = slh_dsa.slh_sign_internal(m, sk, None, r=r)
    pysig += m

    if pysig != sig:
        print("Signature mismatch")
        print(pysig.hex())
        print(sig.hex())
    print("Passed sanity check")
else:
    print("Skipping sanity check")

Skipping sanity check


# Update Experiment
Run this cell (and below) to process new signatures

In [None]:
if processed_sigs > 0:
    print(f"Skipping {processed_sigs} signatures")

sigs = sigs[processed_sigs:]

print(f"Processing {len(sigs)} signatures...", end=' ')

groups = merge_groups(groups, group_keys_by_addr(pk, sigs))

print(f"done!")

# update processed_sigs for consecutive runs
processed_sigs += len(sigs)

Processing 20720 signatures... 

# Load simulated faults
This section loads simulated faults from `faulty_sigs_file`

In [None]:
with open(faulty_sigs_file, "r") as f:
    faulty_sigs = [bytes.fromhex(s.strip()) for s in f.readlines()]
    #sigs = [(s[0], bytes.fromhex(s[1])) for s in sigs]
    #sigs_correct = [s[1] for s in sigs if s[0] == '[CORRECT]']
    #sigs_faulty = [s[1] for s in sigs if s[0] == '[FAULTY]']
print(f"Loaded {len(faulty_sigs)} faulty (simulated) signatures")
simulated_groups = group_keys_by_addr(pk, faulty_sigs)
groups = merge_groups(groups, simulated_groups)
print_groups(groups)

# Results

...appear here!

In [None]:
print_groups(groups)