# SPHINCS+ algorithm

**FTS schemes** are a signature schemes that allow a key
pair to produce a small number of signatures.

The basic idea is to authenticate a huge number of few-time signature (FTS)
key pairs using a so-called hypertree.

A **hypertree** is a tree of hash-based many-time signatures (MTS)

In [21]:
from math import ceil, floor, log2
from copy import deepcopy
from hashlib import sha256
import hmac
from cryptography.hazmat.primitives import hashes
from cryptography.hazmat.primitives.asymmetric import padding
from random import randbytes
from bitarray import bitarray

In [48]:
sha256(b'').hexdigest()

'e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855'

## Useful functions

In [22]:
def rot_left(n, d):
    return (n << d)|(n >> (32 - d)) & 0xFFFFFFFF

def rot_right(n, d):
    return (n >> d)|(n << (32 - d)) & 0xFFFFFFFF

#### Base-w

Funkcja konwertująca ciąg bajtów na tablicę poszczególnych bajtów w danym systemie liczbowym _w_. Parametr _out_len_ ogranicza długość tablicy wynikowej.

In [23]:
def xor(a, b):
    return bytes(a[i] ^ b[i] for i in range(len(a)))

def num_to_bytes(num, length=8):
    return int.to_bytes(num, length=length, byteorder='big')

In [24]:
def bits_to_int(x):
    bit_string = x.to01()
    number = int(bit_string, 2)
    return number

def int_to_bits(x):
    bit_array = bitarray()
    bit_array.frombytes(num_to_bytes(x))
    return bit_array

In [25]:
def base_w(X, w, out_len):
    bit_array = bitarray()
    bit_array.frombytes(X)
    m = []
    for i in range(out_len):
        mpart = bit_array[:int(log2(w))]
        bit_array = bit_array[int(log2(w)):]
        number = bits_to_int(mpart)
        m.append(number)
    return m

base_w(b'\x12\x34', 16, 4)

[1, 2, 3, 4]

In [26]:
def adrs_to_bytes(adrs):
    res = b''
    for byte in adrs:
        res += byte
    return res

## Hash Functions

<font color='red'>TODO: Shorter Outputs. If a parameter set requires an output length n < 32-bytes for F, H, PRF,and PRFmsg we take the first n bytes of the output and discard the remaining.</font>


### Generating the masks

We generate the bitmasks for arbitrary length messages using the MFG1 function.

In [27]:
def MGF1(m, size, h=sha256):
    hlen = 32
    output = b''
    for i in range(0, ceil(size/hlen)):
        ibytes = i.to_bytes(4, 'big')
        tmp = h(m + ibytes).digest()
        output = output + tmp
    return output[: size + 1]

In [54]:
mgf_res = MGF1(bytes.fromhex('3b5c056af3ebba70d4c805380420585562b32410a778f558ff951252407647e3'), 34)
mgf_res.hex()

'5b7eb772aecf04c74af07d9d9c1c1f8d3a90dcda00d5bab1dc28daecdc86eb87611e5a'

In [28]:
def generate_mask(m, pk_seed, adrs, mask_len):
    mask = MGF1(pk_seed + adrs_to_bytes(adrs), mask_len)
    return xor(m,mask)

### SPHINCS+ -SHA-256 Hash Functions

![image.png](attachment:image.png)

In [59]:
def F(pk_seed, adrs, m1):
    mask = generate_mask(m1, pk_seed, adrs, len(m1))
    mess = pk_seed + adrs_to_bytes(adrs) + m1
    return sha256(mess).digest()

In [30]:
def H(pk_seed, adrs, m1, m2):
    mask1 = generate_mask(m1, pk_seed, adrs, len(m1))
    mask2 = generate_mask(m2, pk_seed, adrs, len(m2))
    mess = pk_seed + adrs_to_bytes(adrs) + mask1 + mask2
    return sha256(mess).digest()

In [31]:
def H_msg(r, pk_seed, pk_root, m):
    mess = r + pk_seed + pk_root + m
    sha = sha256(mess).digest()
    return MGF1(sha, len(m))

In [32]:
def PRF(seed, adrs):
    mess = seed + adrs_to_bytes(adrs)
    return sha256(mess).digest()

