# AES-128 and Differential Fault on AES-128 using one single byte

In [32]:
## Uses https://github.com/bozhu/AES-Python/blob/master/aes.py and the book The Design of Rijndael for references
import random
import time
from multiprocessing import Pool

## Understanding the use of Hexadecimal

In [33]:
print("{0:b}".format(0xFF))### Converting hexadecimal to binary
"{0:X}".format(int('100011011',2))### Converting binary to hexadecimal

11111111


'11B'

In [51]:
0x81 & 0x80

128

## S-Boxes

In [34]:
s_box = (
    0x63, 0x7C, 0x77, 0x7B, 0xF2, 0x6B, 0x6F, 0xC5, 0x30, 0x01, 0x67, 0x2B, 0xFE, 0xD7, 0xAB, 0x76,
    0xCA, 0x82, 0xC9, 0x7D, 0xFA, 0x59, 0x47, 0xF0, 0xAD, 0xD4, 0xA2, 0xAF, 0x9C, 0xA4, 0x72, 0xC0,
    0xB7, 0xFD, 0x93, 0x26, 0x36, 0x3F, 0xF7, 0xCC, 0x34, 0xA5, 0xE5, 0xF1, 0x71, 0xD8, 0x31, 0x15,
    0x04, 0xC7, 0x23, 0xC3, 0x18, 0x96, 0x05, 0x9A, 0x07, 0x12, 0x80, 0xE2, 0xEB, 0x27, 0xB2, 0x75,
    0x09, 0x83, 0x2C, 0x1A, 0x1B, 0x6E, 0x5A, 0xA0, 0x52, 0x3B, 0xD6, 0xB3, 0x29, 0xE3, 0x2F, 0x84,
    0x53, 0xD1, 0x00, 0xED, 0x20, 0xFC, 0xB1, 0x5B, 0x6A, 0xCB, 0xBE, 0x39, 0x4A, 0x4C, 0x58, 0xCF,
    0xD0, 0xEF, 0xAA, 0xFB, 0x43, 0x4D, 0x33, 0x85, 0x45, 0xF9, 0x02, 0x7F, 0x50, 0x3C, 0x9F, 0xA8,
    0x51, 0xA3, 0x40, 0x8F, 0x92, 0x9D, 0x38, 0xF5, 0xBC, 0xB6, 0xDA, 0x21, 0x10, 0xFF, 0xF3, 0xD2,
    0xCD, 0x0C, 0x13, 0xEC, 0x5F, 0x97, 0x44, 0x17, 0xC4, 0xA7, 0x7E, 0x3D, 0x64, 0x5D, 0x19, 0x73,
    0x60, 0x81, 0x4F, 0xDC, 0x22, 0x2A, 0x90, 0x88, 0x46, 0xEE, 0xB8, 0x14, 0xDE, 0x5E, 0x0B, 0xDB,
    0xE0, 0x32, 0x3A, 0x0A, 0x49, 0x06, 0x24, 0x5C, 0xC2, 0xD3, 0xAC, 0x62, 0x91, 0x95, 0xE4, 0x79,
    0xE7, 0xC8, 0x37, 0x6D, 0x8D, 0xD5, 0x4E, 0xA9, 0x6C, 0x56, 0xF4, 0xEA, 0x65, 0x7A, 0xAE, 0x08,
    0xBA, 0x78, 0x25, 0x2E, 0x1C, 0xA6, 0xB4, 0xC6, 0xE8, 0xDD, 0x74, 0x1F, 0x4B, 0xBD, 0x8B, 0x8A,
    0x70, 0x3E, 0xB5, 0x66, 0x48, 0x03, 0xF6, 0x0E, 0x61, 0x35, 0x57, 0xB9, 0x86, 0xC1, 0x1D, 0x9E,
    0xE1, 0xF8, 0x98, 0x11, 0x69, 0xD9, 0x8E, 0x94, 0x9B, 0x1E, 0x87, 0xE9, 0xCE, 0x55, 0x28, 0xDF,
    0x8C, 0xA1, 0x89, 0x0D, 0xBF, 0xE6, 0x42, 0x68, 0x41, 0x99, 0x2D, 0x0F, 0xB0, 0x54, 0xBB, 0x16,
)


