## Setup

### Parameters

In [45]:
import math
from os import urandom
from hashlib import sha3_256, sha3_512, shake_256

from sage.rings.polynomial.polynomial_quotient_ring import PolynomialQuotientRing as polynomial
from sage.matrix.constructor import matrix
from sage.misc.prandom import randrange
from sage.rings.finite_rings.finite_field_constructor import FiniteField 
from sage.rings.polynomial.polynomial_ring_constructor import PolynomialRing
from sage.rings.polynomial.polynomial_quotient_ring import PolynomialQuotientRing as polynomial
from sage.rings.integer import Integer

# Kyber Parameters
q = 3329
q_bytes = 16
k = 2
n = 256


rQ = PolynomialRing(FiniteField(q, 'x'), 'x', sparse=True)
x = rQ.gen()
f = x^n + 1
RQ = rQ.quotient(f)


DEBUG = True

### Helper Functions

In [46]:
def RandomList(length, cbd=False):
    out = [randrange(q) for i in range(length)]
    if cbd:
        out = FauxCbd(out)
    return out

def FauxCbd(r: list):
    out = []
    for i in r:
        out.append((i % 5) - 2) # Restrict to -2 <= n <= 2
    return out

def RandPolyUniform(length):
    return RQ(RandomList(length))

def RandPolyCbd(length) -> polynomial:
    return RQ(RandomList(length, cbd=True))

def RandListCbd(length) -> list:
    return RandomList(length, cbd=True)

def BytesNeed4Bits(bits: int) -> int:
    return ((bits+7) & (-8))//8

def RandInt(bits: int) -> int:
    m = Integer(int.from_bytes(urandom(BytesNeed4Bits(bits)), 'big'))
    m &= 2**n-1
    return m

def Poly2Bytes(poly: polynomial) -> bytes:
    out = b''
    p = poly.coefficients()[0].list()
    for c in p:
        c = int(c)
        cb = c.to_bytes(q_bytes, 'big')
        out += cb
    return out

def Bytes2ListInt(b: bytes) -> list:
    out = []
    for byte in b:
        out.append(Integer(byte))
    return out

def Bytes2ListBit(b: bytes) -> list:
    out = []
    for byte in b:
        for i in range(0,8):
            bit = (byte >> i) & 1
            out.append(bit)
    return out

def Compress(poly: polynomial):
     q2 = math.ceil(q/2)
     return poly * q2

def Decompress(poly: polynomial) -> int:
    return [(1 if 3*(q/4) > Integer(i) > q/4 else 0) for i in poly]

def dbg(label: str, *args: str):
    if DEBUG:
        s = f'{label}:\n'
        for arg in args:
            s += f'{arg}\n'
        if len(args) == 0:
            s = s[:-2]
        print(s)

## INDCPA Public Key Encryption (K-PKE)

### K-PKE KeyGen

In [47]:
def KPKE_Keygen() -> (matrix, matrix, matrix):
    dbg('===== kpke_keygen =====')

    # Initialize
    A = [[[None] for _ in range(0, k)] for _ in range(0, k)]
    s = [[None] for _ in range(0, k)]
    e = [[None] for _ in range(0, k)]

    # A is a k*k dimension matrix of polynomials with n terms
    for i in range(0, k):
        for j in range(0, k):
            A[i][j] = RandPolyUniform(n)
    A = matrix(A)

    dbg('A', A)

    # s is a k*1 dimension matrix of polynomials with n terms
    for i in range(0, k):
        s[i] = [RandPolyCbd(n)]
    s = matrix(s)

    dbg('s', s)

    # e is a k*1 dimension matrix of polynomials with n terms
    for i in range(0, k):
        e[i] = [RandPolyCbd(n)]
    e = matrix(e)

    dbg('e', e)

#   Compute t = A*s*e:
#   A*s is a k * 1 matrix of polynomials with n terms
#   A*s+e is a k * 1 matrix polynomials with n terms
#   t is a k*1 dimension matrix
#
#   Example when k=2:
#   |     A     |   |  s  |   |  e  |
#   | :-- | :-- |   | :-- |   | :-- |
#   | 0,0 | 0,1 |   |  0  |   |  0  |
#   | 1,0 | 1,1 |   |  1  |   |  1  |
#
#   |             A * s             |
#   | :---------------------------- |
#   | A[0,0] * s[0] + A[0,1] * s[1] |
#   | A[1,0] * s[0] + A[1,1] * s[1] |
#
#   |     As+e     |
#   | :----------- |
#   | As[0] + e[0] |
#   | As[1] + e[1] |

    t = A*s+e

    dbg('t', t, '\n')

    return (A, t, s)

