In [32]:
#AES Decryption for a single block
#input: ciphertext & key 128 bits / 16 bytes/ 32 hex digits

import nbimporter,os
from AES_key_schedule import key_schedule

class aes_decrypt():
    
    #initialization
    def __init__(self,cipher_32hex,key_32hex):
        #row and column are transposed
        self.statearray = [[None for x in range(4)] for x in range(4)]
        self.cipher_32hex = cipher_32hex
        self.key_32hex = key_32hex
        
        self.inv_mix_matrix = [
            ['0e','0b','0d','09'],
            ['09','0e','0b','0d'],
            ['0d','09','0e','0b'],
            ['0b','0d','09','0e']
        ]
        self.inv_sbox = [
            ['52','09','6A','D5','30','36','A5','38','BF','40','A3','9E','81','F3','D7','FB'],
            ['7C','E3','39','82','9B','2F','FF','87','34','8E','43','44','C4','DE','E9','CB'],
            ['54','7B','94','32','A6','C2','23','3D','EE','4C','95','0B','42','FA','C3','4E'],
            ['08','2E','A1','66','28','D9','24','B2','76','5B','A2','49','6D','8B','D1','25'],
            ['72','F8','F6','64','86','68','98','16','D4','A4','5C','CC','5D','65','B6','92'],
            ['6C','70','48','50','FD','ED','B9','DA','5E','15','46','57','A7','8D','9D','84'],
            ['90','D8','AB','00','8C','BC','D3','0A','F7','E4','58','05','B8','B3','45','06'],
            ['D0','2C','1E','8F','CA','3F','0F','02','C1','AF','BD','03','01','13','8A','6B'],
            ['3A','91','11','41','4F','67','DC','EA','97','F2','CF','CE','F0','B4','E6','73'],
            ['96','AC','74','22','E7','AD','35','85','E2','F9','37','E8','1C','75','DF','6E'],
            ['47','F1','1A','71','1D','29','C5','89','6F','B7','62','0E','AA','18','BE','1B'],
            ['FC','56','3E','4B','C6','D2','79','20','9A','DB','C0','FE','78','CD','5A','F4'],
            ['1F','DD','A8','33','88','07','C7','31','B1','12','10','59','27','80','EC','5F'],
            ['60','51','7F','A9','19','B5','4A','0D','2D','E5','7A','9F','93','C9','9C','EF'],
            ['A0','E0','3B','4D','AE','2A','F5','B0','C8','EB','BB','3C','83','53','99','61'],
            ['17','2B','04','7E','BA','77','D6','26','E1','69','14','63','55','21','0C','7D'],
        ]

    def init_statearray(self):
        quotient = int(len(self.cipher_32hex)/2)
        m,n = 0,0
        for i in range(quotient):
            item = self.cipher_32hex[2*i] + self.cipher_32hex[2*i+1]
            
            if None in self.statearray[n]:
                self.statearray[n][m] = item 
                m = m + 1
            else:
                self.statearray[n+1][0] = item
                n,m = n+1,1
            
    def key_expansion(self):
        a = key_schedule(self.key_32hex)
        keys = a.key_expansion()
        return keys
    
    def transpose(self,matrix):
        t = [[None]*len(matrix) for row in range(len(matrix[0]))]
        for i in range(len(matrix[0])):
            for j in range(len(matrix)):
                t[i][j] = matrix[j][i]
        return t

    def bin_to_hex(self,bit):
        bin_to_decimal = int(bit,2)
        decimal_to_hex = hex(bin_to_decimal)[2:].zfill(2)
        return decimal_to_hex.upper()

    def hex_to_bin(self,_hex):
        hex_to_decimal = int(_hex,16)
        decimal_to_bin = bin(hex_to_decimal)[2:].zfill(8)
        return decimal_to_bin

    def XOR(self,hex1,hex2):
        bit1 = self.hex_to_bin(hex1)
        bit2 = self.hex_to_bin(hex2)
        xor_result = ''
        for index in range(len(bit1)):
            r = (int(bit1[index])+int(bit2[index])) % 2
            xor_result = xor_result + str(r) 
        xor_result_hex = self.bin_to_hex(xor_result) 
        return xor_result_hex    
    
    def GF_multiply(self,hex1,hex2):
        bit1 = bin(int(hex1,16))[2:]
        bit2 = bin(int(hex2,16))[2:]
        coeffi_poly1,coeffi_poly2 = list(),list()
        
        #reverse the order of coefficient
        for i in range(1,len(list(bit1))+1):
            coeffi_poly1.append(int(list(bit1)[-i]))
        for j in range(1,len(list(bit2))+1):
            coeffi_poly2.append(int(list(bit2)[-j]))
        
        #multiply two polynomials 
        len1,len2 = len(coeffi_poly1),len(coeffi_poly2)
        muilt_two_poly= [0 for i in range(len1+len2-1)]
        for i in range(len1):
            for j in range(len2):
                muilt_two_poly[i+j] += coeffi_poly1[i]*coeffi_poly2[j] 
        #after_mod2
        current = [i%2 for i in muilt_two_poly]
        
        #mod irreducible poly: x^8+x^4+x^3+x+1
        irr_poly = [1,1,0,1,1,0,0,0,1]
        
        while len(current) >= len(irr_poly):
            mod_poly = [i for i in irr_poly]
            for i in range(len(current) - len(irr_poly)):
                mod_poly.insert(0,0)
            #xor two ploynomials
            for j in range(len(current)):
                current[j] = (current[j] + mod_poly[j]) % 2
            #drop the 0 coefficient (let the last item be the highest term)
            while not current[-1]:
                current.pop()              

        #reverse the order of coefficient
        temp = [i for i in current]
        for i in range(len(temp)):
            current[i] = temp[-(i+1)]
            
        #transform into hex digits
        bit_string = ''.join([str(i) for i in current]).zfill(8)
        result_hex = self.bin_to_hex(bit_string)
        return result_hex 
    
    '''decryption'''
    def generate(self):
        self.init_statearray()
        key_set = self.key_expansion()
        
        #round_0: adding_roundkey w[40,43]
        array_0 = [[None for x in range(4)] for x in range(4)]
        for i in range(4):
            for j in range(4):
                array_0[i][j] = self.XOR(self.statearray[i][j],
                                         key_set[40+i][j])
        #round1-9
        cur_state = self.transpose(array_0) 
        for i in range(1,10):
            #inverse_circular_left_shift_row 
            for m in range(len(cur_state)):
                #inverse is circular_right_shift_row 
                cur_state[m] = cur_state[m][-m:] + cur_state[m][:-m]
                
            #byte_substitution
            for j in range(len(cur_state)):
                for k in range(len(cur_state[0])):
                    alph = cur_state[j][k]
                    cur_state[j][k] = self.inv_sbox[int(alph[0],16)][int(alph[1],16)]
        
            #adding_roundkey
            rkey_matrix = self.transpose([key_set[-4*(i+1)+m] for m in range(4)])
            state_mat = [[None for x in range(4)] for x in range(4)]
            for i in range(4):
                for j in range(4):
                    state_mat[i][j] = self.XOR(cur_state[i][j],rkey_matrix[i][j])
            
            #inverse_mix_column
            after_invmix = [['00']*len(state_mat[0]) for i in range(len(self.inv_mix_matrix))]
            #construct after_invmatrix(matrix multiplication)
            for row in range(len(self.inv_mix_matrix)):
                for col in range(len(state_mat[0])):
                    for k in range(len(self.inv_mix_matrix[0])): 
                        after_invmix[row][col] = self.XOR(after_invmix[row][col],
                                                       self.GF_multiply(self.inv_mix_matrix[row][k],state_mat[k][col]))
        
            cur_state = after_invmix
        
        #round10
        #inverse_circular_left_shift_row 
        for m in range(len(cur_state)):
            #inverse is circular_right_shift_row 
            cur_state[m] = cur_state[m][-m:] + cur_state[m][:-m]
                
        #byte_substitution
        for j in range(len(cur_state)):
            for k in range(len(cur_state[0])):
                alph = cur_state[j][k]
                cur_state[j][k] = self.inv_sbox[int(alph[0],16)][int(alph[1],16)]
        
        #adding_roundkey
        rkey_matrix = self.transpose([key_set[-4*(10+1)+m] for m in range(4)])
        state_mat = [[None for x in range(4)] for x in range(4)]
        for i in range(4):
            for j in range(4):
                state_mat[i][j] = self.XOR(cur_state[i][j],rkey_matrix[i][j])
        
        cur_state = state_mat  
        
        return ''.join([''.join(i) for i in self.transpose(cur_state)])#self.transpose(array_0)
    
ciphertext = '29C3505F571420F6402299B31A02D73A'  
key = '5468617473206D79204B756E67204675'
b = aes_decrypt(ciphertext,key)
(b.generate())

'54776F204F6E65204E696E652054776F'