In [1]:
from CompactFIPS202 import *
import binascii
import os
from hashlib import sha3_256, sha3_512, shake_128, shake_256
from kyber import Kyber512, Kyber768, Kyber1024
from aes256_ctr_drbg import AES256_CTR_DRBG

def parse_kat_data(data):
    parsed_data = {}
    count_blocks = data.split('\n\n')
    for block in count_blocks[1:-1]:
        block_data = block.split('\n')
        count, seed, pk, sk, ct, ss = [line.split(" = ")[-1] for line in block_data]
        parsed_data[count] = {
            "seed": bytes.fromhex(seed),
            "pk": bytes.fromhex(pk),
            "sk": bytes.fromhex(sk),
            "ct": bytes.fromhex(ct),
            "ss": bytes.fromhex(ss),   
        }
    return parsed_data

In [2]:
with open("assets/PQCkemKAT_1632.rsp") as f:
    kat_data = f.read()
    parsed_data = parse_kat_data(kat_data)

data = parsed_data['0']
seed, pk, sk, ct, ss = data.values()
        
# Seed DRBG with KAT seed
Kyber512.set_drbg_seed(seed)
# Assert keygen matches
_pk, _sk = Kyber512.keygen()

# Assert encapsulation matches
_ct, _ss = Kyber512.enc(_pk)

#    # Assert decapsulation matches
#    __ss = Kyber512.dec(ct, sk)

In [3]:
seed.hex()

'061550234d158c5ec95595fe04ef7a25767f2e24cc2bc479d09d86dc9abcfde7056a8c266f9ef97ed08541dbd2e1ffa1'

In [4]:
pk == _pk

True

In [5]:
(pk == _pk) == (sk == _sk) == (ct == _ct) == (ss == _ss)

True

In [6]:
random_bytes = os.urandom

In [7]:
a = random_bytes(32)

In [8]:
len(Kyber512._xof_a)

1

In [9]:
len(Kyber512._xof_input_bytes)

34

In [10]:
Kyber512._xof_a.hex()

'01'

In [11]:
a = shake_128(Kyber512._xof_input_bytes).digest(Kyber512._xof_length).hex()

In [12]:
b = SHAKE128(Kyber512._xof_input_bytes, Kyber512._xof_length).hex()

[[885738389246962277, 0, 0, 0, 9223372036854775808], [1554946951491850079, 0, 0, 0, 0], [10097308149613174153, 0, 0, 0, 0], [2466201460356326228, 0, 0, 0, 0], [2031873, 0, 0, 0, 0]]
[[13945236032936597125, 10144667240262694917, 2854837008212362804, 5510798064491279528, 10649859246619797596], [9618096872207591557, 3074827116302497218, 14279517561725204819, 15146932421457323213, 18193092708702448085], [6506339989559711229, 17814707969875772848, 13508111384288291931, 332785224959981807, 5836078173918574524], [10584102647138126943, 14822784191085864509, 9889073907843170860, 6266883757563277364, 8757664554044316764], [15540293597810366709, 10946536803342694163, 5358925043751425461, 233542234349471477, 5853087487196293284]]
[[9834587190136136295, 10449133524115970660, 13738987947225894368, 7163016936840495663, 8965792112117801652], [5502469702469101644, 3579287621522426969, 6137210619937726070, 17290108940877493852, 4618127811272740221], [1765384516403509340, 6126862305691337853, 73913401052

In [24]:
lanes = [[885738389246962277, 0, 0, 0, 9223372036854775808], [1554946951491850079, 0, 0, 0, 0], [10097308149613174153, 0, 0, 0, 0], [2466201460356326228, 0, 0, 0, 0], [2031873, 0, 0, 0, 0]]
lanes[0][0]

885738389246962277

In [26]:
aList = [[x+2*y for y in range(5)] for x in range(5)]
aList

[[0, 2, 4, 6, 8],
 [1, 3, 5, 7, 9],
 [2, 4, 6, 8, 10],
 [3, 5, 7, 9, 11],
 [4, 6, 8, 10, 12]]

In [14]:
import itertools
bList = list(itertools.chain(*aList))
len(bin(max(bList)).replace('0b',''))
bList

[885738389246962277,
 0,
 0,
 0,
 9223372036854775808,
 1554946951491850079,
 0,
 0,
 0,
 0,
 10097308149613174153,
 0,
 0,
 0,
 0,
 2466201460356326228,
 0,
 0,
 0,
 0,
 2031873,
 0,
 0,
 0,
 0]

In [15]:
my_list = [52, 6, 35, 38, 34, 107, 158, 39]
int(''.join(hex(x).replace('0x', '').rjust(2, '0') for x in my_list), 16)

3748722386525724199

In [16]:
hex(6).replace('0x','').rjust(2, '0')

'06'

In [17]:
a = sha3_256(Kyber512._xof_input_bytes).digest().hex()

In [18]:
b = SHA3_256(Kyber512._xof_input_bytes).hex()

[[885738389246962277, 0, 0, 0, 0], [1554946951491850079, 0, 0, 9223372036854775808, 0], [10097308149613174153, 0, 0, 0, 0], [2466201460356326228, 0, 0, 0, 0], [393473, 0, 0, 0, 0]]


In [19]:
a == b

True

In [20]:
a

'20658677b2e1f5f524759dc04a97f238856e5031e972cd9e52387e147e325bab'

In [21]:
b

'20658677b2e1f5f524759dc04a97f238856e5031e972cd9e52387e147e325bab'

In [22]:
state = bytearray([0 for i in range(2)])
state

bytearray(b'\x00\x00')

In [23]:
for a in range(5):
    print(a)

0
1
2
3
4
