In [1]:
from aes_utils.aes_utils import *

import numpy as np

In [2]:
rng = np.random.default_rng(0)

def AES128_encrypt_with_fault(state, key, pos=(0, 0), fault=0):
    keys = list(key_expansion(key, 10))

    # round 0
    state = state ^ keys[0]

    # round 1..9
    for round in range(1, 10):
        
        if round == 8:
            state[pos] ^= fault
            
        state = sub_bytes(state)
        state = shift_rows(state)
        state = mix_columns(state)
        state = state ^ keys[round]
    
    # round 10
    state = sub_bytes(state)
    state = shift_rows(state)
    state = state ^ keys[10]
    
    return state

# def it_AES128_encrypt_with_fault(state, key, fault=0):
#     keys = list(key_expansion(key, 10))
#     state = state ^ keys[0]
#     yield state.copy()

#     for round in range(1, 10):
#         if round == 8:
#             state[0, 0] ^= fault
#         state = mix_columns(shift_rows(sub_bytes(state))) ^ keys[round]
#         yield state.copy()
    
#     state = shift_rows(sub_bytes(state)) ^ keys[10]
#     yield state.copy()

def random_state():
    return rng.integers(256, size=(4, 4), dtype=np.uint8)

def print_state(X, compact=True):
    if compact:
        print(*(f'{x:02x}' for x in X.flatten('F')), sep='')
    else:
        print(*(f'0x{x:02x}' for x in X.flatten('F')), sep=', ')

def repr_state(X, compact=False):
    if compact:
        return ''.join(f'{x:02x}' for x in X.flatten('F'))
    else:
        return ', '.join(f'0x{x:02x}' for x in X.flatten('F'))

def fault_position_index_to_ij(index):
    j, i = divmod(index, 4)
    return i, j

def print_third_stage_test(X, Y, Y_, K0, fault_position):
    print("tuple<FlatState, FlatState, FlatState, FlatState, size_t> {");
    print(f"    {{{repr_state(X)}}},")
    print(f"    {{{repr_state(Y)}}},")
    print(f"    {{{repr_state(Y_)}}},")
    print(f"    {{{repr_state(K0)}}},")
    print(f"    {fault_position}")
    print("},")

In [3]:
rng = np.random.default_rng()

K0 = random_state()
K10 = list(key_expansion(K0))[-1]
# X = np.zeros((4, 4), dtype=np.uint8)
X = random_state()

Y  = AES128_encrypt(X, K0)
Y_ = AES128_encrypt_with_fault(X, K0, pos=(0, 0), fault=rng.integers(256))

for _ in range(0):
    K = random_state()
    X = random_state()

    fault_position = rng.integers(16)
    ij_position = fault_position_index_to_ij(fault_position)

    Y  = AES128_encrypt(X, K)
    Y_ = AES128_encrypt_with_fault(X, K, pos=ij_position, fault=rng.integers(256))

    print_third_stage_test(X, Y, Y_, K, fault_position)

In [10]:
X = random_state()
K = random_state()

fault_position = rng.integers(16)
ij_position = fault_position_index_to_ij(fault_position)

Y  = AES128_encrypt(X, K)
Y_ = AES128_encrypt_with_fault(X, K, pos=ij_position, fault=rng.integers(256))


str_X = repr_state(X, compact=True);
str_Y = repr_state(Y, compact=True);
str_Y_ = repr_state(Y_, compact=True);

print(str_X, str_Y, str_Y_, fault_position);
print_state(K)

7a9e2ffeffd7d96127746ce6092462c5 b5a92733b663711e1659b4caab7fb22f eed659d47b0b302641c816fa79bb7ac9 9
af63494273c40dc98bb7974f60e9cb5d
