In [194]:
from IPython.display import display, Math, Latex
import functools
import itertools

In [195]:
def AES_mult(a,b):
    res = 0
    p = b
    while a != 0:
        if a%2 == 1:
            res ^= p
        a >>= 1 # divide by x (discard constant coefficient)
        p <<= 1
        if p&0x100:
            p ^= 0x11b

    return res

In [196]:
inv_table = {}
for i in range(1,256):
    for j in range(1,i+1):
        if AES_mult(i,j) == 0x01:
            inv_table[i]=j
            inv_table[j] = i

In [197]:
ok = True
for i in range(1,256):
    if AES_mult(i,inv_table[i]) != 1:
        ok = False
        print("Hey. This is wrong %x * %x != 1"%(i,inv_table[i]))
if ok:
    print("Life is good")

Life is good


In [198]:
def AES_poly_mult_alt(p,q):
    res = [0,0,0,0]
    for k in range(4):
        for i in range(4):
            res[k] ^= AES_mult(p[i],q[(-i)%4])
    return res

In [199]:
def T(p):
    return ((p&0xf)<<4 & (p>>4)&0x0f)^ p ^ 0x63
def J(p):
    if p == 0:
        return p
    else:
        return inv_table[p]

def S(p):
    return T(J(p))

In [200]:
S(Sinv(240))

240

In [201]:
def Sinv(q):
    return J(T(q))

In [202]:
mixing_p = [[2,3,1,1],
            [1,2,3,1],
            [1,1,2,3],
            [3,1,1,2]]

In [203]:
def AES_poly_mult(p,q):
    res = [0,0,0,0]
    for i in range(4):
        for j in range(4):
            res[(i+j)%4] ^= AES_mult(p[i],q[j])
        
    return res

In [204]:
def MixColumns(s):
    for c in range(4):
        s[c] = AES_poly_mult(mixing_p[c], s[c])
        
def MixColumnsInv(s):
    for c in range(4):
        s[c] = AES_poly_mult(unmixing_p, s[c])

In [205]:
s = [[1,2,3,4],[4,5,6,7],[7,8,9,0],[0,1,2,3]]
MixColumns(s)
s

[[15, 0, 5, 14], [5, 2, 7, 0], [13, 20, 15, 16], [3, 4, 1, 6]]

In [206]:
def ShiftRows(s):   
    for r in range(1,4): # We do nothing to the 0th row, so skip that
        s[0][r],s[1][r],s[2][r],s[3][r] = \
        s[r][r], s[(r+1)&0x3][r], s[(r+2)&0x3][r], s[(r+3)&0x3][r]  

In [207]:
SBox = [S(p) for p in range(0x100)]
SBoxInv = [Sinv(q) for q in range(0x100)]

In [208]:
#SBox = [[0x63, 0x7c, 0x77, 0x7b, 0xf2, 0x6b, 0x6f, 0xc5, 0x30, 0x01, 0x67, 0x2b, 0xfe, 0xd7, 0xab, 0x76],
#        [0xca, 0x82, 0xc9, 0x7d, 0xfa, 0x59, 0x47, 0xf0, 0xad, 0xd4, 0xa2, 0xaf, 0x9c, 0xa4, 0x72, 0xc0],
#        [0xb7, 0xfd, 0x93, 0x26, 0x36, 0x3f, 0xf7, 0xcc, 0x34, 0xa5, 0xe5, 0xf1, 0x71, 0xd8, 0x31, 0x15],
#        [0x04, 0xc7, 0x23, 0xc3, 0x18, 0x96, 0x05, 0x9a, 0x07, 0x12, 0x80, 0xe2, 0xeb, 0x27, 0xb2, 0x75],
#        [0x09, 0x83, 0x2c, 0x1a, 0x1b, 0x6e, 0x5a, 0xa0, 0x52, 0x3b, 0xd6, 0xb3, 0x29, 0xe3, 0x2f, 0x84],
#        [0x53, 0xd1, 0x00, 0xed, 0x20, 0xfc, 0xb1, 0x5b, 0x6a, 0xcb, 0xbe, 0x39, 0x4a, 0x4c, 0x58, 0xcf],
#        [0xd0, 0xef, 0xaa, 0xfb, 0x43, 0x4d, 0x33, 0x85, 0x45, 0xf9, 0x02, 0x7f, 0x50, 0x3c, 0x9f, 0xa8],
#        [0x51, 0xa3, 0x40, 0x8f, 0x92, 0x9d, 0x38, 0xf5, 0xbc, 0xb6, 0xda, 0x21, 0x10, 0xff, 0xf3, 0xd2],
#        [0xcd, 0x0c, 0x13, 0xec, 0x5f, 0x97, 0x44, 0x17, 0xc4, 0xa7, 0x7e, 0x3d, 0x64, 0x5d, 0x19, 0x73],
#        [0x60, 0x81, 0x4f, 0xdc, 0x22, 0x2a, 0x90, 0x88, 0x46, 0xee, 0xb8, 0x14, 0xde, 0x5e, 0x0b, 0xdb],
#        [0xe0, 0x32, 0x3a, 0x0a, 0x49, 0x06, 0x24, 0x5c, 0xc2, 0xd3, 0xac, 0x62, 0x91, 0x95, 0xe4, 0x79],
#        [0xe7, 0xc8, 0x37, 0x6d, 0x8d, 0xd5, 0x4e, 0xa9, 0x6c, 0x56, 0xf4, 0xea, 0x65, 0x7a, 0xae, 0x08],
#        [0xba, 0x78, 0x25, 0x2e, 0x1c, 0xa6, 0xb4, 0xc6, 0xe8, 0xdd, 0x74, 0x1f, 0x4b, 0xbd, 0x8b, 0x8a],
#        [0x70, 0x3e, 0xb5, 0x66, 0x48, 0x03, 0xf6, 0x0e, 0x61, 0x35, 0x57, 0xb9, 0x86, 0xc1, 0x1d, 0x9e],
#        [0xe1, 0xf8, 0x98, 0x11, 0x69, 0xd9, 0x8e, 0x94, 0x9b, 0x1e, 0x87, 0xe9, 0xce, 0x55, 0x28, 0xdf],
#        [0x8c, 0xa1, 0x89, 0x0d, 0xbf, 0xe6, 0x42, 0x68, 0x41, 0x99, 0x2d, 0x0f, 0xb0, 0x54, 0xbb, 0x16],]

