In [1]:
import z3
from css.mangle import mangle, mix, unmix, shift, unshift, table, reverse_table

ks = [ z3.BitVec(f"k{i}", 8) for i in range(8) ]
pts = [ z3.BitVec(f"pt{i}", 8) for i in range(8) ]
cts = [ z3.BitVec(f"ct{i}", 8) for i in range(8) ]

t = z3.Function("t", z3.BitVecSort(8), z3.BitVecSort(8))
rt = z3.Function("rt", z3.BitVecSort(8), z3.BitVecSort(8))

def tabulate(x):
    return [t(z3.simplify(i)) for i in x]
def untabulate(x):
    return [rt(z3.simplify(i)) for i in x]

def mangle_mitm(key, pts, cts):
    fwd = pts

    fwd = mix(key, fwd)
    fwd = shift(fwd)
    fwd = mix(key, fwd)
    fwd = shift(fwd)
    fwd = mix(key, fwd)
    fwd = tabulate(fwd)
    fwd = shift(fwd)
    fwd = mix(key, fwd)

    rev = cts
    rev = unmix(key, rev)
    rev = unshift(rev)
    rev = unmix(key, rev)
    rev = unshift(rev)
    rev = untabulate(rev)

    return fwd, rev

fwd_sym, rev_sym = mangle_mitm(ks, pts, cts)

def mitm_report(which):
    fwd_sym_str = str(fwd_sym[which])
    rev_sym_str = str(rev_sym[which])

    print(f"fwd[{which}] = (\n{fwd_sym_str}\n)")
    print(f"rev[{which}] = (\n{rev_sym_str}\n)")

In [2]:
mitm_report(0)

fwd[0] = (
t(pt0 ^ k7 ^ pt6 ^ k6 ^ pt5 ^ k0) ^
t(pt7 ^ k6 ^ pt5 ^ k7 ^ pt6 ^ k5 ^ pt4) ^
k0 ^
0
)
rev[0] = (
rt(k0 ^ ct7 ^ ct5 ^ ct3 ^ ct1 ^ k2 ^ k4 ^ k6 ^ ct0)
)


In [3]:
mitm_report(7)

fwd[7] = (
t(pt7 ^ k6 ^ pt5 ^ k7 ^ pt6 ^ k5 ^ pt4) ^
k7 ^
t(pt6 ^ k5 ^ pt4 ^ k6 ^ pt5 ^ k4 ^ pt3)
)
rev[7] = (
rt(ct6 ^ ct4 ^ ct2 ^ ct0 ^ k1 ^ k3 ^ k5 ^ k7)
)


In [4]:
from random import randbytes

# Generate samples of ct = mangle(key, pt) of mangle from one shared key
# Sweet spot seems to be around 32 samples
n_samples = 32

demo_key = list(randbytes(8))

pts = [
    list(randbytes(8))
    for _ in range(n_samples)
]
cts = [
    list(mangle(demo_key, i))
    for i in pts
]

print(f"{demo_key = }")

demo_key = [27, 25, 196, 182, 180, 11, 79, 81]


In [5]:
s = z3.Solver()

for idx, table_i in enumerate(table):
    s.add(t(idx) == table_i)
for idx, reverse_table_i in enumerate(reverse_table):
    s.add(rt(idx) == reverse_table_i)

k0, k67, k5, k246 = z3.BitVecs("k0 k67 k5 k246", 8)
for pti, cti in zip(pts, cts):
    pt0, pt1, pt2, pt3, pt4, pt5, pt6, pt7 = pti
    ct0, ct1, ct2, ct3, ct4, ct5, ct6, ct7 = cti

    # enforce mitm at byte 0
    # all usages of key can be separated into 4 linear components
    # key[0], key[6] ^ key[7], key[5] and key[2] ^ key[4] ^ key[6]
    # brute force is only 32 bits
    fwd = (
        t(pt0 ^ pt5 ^ pt6 ^ k0 ^ k67) ^
        t(pt4 ^ pt5 ^ pt6 ^ pt7 ^ k5 ^ k67) ^
        k0
    )
    rev = rt(ct0 ^ ct1 ^ ct3 ^ ct5 ^ ct7 ^ k0 ^ k246)

    s.add(fwd == rev)

print(s.check())
m = s.model()

known_k0, known_k67, known_k5, known_k246 = [
    m[ki].as_long()
    for ki in [k0, k67, k5, k246]
]
print(f"""\
{known_k0   = }
{known_k67  = }
{known_k5   = }
{known_k246 = }""")

sat
known_k0   = 27
known_k67  = 30
known_k5   = 11
known_k246 = 63


In [6]:
s = z3.Solver()

for idx, table_i in enumerate(table):
    s.add(t(idx) == table_i)
for idx, reverse_table_i in enumerate(reverse_table):
    s.add(rt(idx) == reverse_table_i)

k56, k7, k4, k135 = z3.BitVecs("k56 k7 k4 k135", 8)
for pti, cti in zip(pts, cts):
    pt0, pt1, pt2, pt3, pt4, pt5, pt6, pt7 = pti
    ct0, ct1, ct2, ct3, ct4, ct5, ct6, ct7 = cti

    # enforce mitm at byte 7
    # 4 linear components, also 32 bit brute force
    fwd = (
        t(pt4 ^ pt5 ^ pt6 ^ pt7 ^ k56 ^ k7) ^
        t(pt3 ^ pt4 ^ pt5 ^ pt6 ^ k4 ^ k56) ^
        k7
    )
    rev = (
        rt(ct0 ^ ct2 ^ ct4 ^ ct6 ^ k135 ^ k7)
    )

    s.add(fwd == rev)

print(s.check())
m = s.model()

known_k56, known_k7, known_k4, known_k135 = [
    m[ki].as_long()
    for ki in [k56, k7, k4, k135]
]

print(f"""\
{known_k56  = }
{known_k7   = }
{known_k4   = }
{known_k135 = }""")

sat
known_k56  = 68
known_k7   = 81
known_k4   = 180
known_k135 = 164


In [7]:
s = z3.Solver()

for idx, table_i in enumerate(table):
    s.add(t(idx) == table_i)
for idx, reverse_table_i in enumerate(reverse_table):
    s.add(rt(idx) == reverse_table_i)

s.add(known_k0 == ks[0])
s.add(known_k4 == ks[4])
s.add(known_k5 == ks[5])
s.add(known_k56 ^ known_k5 == ks[6])
s.add(known_k7 == ks[7])

for pti, cti in zip(pts, cts):
    # symbolic solve for the remaining 24 bits
    fwd, rev = mangle_mitm(ks, pti, cti)
    for fwdi, revi in zip(fwd, rev):
        s.add(fwdi == revi)

print(s.check())
m = s.model()
derived_key = [
    m[ki].as_long()
    for ki in ks
]
print(f"{derived_key = }")
assert(derived_key == demo_key)

sat
derived_key = [27, 25, 196, 182, 180, 11, 79, 81]
