# AES (Advanced Encryption Standard) description

This file is designed as a high level and conceptual introduction to AES encryption/decryption algorithm. I emphasize the role of the use of finite fields in the conceptual design. The motivation for the choices of functions is not discussed in this version of the file.

## Finite field $\mathbb{F}_{2^8}$ setup

In [266]:
R.<x> = PolynomialRing(GF(2))
F256.<a> = GF(2**8, name='a', modulus=x^8+x^4+x^3+x+1)

### Hexadecimal representation of elements in $\mathbb{F}_{2^8}$

In [439]:
def HexToF256(n):
    xl=list(map(int,bin(int(n,16))[2:]))
    xl.reverse()
    return F256(R(xl))
def F256ToHex(el):
    al=R(el).list()
    if len(al)==0:
        return '00'
    al.reverse()
    return hex(int("".join(map(str,al)),2))[2:].upper()

def h(x):
    return HexToF256(x)
def g(x):
    return F256ToHex(x)


def HexToBits(h):
    return leftpadding(list(map(int,bin(int(h,16))[2:])),8)
def b(h):
    return HexToBits(h)

### AES specification

In [401]:
#AES specification
Nb=4
AEStype='128' #choose '128','192' or '256'

if AEStype=='128':
    Nk=4
    Nr=10
if AEStype=='192':
    Nk=6
    Nr=12
if AEStype=='256':
    Nk=8
    Nr=14
    
#state is a 4 by 4 array of F256 elements, e.g.
state=[[F256(0) for i in range(0,4)] for j in range(0,4)]
state     

### S-Box setup and SubBytes routine

In [237]:
#cyclic shifts of lists
def CyclicRot(li,n):
    l=len(li)
    return li[n%l:]+li[:-((l-n)%l)]

In [294]:
#(right)padding of the list
def padding(li,n):
    return li+[0]*(n-len(li))

#left padding of the list
def leftpadding(li,n):
    return [0]*(n-len(li))+li

#x is meant to be an element of F256, we lift it to R ring representation and turn into coefficients list
def l(x,p):
    li=R(x).list()
    return padding(li,p)

In [400]:
#important matrix used in the SBox design
m=matrix([CyclicRot([1,0,0,0,1,1,1,1],-k) for k in range(0,8)])
def SBox(el):
    if not(el==0):
        inv=el**(-1)
    else:
        inv=F256(el)
    return F256(list(m*vector(l(inv,8))))+h('63')

def SubBytes(state):
    out=[]
    for row in state:
        li=map(SBox,row)
        out.append(li)
    return out

### MixColumns

In [551]:
#MixColumns setup
R2.<T>=PolynomialRing(F256)
cpol=h('03')*T^3+h('01')*T^2+h('01')*T+h('02')
cpolinv=(h('0B')*T^3+h('0D')*T^2+h('09')*T+h('0E'))

In [553]:
#inverse check
(cpol*cpolinv)%(T^4+1)

1

In [564]:
def MixColumns(state):
    m=matrix(state)
    mt=transpose(m)
    c0,c1,c2,c3=list(mt)
    vpol=vector([T^3,T^2,T,1])
    newstate=[]
    for c in [c0,c1,c2,c3]:
        p=((vector(R2,padding(list(reversed((c).list())),4-len(c.list())))*vpol)*cpol)%(T^4+1)
        newstate+=[padding(p.list(),4)]
    return list(transpose(matrix(newstate)))

### ShiftRows routine

In [114]:
#shifts initialization designed in AES
if Nb==4:
    C0=0
    C1=1
    C2=2
    C3=3

In [395]:
def ShiftRows(state):
    row0,row1,row2,row3=state
    return [CyclicRot(row0,C0),CyclicRot(row1,C1),CyclicRot(row2,C2),CyclicRot(row3,C3)]

### KeyExpansion routine