inv_s_box = (
    0x52, 0x09, 0x6A, 0xD5, 0x30, 0x36, 0xA5, 0x38, 0xBF, 0x40, 0xA3, 0x9E, 0x81, 0xF3, 0xD7, 0xFB,
    0x7C, 0xE3, 0x39, 0x82, 0x9B, 0x2F, 0xFF, 0x87, 0x34, 0x8E, 0x43, 0x44, 0xC4, 0xDE, 0xE9, 0xCB,
    0x54, 0x7B, 0x94, 0x32, 0xA6, 0xC2, 0x23, 0x3D, 0xEE, 0x4C, 0x95, 0x0B, 0x42, 0xFA, 0xC3, 0x4E,
    0x08, 0x2E, 0xA1, 0x66, 0x28, 0xD9, 0x24, 0xB2, 0x76, 0x5B, 0xA2, 0x49, 0x6D, 0x8B, 0xD1, 0x25,
    0x72, 0xF8, 0xF6, 0x64, 0x86, 0x68, 0x98, 0x16, 0xD4, 0xA4, 0x5C, 0xCC, 0x5D, 0x65, 0xB6, 0x92,
    0x6C, 0x70, 0x48, 0x50, 0xFD, 0xED, 0xB9, 0xDA, 0x5E, 0x15, 0x46, 0x57, 0xA7, 0x8D, 0x9D, 0x84,
    0x90, 0xD8, 0xAB, 0x00, 0x8C, 0xBC, 0xD3, 0x0A, 0xF7, 0xE4, 0x58, 0x05, 0xB8, 0xB3, 0x45, 0x06,
    0xD0, 0x2C, 0x1E, 0x8F, 0xCA, 0x3F, 0x0F, 0x02, 0xC1, 0xAF, 0xBD, 0x03, 0x01, 0x13, 0x8A, 0x6B,
    0x3A, 0x91, 0x11, 0x41, 0x4F, 0x67, 0xDC, 0xEA, 0x97, 0xF2, 0xCF, 0xCE, 0xF0, 0xB4, 0xE6, 0x73,
    0x96, 0xAC, 0x74, 0x22, 0xE7, 0xAD, 0x35, 0x85, 0xE2, 0xF9, 0x37, 0xE8, 0x1C, 0x75, 0xDF, 0x6E,
    0x47, 0xF1, 0x1A, 0x71, 0x1D, 0x29, 0xC5, 0x89, 0x6F, 0xB7, 0x62, 0x0E, 0xAA, 0x18, 0xBE, 0x1B,
    0xFC, 0x56, 0x3E, 0x4B, 0xC6, 0xD2, 0x79, 0x20, 0x9A, 0xDB, 0xC0, 0xFE, 0x78, 0xCD, 0x5A, 0xF4,
    0x1F, 0xDD, 0xA8, 0x33, 0x88, 0x07, 0xC7, 0x31, 0xB1, 0x12, 0x10, 0x59, 0x27, 0x80, 0xEC, 0x5F,
    0x60, 0x51, 0x7F, 0xA9, 0x19, 0xB5, 0x4A, 0x0D, 0x2D, 0xE5, 0x7A, 0x9F, 0x93, 0xC9, 0x9C, 0xEF,
    0xA0, 0xE0, 0x3B, 0x4D, 0xAE, 0x2A, 0xF5, 0xB0, 0xC8, 0xEB, 0xBB, 0x3C, 0x83, 0x53, 0x99, 0x61,
    0x17, 0x2B, 0x04, 0x7E, 0xBA, 0x77, 0xD6, 0x26, 0xE1, 0x69, 0x14, 0x63, 0x55, 0x21, 0x0C, 0x7D,
)

## Operation
Input be a 4x4 matrix (list of list). with each list inside be the column of the matrix.

### Sub-Bytes

In [35]:
def subbytes(matrix):
    size_row = len(matrix)
    size_col = len(matrix[0])
    assert size_row == 4 and size_col == 4
    for i in range(size_row):
        for j in range(size_col): 
            matrix[i][j] = s_box[matrix[i][j]]
            
def inv_subbytes(matrix): 
    size_row = len(matrix)
    size_col = len(matrix[0])
    assert size_row == 4 and size_col == 4
    for i in range(size_row):
        for j in range(size_col): 
            matrix[i][j] = inv_s_box[matrix[i][j]]

In [36]:
##Testing subbytes
matrix1= [[0x52, 0x09, 0x6A, 0xD5],[0x3A, 0x91, 0x11, 0x41], [0xA7, 0x8D, 0x9D, 0x84],[0x07, 0x12, 0x80, 0xE2]]
print(matrix1)
subbytes(matrix1)
print(matrix1)
inv_subbytes(matrix1)
print(matrix1)

[[82, 9, 106, 213], [58, 145, 17, 65], [167, 141, 157, 132], [7, 18, 128, 226]]
[[0, 1, 2, 3], [128, 129, 130, 131], [92, 93, 94, 95], [197, 201, 205, 152]]
[[82, 9, 106, 213], [58, 145, 17, 65], [167, 141, 157, 132], [7, 18, 128, 226]]


### ShiftRow

In [37]:
def shiftrow(matrix):
    size_row = len(matrix)
    size_col = len(matrix[0])
    assert size_row == 4 and size_col == 4
    matrix[0][1],matrix[1][1],matrix[2][1], matrix[3][1] = matrix[1][1], matrix[2][1], matrix[3][1], matrix[0][1]
    matrix[0][2],matrix[1][2],matrix[2][2], matrix[3][2] = matrix[2][2], matrix[3][2], matrix[0][2], matrix[1][2]
    matrix[0][3],matrix[1][3],matrix[2][3], matrix[3][3] = matrix[3][3], matrix[0][3], matrix[1][3], matrix[2][3]

