In [1]:
#Required Variables
key = "0100101010010101"

sub_table = [
    ["1001", "0100", "1010", "1011"],
    ["1101", "0001", "1000", "0101"],
    ["0110", "0010", "0000", "0011"],
    ["1100", "1110", "1111", "0111"]
]

inv_sub_table = [
    ["1010", "0101", "1001", "1011"],
    ["0001", "0111", "1000", "1111"],
    ["0110", "0000", "0010", "0011"],
    ["1100", "0100", "1101", "1110"]
]

mix_col_mat = [[1, 4], [4, 1]]

inv_mix_col_mat = [[9, 2], [2, 9]]

In [2]:
#Function to rotate the nibbles
def rotate_nibble(wi):
    assert (len(wi)==8), "Invalid word length" 
        
    return wi[4:] + wi[:4]

In [3]:
def get_xor(word, key, output_length):
    
    assert (len(word) == output_length), "Invalid Word length"
    assert (len(key) == output_length), "Invalid Key length"
    
    word = int(word, 2)
    key = int(key, 2)
    
    res = word ^ key
    
    return format(res, 'b').zfill(output_length)

In [4]:
#sbox substitution
def sbox_substitution(word, process):
    
    assert (len(word)==8), "Invalid Word length"
    
    table = inv_sub_table if process == "de" else sub_table

    # Return 4 bit output from table for each nibble
    return table[int(word[:2], 2)][int(word[2:4], 2)] + table[int(word[4:6], 2)][int(word[6:], 2)]

In [5]:
def generate_keys(key):
    
    assert (len(key) == 16), "Key is not 16 bits long"
    
    w0, w1 = key[:8], key[8:]

    w2 = get_xor(w0, get_xor("10000000", sbox_substitution(rotate_nibble(w1), "en"), 8), 8)
    w3 = get_xor(w2, w1, 8)

    w4 = get_xor(w2, get_xor("00110000", sbox_substitution(rotate_nibble(w3), "en"), 8), 8)
    w5 = get_xor(w4, w3, 8)
    
    return key, w2+w3, w4+w5

In [6]:
def get_mult(num1, num2):
    
    ret = 0

    num1 = int(num1, 2)

    for i in reversed(range(4)):
        if (num2 & (1 << i)):
            ret = ret ^ (num1 << i)

    red_poly = 19           # 0x13
    high_bit = 16           # 0x10

    for i in reversed(range(4)):
        if (ret & (high_bit << i)):
            ret = ret ^ (red_poly << i)

    return format(ret, "04b")

In [7]:
def mix_columns(data, process):
    
    assert (len(data) == 16), "Invalid data length"

    mat = inv_mix_col_mat if process == "de" else mix_col_mat

    s00 = get_xor(get_mult(data[:4], mat[0][0]), get_mult(data[4:8], mat[0][1]), 4)

    s10 = get_xor(get_mult(data[:4], mat[1][0]), get_mult(data[4:8], mat[1][1]), 4)

    s01 = get_xor(get_mult(data[8:12], mat[0][0]), get_mult(data[12:], mat[0][1]), 4)

    s11 = get_xor(get_mult(data[8:12], mat[1][0]), get_mult(data[12:], mat[1][1]), 4)

    return s00 + s10 + s01 + s11

In [8]:
def perform_saes(key, data, process):
    # Throw error if process is not en or de
    assert(process in ["en", "de"]), "Enter valid process"
    
    # Throw error if data is not 16 bits
    assert (len(data) == 16), "Invalid data length"

    # Generate keys
    k1, k2, k3 = generate_keys(key)
    
    # If doing decryption
    if process == "de":
        # Add Round key
        data = get_xor(data, k3, 16)
        print("\nAdd round key 3 : " + data)

        # Shift Rows
        data = data[:4] + data[12:] + data[8:12] + data[4:8]
        print("\nShift Rows : " + data)

        # Substitute nibbles
        data = sbox_substitution(data[:8], process) + sbox_substitution(data[8:], process)
        print("\nSubstitute nibbles : " + data)

        #Add Round key
        data = get_xor(data, k2, 16)
        print("\nAdd round key 2 : " + data)
    
        # Mix Columns
        data = mix_columns(data, process)
        print("\nMix columns : " + data)

        # Shift Rows
        data = data[:4] + data[12:] + data[8:12] + data[4:8]
        print("\nShift Rows : " + data)

        # Substitute nibble
        data = sbox_substitution(data[:8], process) + sbox_substitution(data[8:], process)
        print("\nSubstitute nibbles : " + data)

        # Add round key
        data = get_xor(data, k1, 16)
        print("\nAdd round key 1 : " + data)

    else:
        # Add round key
        data = get_xor(data, k1, 16)
        print("\nAdd round key 1 : " + data)

        # Substitute nibble
        data = sbox_substitution(data[:8], process) + sbox_substitution(data[8:], process)
        print("\nSubstitute nibbles : " + data)

        # Shift Rows
        data = data[:4] + data[12:] + data[8:12] + data[4:8]
        print("\nShift Rows : " + data)
    
        # Mix Columns
        data = mix_columns(data, process)
        print("\nMix columns : " + data)

        #Add Round key
        data = get_xor(data, k2, 16)
        print("\nAdd round key 2 : " + data)

        # Substitute nibbles
        data = sbox_substitution(data[:8], process) + sbox_substitution(data[8:], process)
        print("\nSubstitute nibbles : " + data)

        # Shift Rows
        data = data[:4] + data[12:] + data[8:12] + data[4:8]
        print("\nShift Rows : " + data)

        #Add Round key
        data = get_xor(data, k3, 16)
        print("\nAdd round key 3 : " + data)

    # Return encrypted data
    return data

In [9]:
# Accept data from user to encrypt
data = input("Enter data to encrypt : ")
print("data to encrypt : ", data)

# Add new line at the end if number of charcters in data is odd
if (len(data)%2 != 0):
    data += "\n"

# Initialize variable
cipher_text = ""

# Encrypt data 2 characters at a time
for i in range(int(len(data)/2)):
    tmp = format(ord(data[2*i]), "08b") + format(ord(data[2*i + 1]), "08b")
    cipher_text += perform_saes(key, tmp, "en")


print("\nEncrypted Data : ", cipher_text)

Enter data to encrypt : AB
data to encrypt :  AB

Add round key 1 : 0000101111010111

Substitute nibbles : 1001001111100101

Shift Rows : 1001010111100011

Mix columns : 1110011100101110

Add round key 2 : 0011111101100011

Substitute nibbles : 1011011110001011

Shift Rows : 1011101110000111

Add round key 3 : 1011111011001111

Encrypted Data :  1011111011001111


In [10]:
original = ""

# Decrypt data 16 bits at a time
for i in range(int(len(cipher_text)/16)):
    tmp = perform_saes(key, cipher_text[16*i:(i+1)*16], "de")
    original += chr(int(tmp[:8], 2)) + chr(int(tmp[8:], 2))

# Print decrypted data
print("\nDecrypted Data : ", original)


Add round key 3 : 1011101110000111

Shift Rows : 1011011110001011

Substitute nibbles : 0011111101100011

Add round key 2 : 1110011100101110

Mix columns : 1001010111100011

Shift Rows : 1001001111100101

Substitute nibbles : 0000101111010111

Add round key 1 : 0100000101000010

Decrypted Data :  AB
