In [2]:
import hashlib
import numpy as np

from tqdm import tqdm

# DFA by Piret and Quisquater

This template is constructed to aid with implementing both the `simple` and `full` DFA variants needed for the homework.

In [31]:
SBOX = np.array([
        0x63, 0x7C, 0x77, 0x7B, 0xF2, 0x6B, 0x6F, 0xC5, 0x30, 0x01, 0x67, 0x2B, 0xFE, 0xD7, 0xAB, 0x76,
        0xCA, 0x82, 0xC9, 0x7D, 0xFA, 0x59, 0x47, 0xF0, 0xAD, 0xD4, 0xA2, 0xAF, 0x9C, 0xA4, 0x72, 0xC0,
        0xB7, 0xFD, 0x93, 0x26, 0x36, 0x3F, 0xF7, 0xCC, 0x34, 0xA5, 0xE5, 0xF1, 0x71, 0xD8, 0x31, 0x15,
        0x04, 0xC7, 0x23, 0xC3, 0x18, 0x96, 0x05, 0x9A, 0x07, 0x12, 0x80, 0xE2, 0xEB, 0x27, 0xB2, 0x75,
        0x09, 0x83, 0x2C, 0x1A, 0x1B, 0x6E, 0x5A, 0xA0, 0x52, 0x3B, 0xD6, 0xB3, 0x29, 0xE3, 0x2F, 0x84,
        0x53, 0xD1, 0x00, 0xED, 0x20, 0xFC, 0xB1, 0x5B, 0x6A, 0xCB, 0xBE, 0x39, 0x4A, 0x4C, 0x58, 0xCF,
        0xD0, 0xEF, 0xAA, 0xFB, 0x43, 0x4D, 0x33, 0x85, 0x45, 0xF9, 0x02, 0x7F, 0x50, 0x3C, 0x9F, 0xA8,
        0x51, 0xA3, 0x40, 0x8F, 0x92, 0x9D, 0x38, 0xF5, 0xBC, 0xB6, 0xDA, 0x21, 0x10, 0xFF, 0xF3, 0xD2,
        0xCD, 0x0C, 0x13, 0xEC, 0x5F, 0x97, 0x44, 0x17, 0xC4, 0xA7, 0x7E, 0x3D, 0x64, 0x5D, 0x19, 0x73,
        0x60, 0x81, 0x4F, 0xDC, 0x22, 0x2A, 0x90, 0x88, 0x46, 0xEE, 0xB8, 0x14, 0xDE, 0x5E, 0x0B, 0xDB,
        0xE0, 0x32, 0x3A, 0x0A, 0x49, 0x06, 0x24, 0x5C, 0xC2, 0xD3, 0xAC, 0x62, 0x91, 0x95, 0xE4, 0x79,
        0xE7, 0xC8, 0x37, 0x6D, 0x8D, 0xD5, 0x4E, 0xA9, 0x6C, 0x56, 0xF4, 0xEA, 0x65, 0x7A, 0xAE, 0x08,
        0xBA, 0x78, 0x25, 0x2E, 0x1C, 0xA6, 0xB4, 0xC6, 0xE8, 0xDD, 0x74, 0x1F, 0x4B, 0xBD, 0x8B, 0x8A,
        0x70, 0x3E, 0xB5, 0x66, 0x48, 0x03, 0xF6, 0x0E, 0x61, 0x35, 0x57, 0xB9, 0x86, 0xC1, 0x1D, 0x9E,
        0xE1, 0xF8, 0x98, 0x11, 0x69, 0xD9, 0x8E, 0x94, 0x9B, 0x1E, 0x87, 0xE9, 0xCE, 0x55, 0x28, 0xDF,
        0x8C, 0xA1, 0x89, 0x0D, 0xBF, 0xE6, 0x42, 0x68, 0x41, 0x99, 0x2D, 0x0F, 0xB0, 0x54, 0xBB, 0x16
], dtype=np.uint8)

# Inverse of AES SBOX
ISBOX = SBOX.argsort()

# AES MixCols matrix
MIXCOLS = np.array([[2, 3, 1, 1],
                    [1, 2, 3, 1],
                    [1, 1, 2, 3],
                    [3, 1, 1, 2]])



In [1]:
# Check if you got the right answer for the 4 keybytes in the first column (key at indices [0, 13, 10, 7])
# Note: We're referring to the round 10 keybytes (no need to rewind with the the key schedule)
# Note: The same key is used both for the `simple` and `full` DFA