In [33]:
def PRF_msg(sk_seed, sk_prf, opt_rand, m):
    hmac_value = hmac.new(sk_prf, opt_rand + m , sha256)
    return hmac_value.digest()

### Hash Function Address Scheme

An address ADRS is a 32-byte value that follows a defined structure. We use 5 different address structures later in the algorithms.

1) WOTS+ hash address:

![image-3.png](attachment:image-3.png)

2) WOTS+ public key compression address:

![image-2.png](attachment:image-2.png)

3) Hash tree address:

![image-4.png](attachment:image-4.png)

4) FORS tree address:

![image-5.png](attachment:image-5.png)

5) FORS tree roots compression address:

![image-6.png](attachment:image-6.png)

## WOTS+

The **WOTS+ _(Winternitz One-Time Signature Plus)_** algorithm is a component of the SPHINCS+ digital signature scheme. It is used to generate one-time digital signatures for individual message blocks. WOTS+ each private key **MUST NOT** be used to sign more than a single message.

Parameters:
* _n_ - the security parameter - it is the message length as well as the length of a private key, public key, or signature element in bytes
* _w_ - the Winternitz parameter; it is an element of the set {4, 16, 256}

These parameters are used to compute values _len_, _len1_ and _len2_:

![image.png](attachment:image.png)

![image-2.png](attachment:image-2.png)

![image-3.png](attachment:image-3.png)

In [14]:
n = 32
w = 16
h = 64
len1 = ceil(n/log2(w))
len2 = floor(log2(len1*(w-1))/log2(w)) + 1
length = len1 + len2
print(len1, len2, length)

8 2 10


### Chaining Function

the chaining function is used to chain together the output of a hash function over multiple iterations. 

Input:
* _X_ - input string
* _i_ - start index
* _s_ - a number of steps
* _pk_seed_ - public seed
* _adrs_ - 32-byte WOTS+ hash address

Output: value of F iterated s times on X

In [55]:
import os
def chain(X, i, s, pk_seed, adrs):
    if s == 0:
        return X, adrs
    if (i+s) > (w-1):
        return b'', adrs
    tmp, adrs = chain(X, i, s-1, pk_seed, adrs)
    adrs[3] = num_to_bytes(i + s - 1) # hash address
    tmp = F(pk_seed, adrs, tmp)
#     print(tmp)
    return tmp, adrs

ADRS = [os.urandom(8) for _ in range(4)]
PK_SEED = os.urandom(n)
chain(num_to_bytes(10), 0, 5, PK_SEED, ADRS)

(b"\x9d\xff\x9a\x9e\x94\x8eT\x991\xa1F\x90\x9az\na'\xfe\xe0\xde{d\x01\xccK\x96\xf1\xdc\x99/Q\xd0",
 [b'\x9dL\xcdO\xa0!\xb1\xb7',
  b'\xa1\n\xb7SQ<rV',
  b"+\xe4D\xab'\x85f.",
  b'\x00\x00\x00\x00\x00\x00\x00\x04'])

In [58]:
def wots_chain(msghex, show_csum=False):
    # Split hex string into list of 4-bit nibbles
    # (Cheat: we can just split hex string into separate digits)
    msg = [int(x, 16) for x in msghex]
    #print(msg)
    # Compute csum
    csum = 0
    for i in range(len(msg)):
        csum += int(w - 1 - msg[i]) 
    csum &= 0xfff   # truncate to 12 bits
    if show_csum: print(f"csum={csum:03x}")
    msg.append((csum >> 8) & 0xF)
    msg.append((csum >> 4) & 0xF)
    msg.append((csum >> 0) & 0xF)
    return msg

# Input FORS public key to first WOTS signature
msg = wots_chain(num_to_bytes(10).hex(), True)
print(msg)
print([hex(x) for x in msg])
print(f"len={len(msg)}")
bytes.fromhex(msg)

csum=0e6
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 10, 0, 14, 6]
['0x0', '0x0', '0x0', '0x0', '0x0', '0x0', '0x0', '0x0', '0x0', '0x0', '0x0', '0x0', '0x0', '0x0', '0x0', '0xa', '0x0', '0xe', '0x6']
len=19