In [209]:
#SBoxInv = [[0x52, 0x09, 0x6a, 0xd5, 0x30, 0x36, 0xa5, 0x38, 0xbf, 0x40, 0xa3, 0x9e, 0x81, 0xf3, 0xd7, 0xfb],
#           [0x7c, 0xe3, 0x39, 0x82, 0x9b, 0x2f, 0xff, 0x87, 0x34, 0x8e, 0x43, 0x44, 0xc4, 0xde, 0xe9, 0xcb],
#           [0x54, 0x7b, 0x94, 0x32, 0xa6, 0xc2, 0x23, 0x3d, 0xee, 0x4c, 0x95, 0x0b, 0x42, 0xfa, 0xc3, 0x4e],
#           [0x08, 0x2e, 0xa1, 0x66, 0x28, 0xd9, 0x24, 0xb2, 0x76, 0x5b, 0xa2, 0x49, 0x6d, 0x8b, 0xd1, 0x25],
#           [0x72, 0xf8, 0xf6, 0x64, 0x86, 0x68, 0x98, 0x16, 0xd4, 0xa4, 0x5c, 0xcc, 0x5d, 0x65, 0xb6, 0x92], 
#           [0x6c, 0x70, 0x48, 0x50, 0xfd, 0xed, 0xb9, 0xda, 0x5e, 0x15, 0x46, 0x57, 0xa7, 0x8d, 0x9d, 0x84], 
#           [0x90, 0xd8, 0xab, 0x00, 0x8c, 0xbc, 0xd3, 0x0a, 0xf7, 0xe4, 0x58, 0x05, 0xb8, 0xb3, 0x45, 0x06], 
#           [0xd0, 0x2c, 0x1e, 0x8f, 0xca, 0x3f, 0x0f, 0x02, 0xc1, 0xaf, 0xbd, 0x03, 0x01, 0x13, 0x8a, 0x6b], 
#           [0x3a, 0x91, 0x11, 0x41, 0x4f, 0x67, 0xdc, 0xea, 0x97, 0xf2, 0xcf, 0xce, 0xf0, 0xb4, 0xe6, 0x73], 
#           [0x96, 0xac, 0x74, 0x22, 0xe7, 0xad, 0x35, 0x85, 0xe2, 0xf9, 0x37, 0xe8, 0x1c, 0x75, 0xdf, 0x6e], 
#           [0x47, 0xf1, 0x1a, 0x71, 0x1d, 0x29, 0xc5, 0x89, 0x6f, 0xb7, 0x62, 0x0e, 0xaa, 0x18, 0xbe, 0x1b], 
#           [0xfc, 0x56, 0x3e, 0x4b, 0xc6, 0xd2, 0x79, 0x20, 0x9a, 0xdb, 0xc0, 0xfe, 0x78, 0xcd, 0x5a, 0xf4], 
#           [0x1f, 0xdd, 0xa8, 0x33, 0x88, 0x07, 0xc7, 0x31, 0xb1, 0x12, 0x10, 0x59, 0x27, 0x80, 0xec, 0x5f], 
#           [0x60, 0x51, 0x7f, 0xa9, 0x19, 0xb5, 0x4a, 0x0d, 0x2d, 0xe5, 0x7a, 0x9f, 0x93, 0xc9, 0x9c, 0xef], 
#           [0xa0, 0xe0, 0x3b, 0x4d, 0xae, 0x2a, 0xf5, 0xb0, 0xc8, 0xeb, 0xbb, 0x3c, 0x83, 0x53, 0x99, 0x61], 
#           [0x17, 0x2b, 0x04, 0x7e, 0xba, 0x77, 0xd6, 0x26, 0xe1, 0x69, 0x14, 0x63, 0x55, 0x21, 0x0c, 0x7d]]