def check_keybytes(k_0: int, k_13: int, k_10: int, k_7: int):
    keybytes = bytes([k_0, k_13, k_10, k_7])
    hasher = hashlib.sha3_256()
    hasher.update(keybytes)
    key_hash = hasher.hexdigest()
    if key_hash == '4409976e63e88e6d0ef93405e6b6d678c2a498d22dcaa72b28c8c9cd6233ec7f':
        print("Congratulations! Correct 4 keybytes found")
        return True
    
    print("Not quite right")
    return False

In [None]:
# Two pairs of ciphertext/faulty texts.
# The fault is injected in the *first byte* before the MixCols in the 9th round
# Note: use for `simple` DFA (part C)

simple_ctxt1 = [174, 44, 204, 43, 18, 196, 238, 88, 3, 227, 92, 0, 137, 106, 205, 88]
simple_ftxt1 = [128, 44, 204, 43, 18, 196, 238, 171, 3, 227, 159, 0, 137, 186, 205, 88]

simple_ctxt2 = [41, 4, 148, 29, 23, 74, 41, 127, 125, 148, 36, 219, 29, 127, 4, 58]
simple_ftxt2 = [186, 4, 148, 29, 23, 74, 41, 160, 125, 148, 59, 219, 29, 172, 4, 58]


# Load ctext/ftext pairs in the correct AES column order
simple_ctxt1 = np.reshape(simple_ctxt1, (4, 4), order='F').astype(np.uint8)
simple_ftxt1 = np.reshape(simple_ftxt1, (4, 4), order='F').astype(np.uint8)

simple_ctxt2 = np.reshape(simple_ctxt2, (4, 4), order='F').astype(np.uint8)
simple_ftxt2 = np.reshape(simple_ftxt2, (4, 4), order='F').astype(np.uint8)

print("First pair diff:")
print(simple_ctxt1 ^ simple_ftxt1)

print("Second pair diff:")
print(simple_ctxt2 ^ simple_ftxt2)

: 

In [33]:
# Load all plaintexts/ciphertexts/faultytexts in the correct AES column order
# Note: not needed for the `simple` DFA (part C) 

row_type = np.dtype((np.uint8, (4, 4)))
all_ctext = np.fromfile("full_dfa_data/ctext.bin", dtype=row_type).transpose(0, 2, 1)
all_ptext = np.fromfile("full_dfa_data/ptext.bin", dtype=row_type).transpose(0, 2, 1)
all_ftext = np.fromfile("full_dfa_data/ftext.bin", dtype=row_type).transpose(0, 2, 1)

print(all_ctext[0] ^ all_ftext[0])

[[  0   0   0 109]
 [  0   0  50   0]
 [  0   3   0   0]
 [206   0   0   0]]


In [34]:
# Galois multiplication by 2 (for MixCols)
def galois_mult_2(a):
    temp = (a << 1) & 0xff

    if (a & 0x80):
        temp ^= 0x1b

    return temp

# Galois multiplication by 3 (for MixCols)
def galois_mult_3(a):
    return galois_mult_2(a) ^ a

In [35]:
# AES ShiftRows
# `mat` is 4x4 AES matrix
def shift(mat):
    shifted = np.zeros_like(mat)
    for i in range(4):
        shifted[i] = mat[i, np.arange(i, 4+i) % 4]
    return shifted

# TODO: implement your own. Should undo shift(mat)
def unshift(mat):
    pass


In [36]:
# Precompute all possible mixcols(glitch)
# TODO: Only works if glitch is in the first row! 
#       For `full` DFA need to add more entries to make it work for glitches in any row

D = []
mixcol = MIXCOLS[:, 0]
for x in range(1, 255+1):
    D_element = []
    for j in range(4):
        out = None
        if mixcol[j] == 1:
            out = x
        if mixcol[j] == 2:
            out = galois_mult_2(x)
        if mixcol[j] == 3:
            out = galois_mult_3(x)
        D_element.append(out)
    D.append(D_element)

print("Length of lookup table:", len(D))

Length of lookup table: 255


# Simple Attack Algorithm (finds 4 bytes):

## Variable Definitions
Let $c, c'$ be the first (ciphertext, faultytext) pair and $c, c^{*\prime}$ be the second pair.
Note: $c_0$ refers to the $0$ index byte of the ciphertext

