In [1]:
from CompactFIPS202 import *
import binascii
import os
from hashlib import sha3_256, sha3_512, shake_128, shake_256
from kyber import Kyber512, Kyber768, Kyber1024
from aes256_ctr_drbg import AES256_CTR_DRBG

def parse_kat_data(data):
    parsed_data = {}
    count_blocks = data.split('\n\n')
    for block in count_blocks[1:-1]:
        block_data = block.split('\n')
        count, seed, pk, sk, ct, ss = [line.split(" = ")[-1] for line in block_data]
        parsed_data[count] = {
            "seed": bytes.fromhex(seed),
            "pk": bytes.fromhex(pk),
            "sk": bytes.fromhex(sk),
            "ct": bytes.fromhex(ct),
            "ss": bytes.fromhex(ss),   
        }
    return parsed_data

In [2]:
with open("assets/PQCkemKAT_1632.rsp") as f:
    kat_data = f.read()
    parsed_data = parse_kat_data(kat_data)

data = parsed_data['0']
seed, pk, sk, ct, ss = data.values()
        
# Seed DRBG with KAT seed
Kyber512.set_drbg_seed(seed)
# Assert keygen matches
_pk, _sk = Kyber512.keygen()

# Assert encapsulation matches
_ct, _ss = Kyber512.enc(_pk)

#    # Assert decapsulation matches
#    __ss = Kyber512.dec(ct, sk)

In [3]:
seed.hex()

'061550234d158c5ec95595fe04ef7a25767f2e24cc2bc479d09d86dc9abcfde7056a8c266f9ef97ed08541dbd2e1ffa1'

In [4]:
pk == _pk

True

In [5]:
(pk == _pk) == (sk == _sk) == (ct == _ct) == (ss == _ss)

True

In [6]:
random_bytes = os.urandom

In [7]:
a = random_bytes(32)

In [8]:
len(Kyber512._xof_a)

1

In [9]:
len(Kyber512._xof_input_bytes)

34

In [10]:
Kyber512._xof_a.hex()

'01'

In [11]:
a = shake_128(Kyber512._xof_input_bytes).digest(Kyber512._xof_length).hex()

In [12]:
b = SHAKE128(Kyber512._xof_input_bytes, Kyber512._xof_length).hex()

13945236032936597125, [133, 214, 41, 116, 73, 110, 135, 193]
a          13945236032936597125
o_out      9643941231485421505
10144667240262694917, [5, 120, 102, 52, 27, 25, 201, 140]
a          10144667240262694917
o_out      394177341373925772
2854837008212362804, [52, 6, 35, 38, 34, 107, 158, 39]
a          2854837008212362804
o_out      3748722386525724199
5510798064491279528, [168, 180, 233, 102, 164, 73, 122, 76]
a          5510798064491279528
o_out      12156597921232026188
10649859246619797596, [92, 168, 174, 54, 143, 230, 203, 147]
a          10649859246619797596
o_out      6676777996942494611
9618096872207591557, [133, 168, 58, 123, 37, 88, 122, 133]
a          9618096872207591557
o_out      9631012103713749637
3074827116302497218, [194, 245, 242, 156, 248, 250, 171, 42]
a          3074827116302497218
o_out      14048401368658127658
14279517561725204819, [83, 57, 149, 246, 144, 9, 43, 198]
a          14279517561725204819
o_out      5996989265031539654
15146932421457323213, [205

In [13]:
aList = [[885738389246962277, 0, 0, 0, 9223372036854775808], [1554946951491850079, 0, 0, 0, 0], [10097308149613174153, 0, 0, 0, 0], [2466201460356326228, 0, 0, 0, 0], [2031873, 0, 0, 0, 0]]


In [14]:
import itertools
bList = list(itertools.chain(*aList))
len(bin(max(bList)).replace('0b',''))
bList

[885738389246962277,
 0,
 0,
 0,
 9223372036854775808,
 1554946951491850079,
 0,
 0,
 0,
 0,
 10097308149613174153,
 0,
 0,
 0,
 0,
 2466201460356326228,
 0,
 0,
 0,
 0,
 2031873,
 0,
 0,
 0,
 0]

In [None]:
my_list = [52, 6, 35, 38, 34, 107, 158, 39]
int(''.join(hex(x).replace('0x', '').rjust(2, '0') for x in my_list), 16)

235914677176737319

In [16]:
hex(6).replace('0x','').rjust(2, '0')

'06'

In [17]:
a = sha3_256(Kyber512._xof_input_bytes).digest().hex()

In [18]:
b = SHA3_256(Kyber512._xof_input_bytes).hex()

17723320065212179744, [32, 101, 134, 119, 178, 225, 245, 245]
a          17723320065212179744
o_out      2334419830521853429
5332688556074578821, [133, 179, 157, 234, 252, 131, 1, 74]
a          5332688556074578821
o_out      9634217660459974986
15592808027434457827, [227, 26, 23, 102, 73, 200, 100, 216]
a          15592808027434457827
o_out      16364417924096091352
8187770676326191426, [66, 81, 200, 168, 12, 206, 160, 113]
a          8187770676326191426
o_out      4778821303711735921
11531033938493744927, [31, 167, 228, 228, 69, 118, 6, 160]
a          11531033938493744927
o_out      2281043405355484832
4103508557802861860, [36, 117, 157, 192, 74, 151, 242, 56]
a          4103508557802861860
o_out      2627179406851306040
16257386486144937476, [4, 174, 231, 178, 223, 214, 157, 225]
a          16257386486144937476
o_out      337461777544945121
3571200004506186818, [66, 244, 118, 170, 123, 115, 143, 49]
a          3571200004506186818
o_out      4824611575408332593
15550595238011099980,

In [19]:
a == b

True

In [20]:
a

'20658677b2e1f5f524759dc04a97f238856e5031e972cd9e52387e147e325bab'

In [21]:
b

'20658677b2e1f5f524759dc04a97f238856e5031e972cd9e52387e147e325bab'

In [22]:
state = bytearray([0 for i in range(2)])
state

bytearray(b'\x00\x00')

In [23]:
for a in range(5):
    print(a)

0
1
2
3
4