def inv_shiftrow(matrix):
    size_row = len(matrix)
    size_col = len(matrix[0])
    assert size_row == 4 and size_col == 4
    matrix[0][1],matrix[1][1],matrix[2][1], matrix[3][1] = matrix[3][1], matrix[0][1], matrix[1][1], matrix[2][1] 
    matrix[0][2],matrix[1][2],matrix[2][2], matrix[3][2] = matrix[2][2], matrix[3][2], matrix[0][2], matrix[1][2]
    matrix[0][3],matrix[1][3],matrix[2][3], matrix[3][3] = matrix[1][3], matrix[2][3], matrix[3][3], matrix[0][3]

In [38]:
#Testing shiftrow
matrix1= [[0x52, 0x09, 0x6A, 0xD5],[0x3A, 0x91, 0x11, 0x41], [0xA7, 0x8D, 0x9D, 0x84],[0x07, 0x12, 0x80, 0xE2]]
print(matrix1)
shiftrow(matrix1)
print(matrix1)
inv_shiftrow(matrix1)
print(matrix1)

[[82, 9, 106, 213], [58, 145, 17, 65], [167, 141, 157, 132], [7, 18, 128, 226]]
[[82, 145, 157, 226], [58, 141, 128, 213], [167, 18, 106, 65], [7, 9, 17, 132]]
[[82, 9, 106, 213], [58, 145, 17, 65], [167, 141, 157, 132], [7, 18, 128, 226]]


### Mix Column

In [39]:
def xtimes(a): ### note this is actually vulnerable to power analysis, actually right now is mostly implemented using a table of all the values of x.
    if(a & 0x80): ##this will return true if there is '1' in the x^7 position.
        return (((a << 1)^0x1B) & 0xFF) ### ^0x1B is because x^8 = x^4+x^3+x+1 so just sub and xoring, &0xFF is trunc all those is above degree 7. 
    else:
        return (a<<1)
    
def mix_single_column(column):
    assert len(column) == 4
    t = column[0]^column[1]^column[2]^column[3]
    u = column[0]
    column[0] ^= xtimes(column[0]^column[1]) ^ t
    column[1] ^= xtimes(column[1]^column[2]) ^ t
    column[2] ^= xtimes(column[2]^column[3]) ^ t
    column[3] ^= xtimes(column[3]^u) ^ t
    

def mix_column(matrix): ## Look at The Design of Rijndael section Sec 4.1.2 this is to optimize the code so as to not use xtimes many times.
    assert len(matrix) == 4 and len(matrix[0])==4
    for i in range(4):
        mix_single_column(matrix[i])
        
def inv_mix_column(matrix): ##Can split the matrix into 2 a small one then followed by mix column matrix. Look at The Design of Rijndael section Sec 4.1.3
    assert len(matrix) == 4 and len(matrix[0])==4
    for i in range(4):
        col = matrix[i]
        u = xtimes(xtimes(col[0]^col[2]))
        v = xtimes(xtimes(col[1]^col[3]))
        col[0] ^= u
        col[1] ^= v
        col[2] ^= u
        col[3] ^= v
    mix_column(matrix)


In [40]:
##Testing mix column
matrix1= [[0x52, 0x09, 0x6A, 0xD5],[0x3A, 0x91, 0x11, 0x41], [0xA7, 0x8D, 0x9D, 0x84],[0x07, 0x12, 0x80, 0xE2]]
print('orginal matrix:')
print(matrix1)
mix_column(matrix1)
print('mixed matrix:')
print(matrix1)
##Testing inv_mix_column
inv_mix_column(matrix1)
print('inv mix matrix:')
print(matrix1)

orginal matrix:
[[82, 9, 106, 213], [58, 145, 17, 65], [167, 141, 157, 132], [7, 18, 128, 226]]
mixed matrix:
[[0, 43, 235, 36], [140, 113, 74, 76], [192, 158, 156, 241], [90, 90, 51, 68]]
inv mix matrix:
[[82, 9, 106, 213], [58, 145, 17, 65], [167, 141, 157, 132], [7, 18, 128, 226]]


### Add keys

In [43]:
def add_key(matrix, key_matrix):
    for i in range(4):
        for j in range(4):
            matrix[i][j]^= key_matrix[i][j]

## Key Schedule

In [42]:
Rcon = [] ### How r_con is created. Found in The Design of Rijndael section 3.6.2
for j in range(11):
    if (j == 0):
        Rcon.append(0x01)
        print(hex(0x01))
    elif (j == 1): 
        Rcon.append(0x02)
        print(hex(0x02))
    else:
        Rcon.append(xtimes(Rcon[j-1]))
        print(hex(xtimes(Rcon[j-1])))