TypeError: fromhex() argument must be str, not list

### Private key generation

The WOTS+ private key, denoted by _sk (s for secret)_, is a length _len_ array of n-byte strings. This private key MUST NOT be used to sign more than one message.

Input:
* _sk_seed_ - secret seed
* _adrs_ - WOTS+ hash address

Output: WOTS+ private key _sk_

In [34]:
def wots_SKgen(sk_seed, adrs, length):
    sk = b''
    for i in range(length):
        adrs[2] = num_to_bytes(i) # chain address
        sk += PRF(sk_seed, adrs)
    return sk, adrs

### Public key generation

Public key is generated from the _pk_seed_ and a chained _w_-times secret key.

Input:
* _sk_seed_ - secret seed
* _adrs_ - WOTS+ hash address
* _public_seed_ - public seed

Output: WOTS+ public key _pk_

In [39]:
def wots_PKgen(sk, pk_seed, adrs):
    wotspkadrs = deepcopy(adrs)
    tmp = b''
    for i in range(length):
        adrs[2] = num_to_bytes(i) # chain address
        chain_res, adrs = chain(num_to_bytes(sk[i]), 0, w-1, pk_seed, adrs)
        tmp += chain_res
    wotspkadrs[0] = num_to_bytes(1) # set type to WOTS_PK
    wotspkadrs[1] = deepcopy(adrs[1]) # key pair
    pk = sha256(pk_seed + adrs_to_bytes(wotspkadrs) + tmp).digest()
    return pk, adrs

### Signature Generation

A WOTS+ signature is a length _len_ array of n-byte strings. The WOTS+ signature is generated by mapping a message M to _len_ integers between 0 and w − 1 using the _base_w_ method. Next, a checksum over M is computed and appended to the transformed message as _len2_ base-w numbers, also using the _base_w_ function. Each of the _base-w_ integers is used to select a node
from a different hash chain. The signature is formed by concatenating the selected nodes.

Input:
* _M_ - message
* _sk_seed_ - secret seed
* _pk_seed_ - public seed
* _adrs_ - WOTS+ hash address

Output: WOTS+ signature _sig_

In [40]:
def wots_sign(m, sk, pk_seed, adrs, len1):
    csum = 0
    
    # convert message to base w
    msg = base_w(m, w, len1)
    
    # compute checksum
    for i in range(len1):
        csum += w - 1 - msg[i]
    
    # convert csum to base w
    csum = rot_left(csum, (8 - ((len2 * int(log2(w)) % 8))))
    msg = msg + base_w(num_to_bytes(csum), w, len2)
    sig = b''
    for i in range(length):
        adrs[2] = num_to_bytes(i) # chain
        tmp, adrs = chain(sk, 0, msg[i], pk_seed, adrs)
        sig += tmp
    print("END SIGN")
    return sig, adrs

### Compute Public Key from Signature

In order to verify a WOTS+ signature _sig_ on a message M, the verifier computes a WOTS+ public key value from the signature.

In [41]:
def wots_pk_from_sig(sig, m, pk_seed, adrs):
    csum = 0
    wotspkadrs = deepcopy(adrs)
    
    # convert message to base w
    msg = base_w(m, w, len1)
    
    # compute checksum
    for i in range(len1):
        csum += w - 1 - msg[i]
        
    # convert csum to base w
    csum = rot_left(csum, ( 8 - ( ( len2 * int(log2(w)) ) % 8 )))
    msg += base_w(num_to_bytes(csum), w, len2)
    tmp = b''
    for i in range(length):
        adrs[2] = num_to_bytes(i) # chain
#         t, adrs = chain(sig[i], msg[i], w - 1 - msg[i], pk_seed, adrs)
        t, adrs = chain(sig, msg[i], msg[i] + 2, pk_seed, adrs)
        tmp +=t
    wotspkadrs[0] = num_to_bytes(1)
    wotspkadrs[1] = deepcopy(adrs[1])
    pk_sig = sha256(pk_seed + adrs_to_bytes(wotspkadrs) + tmp).digest()
    return pk_sig, adrs