In [210]:
def SubstBytes(s):
    for c in range(4):
        for c in range(4):
            s[c][r] = SBox[s[c][r]]

def SustBytesInv(s):
    for c in range(4):
        for r in range(4):
            s[c][r] = SBoxInv[s[c][r]]

In [211]:
def AddRoundKey(s,w):
    for c in range(4):
        for r in range(4):
            w = HexToBin(w) #want to index bits from the hex value
            s[c][r] ^= int(w[r])

In [212]:
def Cipher(inputBlock, outputBlock, word):
    Nb = 4
    Nr = 10
    state = inputBlock
    AddRoundKey(state, word[0][Nb-1])
    for i in range (1,Nr):
        SubBytes(state) 
        ShiftRows(state)
        MixColumns(state)
        AddRoundKey(state, word[round*Nb][(round+1)*Nb-1])
    SubBytes(state)
    ShiftRows(state)
    AddRoundKey(state, word[Nr*Nb][(Nr+1)*Nb-1])
    outputBlock = state

In [None]:
def HexToBin(hex):
    hex = int(hex,0) #string to hex
    binary = bin(hex).lstrip("0b").rstrip("L") #return a string of binary value
    return binary

In [None]:
#testing the cipher
inputBlock = [[0xa5, 0x81, 0x60, 0xb0],
              [0x33, 0x88, 0x07, 0xc7],
              [0xa1, 0x66, 0x28, 0xd9],
              [0x47, 0xf1, 0x1a, 0x71]]

word = [[0xbb, 0x16, 0x8a, 0x91], 
        [0xe7, 0xad, 0x35, 0x85],
        [0xf8, 0xf6, 0x64, 0x86],
        [0x5b, 0xa2, 0x49, 0x6d]]
out = ""
Cipher(inputBlock, out, word)
out

In [None]:
Rcon = [0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x1b, 0x36]

In [142]:
def ByteToCol(byte1, byte2, byte3, byte4):
    return (((byte1 << 8 | byte2) << 8 | byte3) << 8 | byte4)
def SubWord(word):
    t1 = (word & 0xff000000) >> 24
    t2 = (word & 0x00ff0000) >> 16
    t3 = (word & 0x0000ff00) >> 8
    t4 = (word & 0x000000ff)
    return ByteToCol(SBox[t1], SBox[t2], SBox[t3], SBox[t4])
def RotWord(word):
    word = word << 8 | (word >> (32 - 8))

In [143]:
CipherKey = [0x8e, 0x73, 0xb0, 0xf7, 0xda, 0x0e, 0x64, 0x52, 0xc8, 0x10, 0xf3, 0x2b, 0x80, 
             0x90, 0x79, 0xe5, 0x62, 0xf8, 0xea, 0xd2, 0x52, 0x2c, 0x6b, 0x7b]

In [144]:
def KeyExpansion(key, word, Nk=4):
    i = 0
    while (i < Nk):
        w[i] = word(key[4*i], key[4*i+1], key[4*i+2], key[4*i+3])
        i = i+1
    i = Nk
    while (i < Nb * (Nr+1)):
        temp = w[i-1]
        if (i % Nk == 0):
            temp = SubWord(RotWord(temp)) ^ Rcon[i/Nk]
        elif (Nk > 6 and i % Nk == 4):
            temp = SubWord(temp)
        w[i] = w[i-Nk] ^ temp
        i = i + 1

In [145]:
def InvCipher(inputBlock, outputBlock, word):
    Nb = 4
    Nr = 10
    state = inputBlock
    AddRoundKey(state, w[Nr*Nb, (Nr+1)*Nb-1])
    for round in range(Nr-1, 1, -1): #round = Nr-1 step -1 downto 1
        InvShiftRows(state)
        InvSubBytes(state)
        AddRoundKey(state, w[round*Nb, (round+1)*Nb-1])
        InvMixColumns(state)
    InvShiftRows(state)
    InvSubBytes(state)
    AddRoundKey(state, w[0, Nb-1])
    outputBlock = state