print(len(Rcon))

0x1
0x2
0x4
0x8
0x10
0x20
0x40
0x80
0x1b
0x36
0x6c
11


In [44]:
r_con = (
    0x00, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40,
    0x80, 0x1B, 0x36, 0x6C, 0xD8, 0xAB, 0x4D, 0x9A,
    0x2F, 0x5E, 0xBC, 0x63, 0xC6, 0x97, 0x35, 0x6A,
    0xD4, 0xB3, 0x7D, 0xFA, 0xEF, 0xC5, 0x91, 0x39
) ### padded with 0x00 to make sure the r_con works as below.
    
def create_keys(key_matrix): ####Look at The Design of Rijndael section 3.6.2
    W = [] #### W = [k_0,k_1,..., k_43] where each k_i is a column for 4. 
    ########### [k_i,k_i+1,k_i+2,k_i+3] is a round key for round i%4+1, where i multiple of 4, i=1,4,..,40
    for i in range(len(key_matrix)):
        W.append(key_matrix[i])
    for j in range(4,4*11):
        W.append([])
        if (j%4 == 0): 
            first_ele = W[j-4][0] ^ s_box[W[j-1][1]] ^ r_con[int(j/4)]
            W[j].append(first_ele)
            for k in range(1,4):
                W[j].append(W[j-4][k]^s_box[W[j-1][(k+1)%4]])
        else:
            for k in range(4):
                W[j].append(W[j-4][k]^W[j-1][k])
    return W

In [45]:
#### Testing for key schedule
key_matrix1= [[0x52, 0x09, 0x6A, 0xD5],[0x3A, 0x91, 0x11, 0x41], [0xA7, 0x8D, 0x9D, 0x84],[0x07, 0x12, 0x80, 0xE2]]
W = create_keys(key_matrix1)
print(W)

[[82, 9, 106, 213], [58, 145, 17, 65], [167, 141, 157, 132], [7, 18, 128, 226], [154, 196, 242, 16], [160, 85, 227, 81], [7, 216, 126, 213], [0, 202, 254, 55], [236, 127, 104, 115], [76, 42, 139, 34], [75, 242, 245, 247], [75, 56, 11, 192], [239, 84, 210, 192], [163, 126, 89, 226], [232, 140, 172, 21], [163, 180, 167, 213], [106, 8, 209, 202], [201, 118, 136, 40], [33, 250, 36, 61], [130, 78, 131, 232], [85, 228, 74, 217], [156, 146, 194, 241], [189, 104, 230, 204], [63, 38, 101, 36], [130, 169, 124, 172], [30, 59, 190, 93], [163, 83, 88, 145], [156, 117, 61, 181], [95, 142, 169, 114], [65, 181, 23, 47], [226, 230, 79, 190], [126, 147, 114, 11], [3, 206, 130, 129], [66, 123, 149, 174], [160, 157, 218, 16], [222, 14, 168, 27], [179, 12, 45, 156], [241, 119, 184, 50], [81, 234, 98, 34], [143, 228, 202, 57], [236, 120, 63, 239], [29, 15, 135, 221], [76, 229, 229, 255], [195, 1, 47, 198]]


## AES encryption and decryption

In [46]:
def encrypt(plaintext_matrix, key_matrix):
    W = create_keys(key_matrix)
    ##Round 0 
    k_1 = [W[0],W[1],W[2],W[3]]
    add_key(plaintext_matrix, k_1)
    ##Round 1-9 
    for i in range(1,10):
        subbytes(plaintext_matrix)
        shiftrow(plaintext_matrix)
        mix_column(plaintext_matrix)
        k = [W[4*i],W[4*i+1],W[4*i+2],W[4*i+3]]
        add_key(plaintext_matrix, k)
    ## Round 10
    subbytes(plaintext_matrix)
    shiftrow(plaintext_matrix)
    k_10 = [W[40],W[41],W[42],W[43]]
    add_key(plaintext_matrix, k_10)
    return plaintext_matrix

In [47]:
def decrypt(ciphertext_matrix, key_matrix):
    W = create_keys(key_matrix)
    #Round 0
    k_10 = [W[40],W[41],W[42],W[43]]
    add_key(ciphertext_matrix, k_10) #### Note the first and the last round do not need to have inv_mix_column on it.
    ##Round 1-9
    for i in range(9,0,-1):
        inv_subbytes(ciphertext_matrix)
        inv_shiftrow(ciphertext_matrix)
        inv_mix_column(ciphertext_matrix)
        eq_k =[W[4*i],W[4*i+1],W[4*i+2],W[4*i+3]]
        inv_mix_column(eq_k)
        add_key(ciphertext_matrix, eq_k)
    ## Round 10
    inv_subbytes(ciphertext_matrix)
    inv_shiftrow(ciphertext_matrix)
    k_1 = [W[0],W[1],W[2],W[3]]
    add_key(ciphertext_matrix, k_1)
    return ciphertext_matrix