In [42]:
# generate keys
SK_SEED = randbytes(n)
PK_SEED = randbytes(n)
ADRS = [b'\0' * 8] + [randbytes(8) for _ in range(3)]
M = b"Ala ma kotka a kotek ma tez Ale!"
# print(len(M))
# print(ADRS)

wots_sk, adrs = wots_SKgen(SK_SEED, ADRS, length)
# print(adrs)
wots_pk, adrs = wots_PKgen(wots_sk, PK_SEED, adrs)
# print(adrs)

# generate a signature
wots_sig, adrs = wots_sign(M, wots_sk, PK_SEED, adrs, len1)
# print(adrs)

# compare public keys
sig_pk, adrs = wots_pk_from_sig(wots_sig, M, PK_SEED, adrs)
# print(adrs)

END SIGN


In [46]:
for i in range(len(wots_pk)):
    print(num_to_bytes(wots_pk[i]))
    print(num_to_bytes(sig_pk[i]))
    if wots_pk[i] != sig_pk[i]:
        print(False)
        break
else:
    print(True)

b'\x00\x00\x00\x00\x00\x00\x00\x9c'
b'\x00\x00\x00\x00\x00\x00\x00\x93'
False


## XMSS

XMSS is a method for signing a potentially large but fixed number of messages. It is based on the Merkle signature scheme. We use it to build the **hypertree**. The leaves are the WOTS+ public keys. The XMSS public key is the root node of the tree.

Parameters:
* _h'_ - the height (number of levels - 1) of the tree
* _n_ - the length in bytes of messages as well as of each node
* _w_ - the Winternitz parameter as defined for WOTS+

In [None]:
h_prim = 
n = 
w = 

In [None]:
def node_height(node):
    pass

### Treehash algorithm

Algoritm that computes the internal n-byte nodes of a Merkle tree.

Input:
* _sk_seed_ - secret seed
* _s_ - start index
* _z_ - target node height
* _pk_seed_ - public seed
* _adrs_ - WOTS+ pk address

Output: n-byte root node

<font color='red'>It is REQUIRED that s % 2z = 0, i.e. that the leaf at index s is a leftmost leaf of a sub-tree of height z. Otherwise the algorithm fails as it would compute non-existent nodes

In [None]:
stack = [] # to jest nasze drzewo?