In [383]:
def KeyExpansion(Key):
    W=[0]*(Nb*(Nr+1))
    if Nk<=6:
        for i in range(0,Nk):
            W[i]=[F256(Key[4*i+j]) for j in range(0,4)]
        for i in range(Nk,Nb*(Nr+1)):
            temp=W[i-1]
            if i%Nk==0:
                temp=list(map(SBox,CyclicRot(temp,1)))
                temp[0]=temp[0]+a**(i//Nk-1)
            W[i]=list(vector(F256,W[i-Nk])+vector(F256,temp))
        return W
    
    #TODO: write the case Nk>6

### RoundKey generation

In [384]:
def RoundKey(expkey,i):
    return expkey[Nb*i:Nb*(i+1)]

### AddRoundKey routine

In [350]:
def AddRoundKey(state,roundkey):
    return list(matrix(state)+transpose(matrix(roundkey)))

### PrintState (this is used in the verification of AES workings)

In [568]:
def PrintState(state):
    return [[g(x) for x in r] for r in state]

### AES encryption (Rijndael with 128 bit key)

In [654]:
def AESEncryptionVerbose(inp,key):
    expkey=KeyExpansion(key)
    state=inp
    
    print("Input")
    print(PrintState(state))
    
    state=AddRoundKey(state,RoundKey(expkey,0))
    print("Round Key Value")
    print(PrintState(RoundKey(expkey,0)))
    
    for rnd in range(1,Nr):
        print("\n Start Of Round {}".format(rnd))
        print(PrintState(state))
        
        state=SubBytes(state)
        print("After SubBytes")
        print(PrintState(state))
        
        state=ShiftRows(state)
        
        print("After ShiftRows")
        print(PrintState(state))


        state=MixColumns(state)
        print("After MixColumns")
        print(PrintState(state))


        state=AddRoundKey(state,RoundKey(expkey,rnd))
        print("Round Key Value")
        print(PrintState(RoundKey(expkey,rnd)))
        
    print("\n Start Of the Final Round {}".format(10))
    print(PrintState(state))
        
    state=SubBytes(state)
    print("After SubBytes")
    print(PrintState(state))
    
    state=ShiftRows(state)
    print("After ShiftRows")
    print(PrintState(state))
    
    state=AddRoundKey(state,RoundKey(expkey,Nr))
    print("\n Output")
    print(PrintState(state))
    
    return [list(x) for x in state]

In [655]:
def AESEncryption(inp,key):
    expkey=KeyExpansion(key)
    state=inp
    state=AddRoundKey(state,RoundKey(expkey,0))
    
    for rnd in range(1,Nr):
        
        state=SubBytes(state)
        state=ShiftRows(state)
        state=MixColumns(state)
        state=AddRoundKey(state,RoundKey(expkey,rnd))
        
        
    state=SubBytes(state)
    state=ShiftRows(state)
    state=AddRoundKey(state,RoundKey(expkey,Nr))    
    return [list(x) for x in state]


##Remarks

#key is a list of 16 element from F256, e.g.
#key=[F256(0)]*16

#state is a 4 by 4 array of F256 elements, e.g.
#state=[[F256(0) for i in range(0,4)] for j in range(0,4)]
#AESEncryption(state,key)

### NIST example

In [656]:
inp=matrix(4,4,[h(x) for x in ['32','43','f6','a8','88','5a','30','8d','31','31','98','a2','e0','37','07','34']])
inp=list(transpose(inp))
key=[h(x) for x in ['2b','7e','15','16','28','ae','d2','a6','ab','f7','15','88','09','cf','4f','3c']]


cipher=AESEncryptionVerbose(inp,key)

Input
[['32', '88', '31', 'E0'], ['43', '5A', '31', '37'], ['F6', '30', '98', '7'], ['A8', '8D', 'A2', '34']]
Round Key Value
[['2B', '7E', '15', '16'], ['28', 'AE', 'D2', 'A6'], ['AB', 'F7', '15', '88'], ['9', 'CF', '4F', '3C']]

 Start Of Round 1
[['19', 'A0', '9A', 'E9'], ['3D', 'F4', 'C6', 'F8'], ['E3', 'E2', '8D', '48'], ['BE', '2B', '2A', '8']]
After SubBytes
[['D4', 'E0', 'B8', '1E'], ['27', 'BF', 'B4', '41'], ['11', '98', '5D', '52'], ['AE', 'F1', 'E5', '30']]
After ShiftRows
[['D4', 'E0', 'B8', '1E'], ['BF', 'B4', '41', '27'], ['5D', '52', '11', '98'], ['30', 'AE', 'F1', 'E5']]
After MixColumns
[['4', 'E0', '48', '28'], ['66', 'CB', 'F8', '6'], ['81', '19', 'D3', '26'], ['E5', '9A', '7A', '4C']]
Round Key Value
[['A0', 'FA', 'FE', '17'], ['88', '54', '2C', 'B1'], ['23', 'A3', '39', '39'], ['2A', '6C', '76', '5']]

 Start Of Round 2
[['A4', '68', '6B', '2'], ['9C', '9F', '5B', '6A'], ['7F', '35', 'EA', '50'], ['F2', '2B', '43', '49']]
After SubBytes
[['49', '45', '7F', '77'], [

### AES decryption (Rijndael with 128 bit key)

#### InvSubBytes

In [644]:
#important matrix used in the SBox design
m=matrix(GF(2),[CyclicRot([1,0,0,0,1,1,1,1],-k) for k in range(0,8)])
minv=m**(-1)
def SBoxinv(el):
    val=F256((list(minv*vector(l(el-h('63'),8)))))
    if not(val==F256(0)):
        return val**(-1)
    else:
        return val

def InvSubBytes(state):
    out=[]
    for row in state:
        li=map(SBoxinv,row)
        out.append(li)
    return out

#### InvMixColumns

In [684]:
def InvMixColumns(state):
    m=matrix(state)
    mt=transpose(m)
    c0,c1,c2,c3=list(mt)
    vpol=vector([T^3,T^2,T,1])
    newstate=[]
    for c in [c0,c1,c2,c3]:
        p=((vector(R2,padding(list(reversed((c).list())),4-len(c.list())))*vpol)*cpolinv)%(T^4+1)
        newstate+=[padding(p.list(),4)]
    return list(transpose(matrix(newstate)))

#### InvShiftRows

In [685]:
def InvShiftRows(state):
    row0,row1,row2,row3=[list(x) for x in state]
    return [CyclicRot(row0,4-C0),CyclicRot(row1,4-C1),CyclicRot(row2,4-C2),CyclicRot(row3,4-C3)]

In [688]:
def AESDecryption(inp,key):
    expkey=KeyExpansion(key)
    state=inp
    state=AddRoundKey(state,RoundKey(expkey,Nr))
    
    for rnd in range(Nr-1,0,-1):
        state=InvShiftRows(state)
        state=InvSubBytes(state)
        state=AddRoundKey(state,RoundKey(expkey,rnd))
        state=InvMixColumns(state)
        
    state=InvShiftRows(state)
    state=InvSubBytes(state)
    state=AddRoundKey(state,RoundKey(expkey,0))
    return [list(x) for x in state]

## NIST example continued

In [690]:
inp=matrix(4,4,[h(x) for x in ['32','43','f6','a8','88','5a','30','8d','31','31','98','a2','e0','37','07','34']])
inp=transpose(inp)
key=[h(x) for x in ['2b','7e','15','16','28','ae','d2','a6','ab','f7','15','88','09','cf','4f','3c']]


cipher=AESEncryption(inp,key)
plain=AESDecryption(matrix(cipher),key)
[[g(x) for x in row] for row in plain]

[['32', '88', '31', 'E0'],
 ['43', '5A', '31', '37'],
 ['F6', '30', '98', '7'],
 ['A8', '8D', 'A2', '34']]