In [48]:
### Testing encryption and decrypt function
key_matrix1= [[0x52, 0x09, 0x6A, 0xD5],[0x3A, 0x91, 0x11, 0x41], [0xA7, 0x8D, 0x9D, 0x84],[0x07, 0x12, 0x80, 0xE2]]
Matrix2 = [[0x4A, 0x0D, 0x2D, 0xE5],[0xC9, 0x7D, 0xFA, 0x59], [0xA3, 0x9E, 0x81, 0xF3], [0x02, 0x04, 0x08, 0x10]]
print('plaintext:')
print(Matrix2)
ciphertext = encrypt(Matrix2, key_matrix1)
print('ciphertext:')
print(ciphertext)
decrypted_text = decrypt(ciphertext, key_matrix1)
print('decrypted_text:')
print(decrypted_text)

plaintext:
[[74, 13, 45, 229], [201, 125, 250, 89], [163, 158, 129, 243], [2, 4, 8, 16]]
ciphertext:
[[115, 32, 212, 197], [176, 194, 168, 6], [160, 177, 239, 29], [209, 104, 29, 95]]
decrypted_text:
[[74, 13, 45, 229], [201, 125, 250, 89], [163, 158, 129, 243], [2, 4, 8, 16]]


## Differential Fault in Encryption

### Creating a fault on 9th round

Look at https://www.youtube.com/watch?v=izF1CbDHCPc&list=PLbRMhDVUMngfulSvKL0cT-tn8ULtERsWk&index=47 for references on the 9th round.

In [49]:
def faulty_encrypt_on_round9(plaintext_matrix, key_matrix): ###firstly test on the first byte on the first round
    W = create_keys(key_matrix)
    ##Round 0 
    k_1 = [W[0],W[1],W[2],W[3]]
    add_key(plaintext_matrix, k_1)
    ##Round 1-7
    for i in range(1,8):
        subbytes(plaintext_matrix)
        shiftrow(plaintext_matrix)
        mix_column(plaintext_matrix)
        k = [W[4*i],W[4*i+1],W[4*i+2],W[4*i+3]]
        add_key(plaintext_matrix, k)
    ##Round 8
    subbytes(plaintext_matrix)
    shiftrow(plaintext_matrix)
    mix_column(plaintext_matrix)
    k_8 = [W[4*8],W[4*8+1],W[4*8+2],W[4*8+3]]
    add_key(plaintext_matrix, k_8)
    ## Round 9
    error = 0x01
    plaintext_matrix[0][0] ^= error
    subbytes(plaintext_matrix)
    shiftrow(plaintext_matrix)
    mix_column(plaintext_matrix)
    k_9 = [W[4*9],W[4*9+1],W[4*9+2],W[4*9+3]]
    add_key(plaintext_matrix, k_9)
    ## Round 10
    subbytes(plaintext_matrix)
    shiftrow(plaintext_matrix)
    k_10 = [W[40],W[41],W[42],W[43]]
    add_key(plaintext_matrix, k_10)
    return plaintext_matrix

In [51]:
### Testing faulty function
key_matrix1= [[0x52, 0x09, 0x6A, 0xD5],[0x3A, 0x91, 0x11, 0x41], [0xA7, 0x8D, 0x9D, 0x84],[0x07, 0x12, 0x80, 0xE2]]
print('plaintext:')
Matrix2 = [[0x4A, 0x0D, 0x2D, 0xE5],[0xC9, 0x7D, 0xFA, 0x59], [0xA3, 0x9E, 0x81, 0xF3], [0x02, 0x04, 0x08, 0x10]]
Matrix2_duplicate = [[0x4A, 0x0D, 0x2D, 0xE5],[0xC9, 0x7D, 0xFA, 0x59], [0xA3, 0x9E, 0x81, 0xF3], [0x02, 0x04, 0x08, 0x10]]
print(Matrix2)

print('ciphertext:')
ciphertext = encrypt(Matrix2, key_matrix1)
print(ciphertext)

print('decrypted_text:')
decrypted_text = decrypt(ciphertext, key_matrix1)
print(decrypted_text)

print('faulty_ciphertext:')
faulty_ciphertext = faulty_encrypt_on_round9(Matrix2_duplicate, key_matrix1)
print(faulty_ciphertext)

plaintext:
[[74, 13, 45, 229], [201, 125, 250, 89], [163, 158, 129, 243], [2, 4, 8, 16]]
ciphertext:
[[115, 32, 212, 197], [176, 194, 168, 6], [160, 177, 239, 29], [209, 104, 29, 95]]
decrypted_text:
[[74, 13, 45, 229], [201, 125, 250, 89], [163, 158, 129, 243], [2, 4, 8, 16]]
faulty_ciphertext:
[[38, 32, 212, 197], [176, 194, 168, 192], [160, 177, 59, 29], [209, 184, 29, 95]]