def treehash(sk_seed, s, z, pk_seed, adrs):
    if s % rot_left(1, z) != 0:
        return -1
        
    for i in range(pow(2,z)):
        adrs[:8] = num_to_bytes(0) # wots hash
        adrs[8:16] = num_to_bytes(s+i) # key pair
        node, adrs = wots_PKgen(sk_seed, pk_seed, adrs)
        adrs[:8] = num_to_bytes(2) # tree type
        adrs[16:24] = num_to_bytes(1) # tree height
        adrs[24:] = num_to_bytes(s+i) # tree index
        while node_height(stack[0]) == node_height(node):
            adrs[24:] = num_to_bytes((int.from_bytes(adrs[16:24], 'big') - 1 // 2))
            node = H(pk_seed, adrs, stack.pop() + node)
            adrs[16:24] = num_to_bytes((int.from_bytes(adrs[16:24], 'big') + 1))
        stack.append(node)
    return stack.pop(), adrs

### XMSS Public Key Generation

The XMSS public key PK is the root of the binary hash tree which is
computed using treehash.

Input: 
* _sk_seed_ - secret seed
* _pk_seed_ - public seed
* _adrs_ - tree address

Output: XMSS public key _pk_

In [None]:
def xmss_PKgen(sk_seed, pk_seed, adrs):
    pk, adrs = treehash(sk_seed, 0, h_prim, pk_seed, adrs)
    return pk, adrs

### XMSS Signature

An XMSS signature is a _((len + h') ∗ n))_-byte string consisting of:
* a WOTS+ signature sig taking _len * n_ bytes,
* the authentication path AUTH for the leaf associated with the used WOTS+ key pair taking _h' * n_ bytes. For the ith WOTS+ key pair, counting from zero, the jth authentication path node is:

![image-2.png](attachment:image-2.png)

where:
_N(x, y)_ denotes the yth node on level x with y = 0 being the leftmost node on a level. The leaves are on level 0, the root is on level h'.


The XMSS signature:
![image.png](attachment:image.png)

Input:
* _M_ - n-byte message M
* _sk_seed_ - secret seed
* _idx_ - index
* _pk_seed_ - public seed
* _adrs_ - tree type address

Output: XMSS signature _sig_xmss_

In [None]:
def  xmss_sign(m, sk_seed, idx, pk_seed, adrs):
    # build authentication path
    auth = bytearray()
    for j in range(h_prim):
        k = floor(idx/pow(2,j)) ^ 1
        auth_j, adrs = treehash(sk_seed, k*pow(2,j), j, pk_seed, adrs)
        auth += auth_j
    adrs[:8] = num_to_bytes(0)
    adrs[8:16] = num_to_bytes(idx) # key pair
    sig, adrs = wots_sign(m, sk_seed, pk_seed, adrs)
    sig_xmss = sig + auth
    return sig_xmss, adrs

### XMSS Compute Public Key from Signature

We don't use the XMSS verification algorithm, but we calculate the public key value from XMSS signature because that's what we'll need in future computations.

Input:
* _idx_ - index
* _sig_xmss_ - XMSS signature
* _M_ - n-byte message
* _pk_seed_ - public seed
* _adrs_ - WOTS+ hash address

 Output: _pk_ - n-byte root value node[0]

In [None]:
def xmss_pk_from_sig(idx, sig_xmss, m, pk_seed, adrs):
    # compute wots+ pk from wots+ sig
    adrs[:8] = num_to_bytes(0)
    adrs[8:16] = num_to_bytes(idx)
    sig = sig_xmss[:len*n]
    auth = sig_xmss[len*n:]
    node = []
    n, adrs = wots_pk_from_sig(sig, m, pk_seed, adrs)
    node[0] = n
    
    # compute root from wots+ pk and auth
    adrs[:8] = num_to_bytes(2)
    adrs[-8:] = num_to_bytes(idx)
    for k in range(h_prim):
        adrs[16:24] = num_to_bytes(k+1)
        if floor(idx/pow(2,k)) % 2 == 0:
            adrs[-8:] = num_to_bytes(int.from_bytes(adrs[-8:], 'big') // 2)
            node[1] = H(pk_seed, adrs, auth[k] + node[0])
        else:
            adrs[-8:] = num_to_bytes((int.from_bytes(adrs[-8:], 'big') - 1) // 2)
            node[1] = H(pk_seed, adrs, auth[k] + node[0])
        node[0] = node[1]
    return node[0], adrs

## Hypertree

A **hypertree** is a tree of several layers of XMSS trees. The trees on top and
intermediate layers are used to sign the public keys. Trees on the lowest layer are used to sign the actual messages, which are FORS public keys in SPHINCS+. All XMSS trees in HT have equal height.

Parameters:
* _h_ - hypertree height
* _d_ - number of tree layers

### HT Key Generation

The HT public key is the public key (root node) of the single XMSS tree on the top layer.

Input:
* _sk_seed_ - private seed
* _pk_seed_ - public seed

Output: HT public key _pk_ht_

In [None]:
def ht_pk_gen(sk_seed, pk_seed):
    adrs = int_to_byte(0, length=32)
    adrs[:8] = num_to_byte(d-1)
    root, adrs = xmss_pk_gen(sk_seed, pk_seed, adrs)
    return root, adrs

### HT signature

A HT signature SIGHT is a byte string of length _(h + d ∗ len) ∗ n_. It consists of _d_ XMSS signatures (of _(h/d + len) ∗ n_ bytes each). 

![image.png](attachment:image.png)


Signature algorithm uses _xmss_pk_ from_sig_ to compute the root node of an XMSS instance
after that instance was used for signing.

Input: 
* _M_ - message
* _sk_seed_ - secret seed
* _pk_seed_ - public seed
* _idx_tree_ - tree index
* _idx_leaf_ - leaf index

Output: HT signature _sig_ht_

In [None]:
def ht_sign(m, sk_seed, pk_seed, idx_tree, idx_leaf):
    # init
    adrs = int_to_byte(0, length=32)
    
    # sign
    adrs[:8] = num_to_byte(0)
    adrs[8:] = idx_tree
    sig_tmp, adrs = xmss_sign(m, sk_seed, idx_leaf, pk_seed, adrs)
    sig_ht = bytearray()
    sig_ht += sig_tmp
    root, adrs = xmss_pk_from_sig(idx_leaf, sig_tmp, m, pk_seed, adrs)
    for j in range(1, d):
        idx_leaf = idx_tree[-(h/d):] # least significant bits of idx_tree
        idx_tree = idx_tree[:(h-j * (h/d))] # most significant bits of idx_tree
        adrs[:8] = num_to_int(j)
        adrs[8:] = idx_tree
        sig_tmp, adrs = xmss_sign(root, sk_seed, idx_leaf, pk_seed, adrs)
        sig_ht += sig_tmp
        if j < d - 1:
            root, adrs = xmss_pk_from_sig(idx_leaf, sig_tmp, root, pk_seed, adrs)
    return sig_ht, adrs
    

### HT Signature Verification

HT signature verification are _d_ calls to _xmss_pk_from_sig_ and
one comparison with a given private key value.

Input:
* _M_ - message
* _sig_ht_ - signature
* _pk_seed_ - public seed
* _idx_tree_ - tree index
* _idx_leaf_ - index leaf
* _pk_ht_ - HT public key

In [None]:
def ht_verify(m, sig_ht, pk_seed, idx_tree, idx_leaf, pk_ht):
    # init
    adrs = int_to_byte(0, length=32)
    
    # verify
    sig_tmp = sig_ht[0] # first xmss_sig
    adrs[:8] = num_to_bytes(0)
    adrs[8:] = num_to_bytes(idx_tree)
    node, adrs = xmss_pk_from_sig(idx_leaf, sig_tmp, m, pk_seed, adrs)
    for j in range(1, d):
        idx_leaf = idx[-(h//d):]  # least significant bits of idx
        idx_tree = idx[:(h-j * h // d)] #  most significant bits of idx
        sig_tmp, adrs = sig_ht[j]
        adrs[:8] = num_to_bytes(j)
        adrs[8:] = num_to_bytes(idx_tree)
        node, adrs = xmss_pk_from_sig(idx_leaf, sig_tmp, node, pk_seed, adrs)
    if node == pk_ht:
        return True, adrs
    return False, adrs

## FORS: Forest Of Random Subsets

The SPHINCS+ hypertree HT is not used to sign the actual messages but the public keys of FORS instances which in turn are used to sign the messages.

Parameters:
* _n_ - the security parameter; it is the length of a private key, public key, or signature element in bytes.
* _k_ - the number of private key sets, trees and indices computed from the input string.
* _t_ - the number of elements per private key set, number of leaves per hash tree and upper bound on the index values. The parameter t MUST be a power of 2. If t = 2a, then the trees have height a and the input string is split into bit strings of length a.

![image.png](attachment:image.png)

### FORS Private Key

It is used to generate the private key values using PRF with an address.

Input:
* _sk_seed_ - secret seed
* _adrs_ - FORS tree 

Output: FORS private key _sk_

In [None]:
def fors_sk_gen(sk_seed, adrs, idx):
    adrs[16:24] = num_to_bytes(0)
    adrs[-8:] = num_to_bytes(idx)
    sk = PRF(sk_seed, adrs)
    return sk, adrs

### FORS Tree Hash

Function that computes the n-byte nodes of the FORS trees.

Input:
* _sk_seed_ - secret seed
* _s_ - start index
* _z_ - target node height
* _pk_seed_ - public seed
* _adrs_ - FORS tree address

Output: n-byte root node - top node on Stack

In [None]:
stack = [] # to inny stack??

def fors_treehash(sk_seed, s, z, pk_seed, adrs):
    if s % rot_left(1, z) != 0:
        return -1
    for i in range(pow(2,z)):
        adrs[16:24] = num_to_bytes(0)
        adrs[-8:] = num_to_bytes(s + i)
        sk = PRF(sk_seed, adrs)
        node = F(pk_seed, adrs, sk)
        adrs[16:24] = num_to_bytes(1)
        while node_height(stack[0]) == node_height(node):
            adrs[-8:] = num_to_bytes((int.from_bytes(adrs[-8:], 'big') - 1) // 2)
            node = H(pk_seed, adrs, stack.pop() + node)
        stack.push(node)
    return stack.pop(), adrs

### FORS Public Key

The FORS public key is the value on top of out forest. It is NEVER generated alone. It is only generated together with a signature.

Input:
* _sk_seed_ - secret seed
* _pk_sed_ - public seed
* _adrs_ - FORS tree address

Output: FORS public key _pk_

In [None]:
def fors_pkgen(sk_seed, pk_seed, adrs):
    forspkadrs = deepcopy(adrs)
    root = []
    for i in range(k):
        r, adrs = fors_treehash(sk_seed, i*k, a, pk_seed, adrs)
        root.append(r)
    forspkadrs[:8] = num_to_int(4)
    forspkadrs[8:16] = adrs[8:16]
    pk = sha256(pk_seed + forspkadrs + root).digest() # ???
    return pk, adrs

### FORS Signature Generation

A FORS signature is a length _k(log t + 1)_ array of n-byte strings. It contains _k_ private key values, n-bytes each, and their associated authentication paths, _log t_ n-byte values each.

![image.png](attachment:image.png)

Input:
* _M_ - message
* _sk_seed_ secret seed
* _adrs_ - FORS tree address
* _pk_seed_ - public seed

Output: FORS signature _sig_fors_

In [None]:
def fors_sign(m, sk_seed, pk_seed, adrs):
    # compute signature elements
    sig_fors = bytearray()
    for i in range(k):
        # get next index
        idx = m[i*t-1 : (i+1)*t - 2] # ?? - bits i*t to (i+1)*t - 1 of M
        
        # pick private key element
        adrs[16:24] = num_to_int(0)
        adrs[-8:] = i*t + idx
        sig_fors += PRF(sk_seed, adrs)
        
        # compute auth path
        auth = bytearray()
        for j in range(a):
            s = floor(idx/(pow(2, j))) ^ 1
            a, adrs = fors_treehash(sk_seed, i*k + s * pow(2,j), j, pk_seed, adrs)
            auth += a
        sig_fors = sig_fors + auth
    return sig_fors, adrs

### FORS Compute Public Key from Signature

A FORS signature is used to compute a candidate FORS public key. This public key is used in further computations (message for the signature of the XMSS tree above) and implicitly verified by the outcome of that computation.

Input:
* _sig_fors_ - FORS signature
* _M_ - message
* _pk_seed_ - public seed
* _adrs_ - FORS tree address

Output: FORS public key

In [None]:
def fors_pk_from_sig(sig_fors, m, pk_seed, adrs):
    # compute roots
    root = bytearray()
    for i in range(k):
        # get next index
        idx = m[i*t - 1 : (i+1)*t - 2]  # ?? bits i*t to (i+1)*t - 1 of M
        
        # compute leaf
        sk = sig_fors[i * (log2(t)*n + n)] # SK[i]
        adrs[16:24] = num_to_bytes(0)
        adrs[-8:] = num_to_bytes(i*t + idx)
        node = [F(pk_seed, adrs, sk)]
        
        # compute root from leaf and AUTH
        auth = sig_fors[i * (log2(t)*n + n) + n]
        adrs[-8:] = idx
        for j in range(int(log2(t))): # a = log(t)
            adrs[16:24] = num_to_bytes(j+1)
            if floor(idx/pow(2,j)) % 2 == 0:
                adrs[-8:] = num_to_bytes(int.from_bytes(adrs[-8:]) // 2)
                node.append(H(pk_seed, adrs, node[0] + auth[j]))
            else:
                adrs[-8:] = num_to_bytes((int.from_bytes(adrs[-8:]) - 1) // 2)
                node.append(H(pk_seed, adrs, auth[j] + node[0]))
            node[0] = node[1]
        root += node[0]
    forspkadrs = deepcopy(adrs)
    forspkadrs[:8] = num_to_bytes(4)
    pk = sha256(pk_seed, forspkadrs, root).digest() # to mialo byc Tk czyli FORS public key?
    return pk, adrs

# SPHINCS+

Parameters:
* n - the security parameter in bytes.
* w - the Winternitz parameter
* h - the height of the hypertree
* d - the number of layers in the hypertree
* k - the number of trees in FORS
* t - the number of leaves of a FORS tree

### SPHINCS+ Key Generation

SPHINCS+ **public key** consists of:
* _pk_root_ - HT public key (the root of the hypertree on the top layer)
* _pk_seed_ - public seed

SPHINCS+ **private key** consists of:
* _sk_seed_ - secret seed which is used to generate all the WOTS+ and FORS private key elements.
* _sk_prf_ - PRF - which is used to deterministically generate a randomization value for the randomized message hash.
* _pk_root_ - HT public key (the root of the hypertree on the top layer)
* _pk_seed_ - public seed

![image.png](attachment:image.png)

In [None]:
def spx_keygen():
    pk_seed = bytearray(os.urandom(n))
    sk_prf = bytearray(os.urandom(n))
    pk_seed = bytearray(os.urandom(n))
    pk_root, adrs = ht_pk_gen(sk_seed, pk_seed)
    return sk_seed, sk_prf, pk_seed, pk_root

### SPHINCS+ Signature

1) Generate a random value _R_ 
   
In original algorithm when computing R, the PRF takes a n-byte string opt which is initialized with zero but can be overwritten with randomness if the global variable RANDOMIZE is set. This option is given as otherwise SPHINCS+ signatures would be always deterministic. In this implementation we omit this step because it's not that essencial.

 
2) Compute a m byte message digest 

Message digest is split into a _floor((k log t + 7)/8)_-byte partial message digest _tmp_md_, a _floor((h − h/d + 7)/8)_-byte tree, index _tmp_idx_tree_, and a _floor((h/d + 7)/8)_-byte leaf index _tmp_idx_leaf_.

3) Compute the actual values of _md_, _tmp_idx_tree_ and _idx_leaf_ by extracting the necessary number of bits.

4) Sign the partial message digest _md_ using _idx_leaf_th_ FORS key pair of the _idx_tree_th_ XMSS tree on the lowest HT layer

5) Sign the public key of the FORS key pair using HT

Input:
* _M_ - message
* _sk = (sk_seed, sk_prf, pk_seed, pk_root)_ - private key

Output: SPHINCS+ signature _sig_

In [None]:
def spx_sign(m, sk_seed, sk_prf, pk_seed, pk_root):
    # init
    adrs = bytearray(32)
    
    # generate randomizer
    opt = bytearray(32)
    r = PRF_msg(sk_prf, opt, m)
    sig = r
    
    # compute message digest and index
    digest = H_msg(r, pk_seed, pk_root, m)
    tmp_md = digest[:floor((k*a + 7)/8)] # first floor((ka +7)/ 8) bytes of digest;
    tmp_idx_tree = digest[floor((k*a + 7)/8) : floor((k*a + 7)/8) + floor((h - h/d +7)/ 8)] # next floor((h - h/d +7)/ 8) bytes of digest;
    tmp_idx_leaf = digest[floor((k*a + 7)/8) + floor((h - h/d +7)/ 8):] # next floor((h/d +7)/ 8) bytes of digest;
    
    md = tmp_md[:k*a] # first ka bits of tmp_md;
    idx_tree = tmp_idx_tree[:(h-h/d)] # first h - h/d bits of tmp_idx_tree;
    idx_leaf = tmp_idx[:h//d] # first h/d bits of tmp_idx_leaf;
    
    # FORS sign
    adrs[:8] = num_to_bytes(3)
    adrs[8:] = idx_tree
    adrs[8:16] = idx_leaf
    
    sig_fors, adrs = fors_sign(md, sk_seed, pk_seed, adrs)
    sig = sig_fors
    
    # get FORS public key
    pk_fors = fors_pk_from_sig(sig_fors, m, pk_seed, adrs)
    
    # sign FORS public key with HT
    adrs[:8] = num_to_bytes(2)
    sig_ht, adrs = ht_sign(pk_fors, sk_seed, pk_seed, idx_tree, idx_leaf)
    sig += sig_ht
    return sig

### SHPINCS+ Signature Verification

SPHINCS+ signature verification can be summarized as recomputing message digest and index, computing a candidate FORS public key, and verifying the HT signature on that public key.

Input:
* _M_ - message
* _sig_ - signature
* _pk_ - public key