# Rijndael

## Key expansion procedure

In [26]:
# The AES S-box used in the SubWord function.
S_BOX = [ 
  # 0    1    2    3    4    5    6    7    8    9    a    b    c    d    e    f 
  0x63,0x7c,0x77,0x7b,0xf2,0x6b,0x6f,0xc5,0x30,0x01,0x67,0x2b,0xfe,0xd7,0xab,0x76, # 0
  0xca,0x82,0xc9,0x7d,0xfa,0x59,0x47,0xf0,0xad,0xd4,0xa2,0xaf,0x9c,0xa4,0x72,0xc0, # 1
  0xb7,0xfd,0x93,0x26,0x36,0x3f,0xf7,0xcc,0x34,0xa5,0xe5,0xf1,0x71,0xd8,0x31,0x15, # 2
  0x04,0xc7,0x23,0xc3,0x18,0x96,0x05,0x9a,0x07,0x12,0x80,0xe2,0xeb,0x27,0xb2,0x75, # 3
  0x09,0x83,0x2c,0x1a,0x1b,0x6e,0x5a,0xa0,0x52,0x3b,0xd6,0xb3,0x29,0xe3,0x2f,0x84, # 4
  0x53,0xd1,0x00,0xed,0x20,0xfc,0xb1,0x5b,0x6a,0xcb,0xbe,0x39,0x4a,0x4c,0x58,0xcf, # 5
  0xd0,0xef,0xaa,0xfb,0x43,0x4d,0x33,0x85,0x45,0xf9,0x02,0x7f,0x50,0x3c,0x9f,0xa8, # 6
  0x51,0xa3,0x40,0x8f,0x92,0x9d,0x38,0xf5,0xbc,0xb6,0xda,0x21,0x10,0xff,0xf3,0xd2, # 7
  0xcd,0x0c,0x13,0xec,0x5f,0x97,0x44,0x17,0xc4,0xa7,0x7e,0x3d,0x64,0x5d,0x19,0x73, # 8
  0x60,0x81,0x4f,0xdc,0x22,0x2a,0x90,0x88,0x46,0xee,0xb8,0x14,0xde,0x5e,0x0b,0xdb, # 9
  0xe0,0x32,0x3a,0x0a,0x49,0x06,0x24,0x5c,0xc2,0xd3,0xac,0x62,0x91,0x95,0xe4,0x79, # a
  0xe7,0xc8,0x37,0x6d,0x8d,0xd5,0x4e,0xa9,0x6c,0x56,0xf4,0xea,0x65,0x7a,0xae,0x08, # b
  0xba,0x78,0x25,0x2e,0x1c,0xa6,0xb4,0xc6,0xe8,0xdd,0x74,0x1f,0x4b,0xbd,0x8b,0x8a, # c
  0x70,0x3e,0xb5,0x66,0x48,0x03,0xf6,0x0e,0x61,0x35,0x57,0xb9,0x86,0xc1,0x1d,0x9e, # d
  0xe1,0xf8,0x98,0x11,0x69,0xd9,0x8e,0x94,0x9b,0x1e,0x87,0xe9,0xce,0x55,0x28,0xdf, # e
  0x8c,0xa1,0x89,0x0d,0xbf,0xe6,0x42,0x68,0x41,0x99,0x2d,0x0f,0xb0,0x54,0xbb,0x16  # f
];

def RotWord(word):
    return word[1:] + word[:1]

def SubWord(word):
    return [S_BOX[byte] for byte in word]

In [27]:
# AES supports key sizes of 16 (AES-128), 24 (AES-192), or 32 (AES-256)
KEY_SIZE = 16

# This is the round constant (Rcon) array used in key expansion.
RCON = [
    0x00, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 
    0x80, 0x1B, 0x36, 0x6C, 0xD8, 0xAB, 0x4D, 0x9A, 
    0x2F, 0x5E, 0xBC, 0x63, 0xC6, 0x97, 0x35, 0x6A, 
    0xD4, 0xB3, 0x7D, 0xFA, 0xEF, 0xC5, 0x91, 0x39, 
]

def expand_key(key):
    # The number of 32 bit words in the key.
    Nk = KEY_SIZE // 4

    # The number of 32 bit words in the expanded key.
    Nr = Nk + 6  # This is the number of rounds. It's always Nk + 6.
    Nb = 4  # The number of columns comprising a state in AES. This is fixed to 4 for AES.
    
    # Initialize the expanded key with zeros.
    w = [0] * Nb * (Nr + 1)

    # The first Nk words are the original key.
    for i in range(Nk):
        w[i] = key[i*4 : i*4 + 4]

    # Derive the remaining words in the expanded key.
    for i in range(Nk, Nb * (Nr + 1)):
        temp = w[i-1]

    # Apply the special function to every Nk'th word.
        if i % Nk == 0:
            temp = SubWord(RotWord(temp))
            temp[0] ^= RCON[i // Nk - 1]  # Subtract 1 here

        # For 256-bit keys, apply another SubWord
        elif Nk > 6 and i % Nk == 4:
            temp = SubWord(temp)

        w[i] = [a ^ b for a, b in zip(w[i - Nk], temp)]

    return w


In [28]:
# Test key (16 bytes for AES-128)
key = [0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0A, 0x0B, 0x0C, 0x0D, 0x0E, 0x0F]

expanded_key = expand_key(key)
print(expanded_key)


[[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11], [12, 13, 14, 15], [215, 170, 116, 253], [211, 175, 114, 250], [219, 166, 120, 241], [215, 171, 118, 254], [180, 146, 207, 243], [103, 61, 189, 9], [188, 155, 197, 248], [107, 48, 179, 6], [178, 255, 160, 140], [213, 194, 29, 133], [105, 89, 216, 125], [2, 105, 107, 123], [79, 128, 129, 251], [154, 66, 156, 126], [243, 27, 68, 3], [241, 114, 47, 120], [7, 149, 61, 90], [157, 215, 161, 36], [110, 204, 229, 39], [159, 190, 202, 95], [185, 225, 242, 129], [36, 54, 83, 165], [74, 250, 182, 130], [213, 68, 124, 221], [130, 241, 51, 130], [166, 199, 96, 39], [236, 61, 214, 165], [57, 121, 170, 120], [116, 93, 143, 144], [210, 154, 239, 183], [62, 167, 57, 18], [7, 222, 147, 106], [233, 129, 141, 85], [59, 27, 98, 226], [5, 188, 91, 240], [2, 98, 200, 154], [88, 105, 53, 34], [99, 114, 87, 192], [102, 206, 12, 48], [100, 172, 196, 170]]


In [29]:
# Only print the first round key
for i in range(8):
    print([hex(byte) for byte in expanded_key[i]])


['0x0', '0x1', '0x2', '0x3']
['0x4', '0x5', '0x6', '0x7']
['0x8', '0x9', '0xa', '0xb']
['0xc', '0xd', '0xe', '0xf']
['0xd7', '0xaa', '0x74', '0xfd']
['0xd3', '0xaf', '0x72', '0xfa']
['0xdb', '0xa6', '0x78', '0xf1']
['0xd7', '0xab', '0x76', '0xfe']