#### Differential Fault based on analysis

In [52]:
def DFA_round9(y,y_fault):
    key_hypothesis = []
    for key in range(4294967296):
        ##print("new_key:")
        ##print("{0:x}".format(key))
        key_1 = (key& 0XFF)
        ##print("{0:04x}".format(key_1))
        key_2 = ((key& 0XFF00)>>8)
        ##print("{0:04x}".format(key_2))
        key_3 = ((key& 0XFF0000)>>16)
        ##print("{0:04x}".format(key_3))
        key_4 = ((key& 0XFF000000)>>24)
        ##print("{0:04x}".format(key_4))
        f = inv_s_box[y[3][1]^key_2]^inv_s_box[y_fault[3][1]^key_2]
        if ((f == inv_s_box[y[2][2]^key_3]^inv_s_box[y_fault[2][2]^key_3]) and
        (xtimes(f) == inv_s_box[y[0][0]^key_1]^inv_s_box[y_fault[0][0]^key_1]) and 
        (xtimes(f)^f == inv_s_box[y[1][3]^key_4]^inv_s_box[y_fault[1][3]^key_4])) : 
            print("test")
            possible_subkey =[key_1,key_2,key_3,key_4]
            key_hypothesis.append(possible_subkey)
    print(key_hypothesis)
    print("--- %s seconds ---" % (time.time() - start_time))
    return key_hypothesis

In [203]:
##Exhaust all keys
start_time = time.time()
key_matrix1= [[0x52, 0x09, 0x6A, 0xD5],[0x3A, 0x91, 0x11, 0x41], [0xA7, 0x8D, 0x9D, 0x84],[0x07, 0x12, 0x80, 0xE2]]
Matrix2 = [[0x4A, 0x0D, 0x2D, 0xE5],[0xC9, 0x7D, 0xFA, 0x59], [0xA3, 0x9E, 0x81, 0xF3], [0x02, 0x04, 0x08, 0x10]]
Matrix2_duplicate = [[0x4A, 0x0D, 0x2D, 0xE5],[0xC9, 0x7D, 0xFA, 0x59], [0xA3, 0x9E, 0x81, 0xF3], [0x02, 0x04, 0x08, 0x10]]
y = encrypt(Matrix2, key_matrix1)
print(y)
y_fault = faulty_encrypt_on_round9(Matrix2_duplicate, key_matrix1)
print(y_fault)
key_hypothesis_round9= DFA_round9(y,y_fault)

[[115, 32, 212, 197], [176, 194, 168, 6], [160, 177, 239, 29], [209, 104, 29, 95]]
[[38, 32, 212, 197], [176, 194, 168, 192], [160, 177, 59, 29], [209, 184, 29, 95]]
test
test
test
test
test
test
test
test
test
test
test
test
test
test
test
test
test
test
test
test
test
test
test
test
test
test
test
test
test
test
test
test
test
test
test
test
test
test
test
test
test
test
test
test
test
test
test
test
test
test
test
test
test
test
test
test
test
test
test
test
test
test
test
test
test
test
test
test
test
test
test
test
test
test
test
test
test
test
test
test
test
test
test
test
test
test
test
test
test
test
test
test
test
test
test
test
test
test
test
test
test
test
test
test
test
test
test
test
test
test
test
test
test
test
test
test
test
test
test
test
test
test
test
test
test
test
test
test
test
test
test
test
test
test
test
test
test
test
test
test
test
test
test
test
test
test
test
test
test
test
test
test
test
test
test
test
test
test
test
test
test
test
test
test
test
test
test

In [53]:
print(len(key_hypothesis))
W = create_keys(key_matrix1)
k_10 = [W[40],W[41],W[42],W[43]]
print(k_10)
subkey_answer = [k_10[0][0],k_10[3][1],k_10[2][2],k_10[1][3]]
print(subkey_answer)
D = []
for key_hypo in key_hypothesis:
    if(key_hypo == subkey_answer):
        print(True)

NameError: name 'key_hypothesis' is not defined

### Creating a fault on 8th round

References: <br>
1. https://www.youtube.com/watch?v=y67QwAuKV6E&list=PLbRMhDVUMngfulSvKL0cT-tn8ULtERsWk&index=49&t=1945s&pbjreload=101 <br>
2. Differential Fault Analysis of the Advanced Encryption Standard using a Single Fault 