## Preliminary Filtering
1. For each $K_0 \in 0 \text{ to } 255$ and $K_{13} \in 0 \text{ to } 255$, and $x \in 1 \text{ to } 255$
    - Check if $ISBOX(K_0 \oplus c_0) \oplus ISBOX(K_0 \oplus c_0') \stackrel{?}{=} D[x, 0]$
    - Check if $ISBOX(K_{13} \oplus c_{13}) \oplus ISBOX(K_{13} \oplus c_{13}') \stackrel{?}{=} D[x, 1]$

    - If both are true, add candidates $\{K_0, K_{13}\}$ to group $K_{c2}$
2. For each $K_{10} \in 0 \text{ to } 255$, and $\{K_{0}, K_{13}\} \in K_{c2}$, and $x \in 1 \text{ to } 255$
    - Check if $ISBOX(K_0 \oplus c_0) \oplus ISBOX(K_0 \oplus c_0') \stackrel{?}{=} D[x, 0]$
    - Check if $ISBOX(K_{13} \oplus c_{13}) \oplus ISBOX(K_{13} \oplus c_{13}') \stackrel{?}{=} D[x, 1]$
    - Check if $ISBOX(K_{10} \oplus c_{10}) \oplus ISBOX(K_{10} \oplus c_{10}') \stackrel{?}{=} D[x, 2]$

    - If so, add candidates $\{K_0, K_{13}, K_{10}\}$ to group $K_{c3}$
3. For each $K_{7} \in 0 \text{ to } 255$, and $\{K_{0}, K_{13}, K_{10}\} \in K_{c3}$, and $x \in 1 \text{ to } 255$
    - Check if $ISBOX(K_0 \oplus c_0) \oplus ISBOX(K_0 \oplus c_0') \stackrel{?}{=} D[x, 0]$
    - Check if $ISBOX(K_{13} \oplus c_{13}) \oplus ISBOX(K_{13} \oplus c_{13}') \stackrel{?}{=} D[x, 1]$
    - Check if $ISBOX(K_{10} \oplus c_{10}) \oplus ISBOX(K_{10} \oplus c_{10}') \stackrel{?}{=} D[x, 2]$
    - Check if $ISBOX(K_{7} \oplus c_{7}) \oplus ISBOX(K_{7} \oplus c_{7}') \stackrel{?}{=} D[x, 3]$

    - If so, add candidates $\{K_0, K_{13}, K_{10}, K_{7}\}$ to group $K_{c4}$
    
## Finding Final Candidate
1. For each $\{K_{0}, K_{13}, K_{10}, K_7\} \in K_{c4}$, and $x \in 1 \text{ to } 255$
  - $ISBOX(K_0 \oplus c^{*}_0) \oplus ISBOX(K_0 \oplus c^{*\prime}_0) \stackrel{?}{=} D[x, 0]$
  - $ISBOX(K_{13} \oplus c^{*}_{13}) \oplus ISBOX(K_{13} \oplus c^{*\prime}_{13}) \stackrel{?}{=} D[x, 1]$
  - $ISBOX(K_{10} \oplus c^{*}_{10}) \oplus ISBOX(K_{10} \oplus c^{*\prime}_{10}) \stackrel{?}{=} D[x, 2]$
  - $ISBOX(K_{7} \oplus c^{*}_{7}) \oplus ISBOX(K_{7} \oplus c^{*\prime}_{7}) \stackrel{?}{=} D[x, 3]$
  - If all are true $\{K_0, K_{13}, K_{10}, K_{7}\}$, are likely correct keybytes


# Indexing Tips:
You can access elements such as $c_{0}, c_{13}, c_{10}, c_{7}$ by using np.unravel() like this:

In [27]:
c13 = simple_ctxt1[np.unravel_index(13, shape=(4,4), order='F')]

In [28]:
# Perform `simple` DFA below: you should use (`simple_ctxt1`, `simple_ftxt1`) and (`simple_ctxt2`, `simple_ftxt2`)
# Note: remember to use `check_keybytes()` function to check your answer!
def simple_attack():
    pass

In [29]:
# Perform `full` DFA below: you should use `all_ctext` and `all_ftext`
# Note: You can still use `check_keybytes()` function to check your answer for the first 4 keybytes
# Note: After recovering the whole key, you need to rewind the key schedule to get the original key
#       You can you use the python library `aeskeyschedule` (pip install aeskeyschedule) if you wish
def full_attack():
    pass

In [30]:
# After recovering the full key, decrypt the secret message:
# You can use the `pycryptodome` library 
secret = bytes.fromhex("2a92fc6ad8006b658f49062c2843ad99")