### K-PKE Encrypt

In [48]:
def KPKE_Encrypt(A: matrix, t: matrix, m: int, r: polynomial) -> (polynomial, polynomial):
    dbg('===== kpke_encrypt =====')

    # Initialize
    rr = [[None] for _ in range(0, k)]
    e1 = [[None] for _ in range(0, k)]
    e2 = [None] * n
    
    # Ensure that m does not have more bits than n bits
    if len(m.bits()) > n:
        raise ValueError('m has more bits than n!')
    mb = m.bits()

    dbg('Bits of m', mb)

    # N is nonce used to deterministicly modify r
    N = 0

    # We need m to be at least n bits long.
    # Pad mm with 0s until desired length is reached
    pad = [0 for _ in range(0, n - len(mb))]
    mbp = RQ(mb + pad)

    dbg('Polynomial m', mbp)

    # Compress m
    mbpc = Compress(mbp)

    dbg('Compressed m', mbpc)

    # Generate r, e1, e2
    # r is a k*1 matrix of polynomials with n terms
    for i in range(0, k):
        tpoly = [None] * n
        for j in range(0, n):
            tpoly[j] = r[j] + N
        tpoly = FauxCbd(tpoly)
        tpoly = RQ(tpoly)
        rr[i] = [tpoly]
        N += 1
    rr = matrix(rr)
    
    dbg('rr', rr)

    # e1 is a k*1 matrix of polynomials with n terms
    for i in range(0, k):
        tpoly = [None] * n
        for j in range(0, n):
            tpoly[j] = r[j] + N
        tpoly = FauxCbd(tpoly)
        tpoly = RQ(tpoly)
        e1[i] = [tpoly]
        N += 1
    e1 = matrix(e1)

    dbg('e1', e1)

    # e2 is an n-length polynomial with n terms
    for i in range(0, n):
        e2[i] = r[i] + N
    e2 = FauxCbd(e2)
    e2 = RQ(e2)

    dbg('e2', e2)

    u = A.transpose() * rr + e1
    v = t.transpose() * rr + e2 + mbpc

    dbg('u', u)
    dbg('v', v, '\n')

    return (u, v)

    

### K-PKE Decrypt

In [49]:
def KPKE_Decrypt(u: matrix, v: matrix, s: matrix) -> int:
    dbg('===== kpke_decrypt =====')
    
    # Compute a noisy result mn
    mn = v - s.transpose() * u
    mn = mn.coefficients()[0]

    dbg('Noisy recovered m', mn)
   
    mn_c = mn.list()
    mn_c.reverse()

    # Decompress and remove the noise
    m_rec = Decompress(mn_c)

    dbg('Decompressed m', list(reversed(m_rec)), '\n')

    # Convert to integer
    m_rec = int(''.join([str(x) for x in m_rec]), 2)

    return m_rec

### K-PKE Test

In [50]:
# Alice generates a public key (A, t),
# and a private key s
A, t, s = KPKE_Keygen()

# Alice sends Bob her pk
# Bob chooses a random message m
# and encrypts it using Alice's pk
# and some randomness r to produce the ciphertext (u, v)
m = RandInt(n)
r = RandListCbd(n)
u, v = KPKE_Encrypt(A, t, m, r)

# Bob sends Alice (u, v).
# Alice can then recover the message m
mr = KPKE_Decrypt(u, v, s)

dbg('Alice\'s m', mr)
dbg('Bob\'s m', m)
dbg(m.bits())
if m != mr:
    raise ValueError('Alice and Bob\'s messages do not match, final decompression likely failed. Try increasing the value of the prime q')