In [54]:
def faulty_encrypt_on_round8(plaintext_matrix, key_matrix): ###firstly test on the first byte on the first round
    W = create_keys(key_matrix)
    ##Round 0 
    k_1 = [W[0],W[1],W[2],W[3]]
    add_key(plaintext_matrix, k_1)
    ##Round 1-7
    for i in range(1,8):
        subbytes(plaintext_matrix)
        shiftrow(plaintext_matrix)
        mix_column(plaintext_matrix)
        k = [W[4*i],W[4*i+1],W[4*i+2],W[4*i+3]]
        add_key(plaintext_matrix, k)
    ##Round 8
    error = random.randint(0, 0xFF) ### TODO check if is alright
    plaintext_matrix[0][0] ^= error ### Just attack the first byte
    subbytes(plaintext_matrix)
    shiftrow(plaintext_matrix)
    mix_column(plaintext_matrix)
    k_8 = [W[4*8],W[4*8+1],W[4*8+2],W[4*8+3]]
    add_key(plaintext_matrix, k_8)
    ## Round 9
    subbytes(plaintext_matrix)
    shiftrow(plaintext_matrix)
    mix_column(plaintext_matrix)
    k_9 = [W[4*9],W[4*9+1],W[4*9+2],W[4*9+3]]
    add_key(plaintext_matrix, k_9)
    ## Round 10
    subbytes(plaintext_matrix)
    shiftrow(plaintext_matrix)
    k_10 = [W[40],W[41],W[42],W[43]]
    add_key(plaintext_matrix, k_10)
    return plaintext_matrix

In [55]:
### Testing faulty function
key_matrix1= [[0x52, 0x09, 0x6A, 0xD5],[0x3A, 0x91, 0x11, 0x41], [0xA7, 0x8D, 0x9D, 0x84],[0x07, 0x12, 0x80, 0xE2]]
print('plaintext:')
Matrix2 = [[0x4A, 0x0D, 0x2D, 0xE5],[0xC9, 0x7D, 0xFA, 0x59], [0xA3, 0x9E, 0x81, 0xF3], [0x02, 0x04, 0x08, 0x10]]
Matrix2_duplicate = [[0x4A, 0x0D, 0x2D, 0xE5],[0xC9, 0x7D, 0xFA, 0x59], [0xA3, 0x9E, 0x81, 0xF3], [0x02, 0x04, 0x08, 0x10]]
print(Matrix2)

print('ciphertext:')
ciphertext = encrypt(Matrix2, key_matrix1)
print(ciphertext)

print('decrypted_text:')
decrypted_text = decrypt(ciphertext, key_matrix1)
print(decrypted_text)

print('faulty_ciphertext:')
faulty_ciphertext = faulty_encrypt_on_round8(Matrix2_duplicate, key_matrix1)
print(faulty_ciphertext)

plaintext:
[[74, 13, 45, 229], [201, 125, 250, 89], [163, 158, 129, 243], [2, 4, 8, 16]]
ciphertext:
[[115, 32, 212, 197], [176, 194, 168, 6], [160, 177, 239, 29], [209, 104, 29, 95]]
decrypted_text:
[[74, 13, 45, 229], [201, 125, 250, 89], [163, 158, 129, 243], [2, 4, 8, 16]]
faulty_ciphertext:
[[189, 34, 83, 238], [126, 85, 50, 196], [52, 248, 116, 212], [135, 232, 136, 44]]


#### Phase 1 Differential Fault