===== kpke_keygen =====
A:
[                 1161*xbar^255 + 631*xbar^254 + 308*xbar^253 + 1272*xbar^252 + 2316*xbar^251 + 1872*xbar^250 + 460*xbar^249 + 292*xbar^248 + 1963*xbar^247 + 471*xbar^246 + 873*xbar^245 + 2209*xbar^244 + 3206*xbar^243 + 403*xbar^242 + 1494*xbar^241 + 3274*xbar^240 + 768*xbar^239 + 1552*xbar^238 + 2153*xbar^237 + 2353*xbar^236 + 1749*xbar^235 + 2437*xbar^234 + 2219*xbar^233 + 1913*xbar^232 + 198*xbar^231 + 1019*xbar^230 + 2918*xbar^229 + 432*xbar^228 + 199*xbar^227 + 1589*xbar^226 + 1882*xbar^225 + 956*xbar^224 + 584*xbar^223 + 1116*xbar^222 + 2867*xbar^221 + 471*xbar^220 + 734*xbar^219 + 1851*xbar^218 + 3243*xbar^217 + 2515*xbar^216 + 1224*xbar^215 + 694*xbar^214 + 331*xbar^213 + 1763*xbar^212 + 602*xbar^211 + 1410*xbar^210 + 672*xbar^209 + 2284*xbar^208 + 2948*xbar^207 + 152*xbar^206 + 2899*xbar^205 + 2773*xbar^204 + 2678*xbar^203 + 1326*xbar^202 + 3097*xbar^201 + 2207*xbar^200 + 479*xbar^199 + 601*xbar^198 + 1739*xbar^197 + 1542*xbar^196 + 2731*xbar^195 + 9

# INDCCA Key Exchange Mechanism (ML-KEM)

## ML-KEM KeyGen

In [51]:
def MLKEM_KeyGen() -> ((polynomial, polynomial), (polynomial, polynomial, int)):
    dbg('===== MLKEM_KeyGen =====')
    
    z = RandInt(n)
    A, t, s = KPKE_Keygen()
    ek = (A, t)
    Ht = sha3_256(Poly2Bytes(t)).digest()

    dbg('SHA3-256(t)', Ht.hex(), '\n')

    dk = (s, ek, Ht, z)

    return ek, dk

## ML-KEM Encaps

In [52]:
def MLKEM_Encaps(ek: (polynomial, polynomial)) -> (bytes, bytes):
    dbg('===== MLKEM_Encaps =====')

    m = RandInt(n)
    dbg('m', m)

    A, t = ek
    Ht = sha3_256(Poly2Bytes(t)).digest()

    dbg('SHA3-256(t)', Ht.hex())

    Kr = sha3_512(int(m).to_bytes(BytesNeed4Bits(n), 'big') + Ht).digest()
    K, r = (Kr[:BytesNeed4Bits(n)], Kr[BytesNeed4Bits(n):])
    r = FauxCbd(Bytes2ListBit(r))

    dbg('r', r, '\n')
    
    c = KPKE_Encrypt(A, t, m, r)

    return K, c

## ML-KEM Decaps

In [53]:
def MLKEM_Decaps(c: (polynomial, polynomial), dk: (polynomial, (polynomial, polynomial), bytes, int)) -> bytes:
    dbg('===== MLKEM_Decaps =====')
    s, ek, h, _ = dk
    A, t = ek
    u, v = c

    mprime = KPKE_Decrypt(u, v, s)

    dbg('m\'', mprime)

    Krprime = sha3_512(int(mprime).to_bytes(BytesNeed4Bits(n), 'big') + h).digest()
    Kprime, rprime = (Krprime[:BytesNeed4Bits(n)], Krprime[BytesNeed4Bits(n):])
    rprime = FauxCbd(Bytes2ListBit(rprime))

    dbg('r\'', rprime)

    uprime, vprime = KPKE_Encrypt(A, t, Integer(mprime), rprime)

    dbg('u\'', uprime)
    dbg('v\'', vprime, '\n')

    return Kprime



## ML-KEM Test

In [54]:
# Alice runs KeyGen()
pkA, skA = MLKEM_KeyGen()

# Bob receives pkA from Alice,
# then runs Encaps() to generate
# his copy of the shared secret ssB, 
# and a ciphertext c
ssB, c = MLKEM_Encaps(pkA)

# Bob sends c to Alice,
# who then uses her secret key
# to generate her copy of
# the shared secret ssA
ssA = MLKEM_Decaps(c, skA)

dbg('Bob\'s Shared Secret', ssB.hex())
dbg('Alice\'s Shared Secret', ssA.hex())
dbg('ssA == ssB?', ssA == ssB)
if(ssA != ssB):
    raise ValueError('The shared keys do not match!')

===== MLKEM_KeyGen =====
===== kpke_keygen =====
A:
[          2159*xbar^255 + 1897*xbar^254 + 3082*xbar^253 + 750*xbar^252 + 1702*xbar^251 + 2827*xbar^250 + 2061*xbar^249 + 1430*xbar^248 + 674*xbar^247 + 2171*xbar^246 + 25*xbar^245 + 1996*xbar^244 + 2577*xbar^243 + 2980*xbar^242 + 838*xbar^241 + 2183*xbar^240 + 765*xbar^239 + 286*xbar^238 + 3296*xbar^237 + 850*xbar^236 + 1190*xbar^235 + 1575*xbar^234 + 164*xbar^233 + 3002*xbar^232 + 2341*xbar^231 + 2925*xbar^230 + 2010*xbar^229 + 1837*xbar^228 + 2662*xbar^227 + 271*xbar^226 + 1456*xbar^225 + 3189*xbar^224 + 2574*xbar^223 + 3064*xbar^222 + 1684*xbar^221 + 539*xbar^220 + 2725*xbar^219 + 2885*xbar^218 + 132*xbar^217 + 3278*xbar^216 + 2180*xbar^215 + 2169*xbar^214 + 1411*xbar^213 + 763*xbar^212 + 1836*xbar^211 + 136*xbar^210 + 2552*xbar^209 + 3058*xbar^208 + 1353*xbar^207 + 1212*xbar^206 + 1021*xbar^205 + 402*xbar^204 + 342*xbar^203 + 1261*xbar^202 + 1602*xbar^201 + 2119*xbar^200 + 1787*xbar^199 + 1464*xbar^198 + 3081*xbar^197 + 173*xbar^

# Key Exchange Tests

## Init

In [55]:
DEBUG = False

# Alice and Bob create independent *static* (pk, sk) pairs
pkA, skA = MLKEM_KeyGen()
pkB, skB = MLKEM_KeyGen()

## Unilaterally Authenticated Key Exchange (UAKE)

In [56]:
# Alice generates a new set of temporary keys
tpkA, tskA = MLKEM_KeyGen()
# Alice encapsulates Bob's static public key (pkB)
# to generare a ciphertext and shared secret (tcA, tssA) to send to Bob
tssA, cA = MLKEM_Encaps(pkB)

# Alice sends (tpkA, cA) to Bob
# Bob encapsulates Alice's temporary public key (tpkA)
# to generare a ciphertext and shared secret (cB, tssB)
# to send to Alice
tssB, cB = MLKEM_Encaps(tpkA)
# Bob decapsulates Alice's ciphertext to produce tssAprime
tssAprime = MLKEM_Decaps(cA, skB)
# Bob hashes tssB and tssAprime to create his copy of the final shared key (ssB)
ssB = shake_256(tssB + tssAprime).digest(32)

# Bob sends Alice his ciphertext tcB
# Alice decapsulates Bob's tcB using her temporary secret key
# to recover Bob's temporary shared secret
tssBprime = MLKEM_Decaps(cB, tskA)
# Alice hashes tssBprime and tssA to produce her copy of the final shared key (ssA)
ssA = shake_256(tssBprime + tssA).digest(32)

print('Alice\'s SS:')
print(ssA.hex())
print()
print('Bob\'s SS:')
print(ssB.hex())
print()
print('Length of shared secret:', len(ssA), 'bytes')
print()
print('ssA == ssB?', ssA == ssB)

0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31


## Mutually Authenticated Key Exchange (AKE)

In [57]:
# Alice generates a new set of temporary keys
tpkA, tskA = MLKEM_KeyGen()
# Alice encapsulates Bob's static public key (pkB)
# to generare a ciphertext and shared secret (tcA, tSSA) to send to Bob
tssA, cA = MLKEM_Encaps(pkB)

# Alice sends (tpkA, cA) to Bob
# Bob encapsulates Alice's temporary public key (tpkA)
# to generare a ciphertext and shared secret (tcB, tssB)
tssB, cB = MLKEM_Encaps(tpkA)
# Bob encapsulates Alice's static public key (pkA)
# to generare a second ciphertext and shared secret (tcB2, tssB2)
tssB2, cB2 = MLKEM_Encaps(pkA)
# Bob decapsulates Alice's tcA using his static secret key (skB)
# to recover Alice's temporary shared secret
tssAprime = MLKEM_Decaps(cA, skB)
# Bob hashes tssB2, tssB2, and tssAprime to create his copy of the final shared key (ssB)
ssB = shake_256(tssB + tssB2 + tssAprime).digest(32)

# Bob sends tcB, tcB2 to Alice
# Alice decapsulates Bob's cB using her temprary secret key (tskA)
tssBprime = MLKEM_Decaps(cB, tskA)
# Alice decapsulates Bob's cB2 using her static secret key (skA)
tssB2prime = MLKEM_Decaps(cB2, skA)
# Alice hashes tssBprime, tssB2prime, and tssA to create her copy of the final shared key (ssA)
ssA = shake_256(tssBprime + tssB2prime + tssA).digest(32)

print('Alice\'s SS:')
print(ssA.hex())
print()
print('Bob\'s SS:')
print(ssB.hex())
print()
print('Length of shared secret:', len(ssA), 'bytes')
print()
print('ssA == ssB?', ssA == ssB)


0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