In [56]:
def DFA8_phase1(y,y_fault):
    start_time = time.time()
    key_hypothesis_1 = [] ## For key k_1,k_8,k_11,k_14
    key_hypothesis_2 = [] ## For key k_2,k_5,k_12,k_15
    key_hypothesis_3 = [] ## For key k_3,k_6,k_9,k_16
    key_hypothesis_4 = [] ## For key k_4,k_7,k_10,k_13
    for key in range(4294967296): ##Exhaust all keys: time complexity 2^32
        ##print("new_key:")
        ##print("{0:x}".format(key))
        key_1 = (key& 0XFF)
        ##print("{0:04x}".format(key_1))
        key_2 = ((key& 0XFF00)>>8)
        ##print("{0:04x}".format(key_2))
        key_3 = ((key& 0XFF0000)>>16)
        ##print("{0:04x}".format(key_3))
        key_4 = ((key& 0XFF000000)>>24)
        ##print("{0:04x}".format(key_4))
        f = inv_s_box[y[3][1]^key_4]^inv_s_box[y_fault[3][1]^key_4] ### 14
        if ((f == inv_s_box[y[2][2]^key_3]^inv_s_box[y_fault[2][2]^key_3]) and ##11
        (xtimes(f) == inv_s_box[y[0][0]^key_1]^inv_s_box[y_fault[0][0]^key_1]) and ##1
        (xtimes(f)^f == inv_s_box[y[1][3]^key_2]^inv_s_box[y_fault[1][3]^key_2])): ###8 For column 1 
            print("test_1")
            possible_subkey_1 =[key_1,key_2,key_3,key_4]
            key_hypothesis_1.append(possible_subkey_1)


        f = inv_s_box[y[1][0]^key_2]^inv_s_box[y_fault[1][0]^key_2] #5
        if ((f == inv_s_box[y[0][1]^key_1]^inv_s_box[y_fault[0][1]^key_1])  and ##2
        (xtimes(f)^f == inv_s_box[y[3][2]^key_4]^inv_s_box[y_fault[3][2]^key_4]) and  ##15
        (xtimes(f) == inv_s_box[y[2][3]^key_3]^inv_s_box[y_fault[2][3]^key_3])): ###12 For column 2 
            print("test_2")
            possible_subkey_2 =[key_1,key_2,key_3,key_4] ## need to deepcopy this
            key_hypothesis_2.append(possible_subkey_2)
            
            
        f = inv_s_box[y[2][0]^key_3]^inv_s_box[y_fault[2][0]^key_3] ##9
        if ((xtimes(f)^ f  == inv_s_box[y[1][1]^key_2]^inv_s_box[y_fault[1][1]^key_2]) and ##6
        (xtimes(f)== inv_s_box[y[0][2]^key_1]^inv_s_box[y_fault[0][2]^key_1]) and ##3 
        (f == inv_s_box[y[3][3]^key_4]^inv_s_box[y_fault[3][3]^key_4])): ###16 For column 3
            print("test_3")
            possible_subkey_3 =[key_1,key_2,key_3,key_4]
            key_hypothesis_3.append(possible_subkey_3)
        
        f == inv_s_box[y[0][3]^key_1]^inv_s_box[y_fault[0][3]^key_1] ##4
        if ((xtimes(f)  == inv_s_box[y[2][1]^key_3]^inv_s_box[y_fault[2][1]^key_3]) and ##10
        (f == inv_s_box[y[1][2]^key_2]^inv_s_box[y_fault[1][2]^key_2]) and ##7
        (xtimes(f)^ f == inv_s_box[y[3][0]^key_4]^inv_s_box[y_fault[3][0]^key_4])): ##13 For column 4
            print("test_4")
            possible_subkey_4 =[key_1,key_2,key_3,key_4]
            key_hypothesis_4.append(possible_subkey_4)
   
    key_hypothesis = []
    key_hypothesis.append(key_hypothesis_1)
    key_hypothesis.append(key_hypothesis_2)
    key_hypothesis.append(key_hypothesis_3)
    key_hypothesis.append(key_hypothesis_4)
    print("--- %s seconds ---" % (time.time() - start_time))
    return key_hypothesis

In [None]:
key_matrix1= [[0x52, 0x09, 0x6A, 0xD5],[0x3A, 0x91, 0x11, 0x41], [0xA7, 0x8D, 0x9D, 0x84],[0x07, 0x12, 0x80, 0xE2]]
Matrix2 = [[0x4A, 0x0D, 0x2D, 0xE5],[0xC9, 0x7D, 0xFA, 0x59], [0xA3, 0x9E, 0x81, 0xF3], [0x02, 0x04, 0x08, 0x10]]
Matrix2_duplicate = [[0x4A, 0x0D, 0x2D, 0xE5],[0xC9, 0x7D, 0xFA, 0x59], [0xA3, 0x9E, 0x81, 0xF3], [0x02, 0x04, 0x08, 0x10]]
y = encrypt(Matrix2, key_matrix1)
print(y)
y_fault = faulty_encrypt_on_round8(Matrix2_duplicate, key_matrix1)
print(y_fault)
key_hypothesis_phase1 = DFA8_phase1(y,y_fault)
print(key_hypothesis_phase1)

[[115, 32, 212, 197], [176, 194, 168, 6], [160, 177, 239, 29], [209, 104, 29, 95]]
[[204, 179, 75, 145], [221, 216, 86, 177], [13, 203, 75, 105], [110, 152, 237, 213]]
test_1
test_1
test_1
test_1
test_2
test_2
test_2
test_2
test_2
test_2
test_2
test_2
test_1
test_1
test_1
test_1
test_2
test_2
test_2
test_2
test_2
test_2
test_2
test_2
test_3
test_3
test_3
test_3
test_3
test_3
test_3
test_3
test_3
test_3
test_3
test_3
test_3
test_3
test_3
test_3
test_3
test_3
test_3
test_3
test_3
test_3
test_3
test_3
test_3
test_3
test_3
test_3
test_3
test_3
test_3
test_3
test_1
test_1
test_1
test_1
test_1
test_1
test_1
test_1
test_2
test_2
test_2
test_2
test_2
test_2
test_2
test_2
test_3
test_3
test_3
test_3
test_3
test_3
test_3
test_3
test_1
test_1
test_1
test_1
test_1
test_1
test_1
test_1
test_2
test_2
test_2
test_2
test_2
test_2
test_2
test_2
test_3
test_3
test_3
test_3
test_3
test_3
test_3
test_3
test_2
test_2
test_2
test_2
test_2
test_2
test_2
test_2
test_1
test_1
test_1
test_1
test_3
test_3
test_3

Took 3.5 hours.

#### Phase 2 Differential Fault on round 8th Encryption