<a href="https://colab.research.google.com/github/sshrutii/LP3/blob/main/2ii_saes.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [2]:
# Description: Simplified AES implementation in Python 3
import sys
 
# S-Box
sBox  = [0x9, 0x4, 0xa, 0xb, 0xd, 0x1, 0x8, 0x5,
         0x6, 0x2, 0x0, 0x3, 0xc, 0xe, 0xf, 0x7]
 
# Inverse S-Box
sBoxI = [0xa, 0x5, 0x9, 0xb, 0x1, 0x7, 0x8, 0xf,
         0x6, 0x0, 0x2, 0x3, 0xc, 0x4, 0xd, 0xe]
 
# Round keys: K0 = w0 + w1; K1 = w2 + w3; K2 = w4 + w5
w = [None] * 6
 
def mult(p1, p2):
    """Multiply two polynomials in GF(2^4)/x^4 + x + 1"""
    p = 0
    while p2:
        if p2 & 0b1:
            p ^= p1
        p1 <<= 1
        if p1 & 0b10000:
            p1 ^= 0b11
        p2 >>= 1
    return p & 0b1111
 
def intToVec(n):
    """Convert a 2-byte integer into a 4-element vector"""
    return [n >> 12, (n >> 4) & 0xf, (n >> 8) & 0xf,  n & 0xf]            
 
def vecToInt(m):
    """Convert a 4-element vector into 2-byte integer"""
    return (m[0] << 12) + (m[2] << 8) + (m[1] << 4) + m[3]
 
def addKey(s1, s2):
    """Add two keys in GF(2^4)"""  
    return [i ^ j for i, j in zip(s1, s2)]
     
def sub4NibList(sbox, s):
    """Nibble substitution function"""
    return [sbox[e] for e in s]
     
def shiftRow(s):
    """ShiftRow function"""
    return [s[0], s[1], s[3], s[2]]
 
def keyExp(key):
    """Generate the three round keys"""
    def sub2Nib(b):
        """Swap each nibble and substitute it using sBox"""
        return sBox[b >> 4] + (sBox[b & 0x0f] << 4)
 
    Rcon1, Rcon2 = 0b10000000, 0b00110000
    w[0] = (key & 0xff00) >> 8
    w[1] = key & 0x00ff
    w[2] = w[0] ^ Rcon1 ^ sub2Nib(w[1])
    w[3] = w[2] ^ w[1]
    w[4] = w[2] ^ Rcon2 ^ sub2Nib(w[3])
    w[5] = w[4] ^ w[3]
 
def encrypt(ptext):
    """Encrypt plaintext block"""
    def mixCol(s):
        return [s[0] ^ mult(4, s[2]), s[1] ^ mult(4, s[3]),
                s[2] ^ mult(4, s[0]), s[3] ^ mult(4, s[1])]    
     
    state = intToVec(((w[0] << 8) + w[1]) ^ ptext)
    state = mixCol(shiftRow(sub4NibList(sBox, state)))
    state = addKey(intToVec((w[2] << 8) + w[3]), state)
    state = shiftRow(sub4NibList(sBox, state))
    return vecToInt(addKey(intToVec((w[4] << 8) + w[5]), state))
     
def decrypt(ctext):
    """Decrypt ciphertext block"""
    def iMixCol(s):
        return [mult(9, s[0]) ^ mult(2, s[2]), mult(9, s[1]) ^ mult(2, s[3]),
                mult(9, s[2]) ^ mult(2, s[0]), mult(9, s[3]) ^ mult(2, s[1])]
     
    state = intToVec(((w[4] << 8) + w[5]) ^ ctext)
    state = sub4NibList(sBoxI, shiftRow(state))
    state = iMixCol(addKey(intToVec((w[2] << 8) + w[3]), state))
    state = sub4NibList(sBoxI, shiftRow(state))
    return vecToInt(addKey(intToVec((w[0] << 8) + w[1]), state))
 

     
    

#functions of encryption and decryption

#add round key - xor input with subkey
#nibble substitution - s-box encryption
#shift row - replace 2nd and 4th nibble
#mix columns - matrix multiplication
#Key Generation
#
#Input - 16 bit key
#Output - 3 16 bit sub keys {key0, key1, key2}
#Mathematical Notation
##w1 = key[8:]
#w2 = w0 ^ [1,0,0,0,0,0,0,0] ^ sub_nib(rot_nib(w1))
#w3 = w1 ^ w2
#w4 = w2 ^ [0,0,1,1,0,0,0,0] ^ sub_nib(rot_nib(w3))
#w5 = w3 ^ w4
#key0 = key
#key1 = w2w3
#key2 = w4w5
#Encryption

#Input - 16 bit plain text, 16 bit key
#Output - 16 bit cipher text
#cipher_text = ShRow(NibSub(MixCol(ShRow(NibSub(plain_text^key0)))^key1))^key2
#Decryption

#Input - 16 bit cipher text, 16 bit key
#Output - 16 bit plain text
#decrypted_text = InvNibSub(InvShRow(InvMixCol((InvNibSub(InvShRow(cipher_text^key2)))^key1)))^key0
#The 16 bit key in turn produces three subkeys k1, k2 and k3; which are used once each during encryption and decryption















    

In [3]:
plaintext = 0b1101011100101000
key = 0b0100101011110101
ciphertext = 0b0010010011101100
keyExp(key)
try:
    assert encrypt(plaintext) == ciphertext
except AssertionError:
    print("Encryption error")
    print(encrypt(plaintext), ciphertext)
    sys.exit(1)
try:
    assert decrypt(ciphertext) == plaintext
except AssertionError:
    print("Decryption error")
    print(decrypt(ciphertext), plaintext)
    sys.exit(1)
print("Test ok!")

Test ok!


In [8]:
plaintext = "1101011100101000"
key = "0100101011110101"
ciphertext = ""
SBox = {}
SBox["00"] = {"00" : "1000", "01" : "0100", "10" : "1010", "11" : "1011"}
SBox["01"] = {"00" : "1101", "01" : "0001", "10" : "1000", "11" : "0101"}
SBox["10"] = {"00" : "0110", "01" : "0010", "10" : "0000", "11" : "0011"}
SBox["11"] = {"00" : "1100", "01" : "1110", "10" : "1111", "11" : "0111"}
SBox_inv = {}
SBox_inv["00"] = {"00" : "1010", "01" : "0101", "10" : "1001", "11" : "1011"}
SBox_inv["01"] = {"00" : "0001", "01" : "0111", "10" : "1000", "11" : "1111"}
SBox_inv["10"] = {"00" : "0110", "01" : "0000", "10" : "0010", "11" : "0011"}
SBox_inv["11"] = {"00" : "1100", "01" : "0100", "10" : "1101", "11" : "1110"}
mult = {}

mult["0100"] = {"0000" : "0000", "0001" : "0100", "0010" : "1000", "0011" : "1100", \
                "0100" : "0011", "0101" : "0111", "0110" : "1011", "0111" : "1111", \
                "1000" : "0110", "1001" : "0010", "1010" : "1110", "1011" : "1010", \
                "1100" : "0101", "1101" : "0001", "1110" : "1101", "1111" : "1001"}

mult["0010"] = {"0000" : "0000", "0001" : "0010", "0010" : "0100", "0011" : "0110",\
                "0100" : "1000", "0101" : "1010", "0110" : "1100", "0111" : "1110",\
                "1000" : "0011", "1001" : "0001", "1010" : "0111", "1011" : "0101",\
                "1100" : "1011", "1101" : "1001", "1110" : "1111", "1111" : "1101"}

mult["1001"] = {"0000" : "0000", "0001" : "1001", "0010" : "0001", "0011" : "1000",\
                "0100" : "0010", "0101" : "1011", "0110" : "0011", "0111" : "1010",\
                "1000" : "0100", "1001" : "1101", "1010" : "0101", "1011" : "1100",\
                "1100" : "0110", "1101" : "1111", "1110" : "0111", "1111" : "1110"}




def XOR(a, b):
    y = int(a, 2) ^ int(b, 2)
    y = bin(y)[2 : ].zfill(len(a))
    return y

def rotateNibble(w):
    return w[4 : ], w[ : 4]

def subNibble(a, b):
    return (SBox[a[ : 2]][a[2 : ]] + SBox[b[ : 2]][b[2 : ]])

def generateKey(w0, w1, w2, w3, w4, w5):
    global key1, key2, key3
    key1 = w0 + w1
    key2 = w2 + w3
    key3 = w4 + w5
    print("key1 : ", key1)
    print("key2 : ", key2)
    print("key3 : ", key3)


def encrypt():

    temp = XOR(plaintext, key1)
    tempout = SBox[temp[ : 2]][temp[2 : 4]] + SBox[temp[4 : 6]][temp[6 : 8]] \
    + SBox[temp[8 : 10]][temp[10 : 12]] + SBox[temp[12 : 14]][temp[14 : ]]
    tempswap = tempout[ : 4] + tempout[12 : ] + tempout[8 : 12] + tempout[4 : 8]
    
    S_00 = XOR(tempswap[ : 4], mult["0100"][tempswap[4 : 8]])
    S_01 = XOR(tempswap[8 : 12], mult["0100"][tempswap[12 : ]])
    S_10 = XOR(tempswap[4 : 8], mult["0100"][tempswap[ : 4]])
    S_11 = XOR(tempswap[12 : ], mult["0100"][tempswap[8 : 12]])
    
    r1_inv = S_00 + S_10 + S_01 + S_11
    temp = XOR(r1_inv, key2)
    
    tempout = SBox[temp[ : 2]][temp[2 : 4]] + SBox[temp[4 : 6]][temp[6 : 8]] \
    + SBox[temp[8 : 10]][temp[10 : 12]] + SBox[temp[12 : 14]][temp[14 : ]]
    tempswap = tempout[ : 4] + tempout[12 : ] + tempout[8 : 12] + tempout[4 : 8]
        
    ciphertext = XOR(tempswap, key3)
    return ciphertext

def decrypt():
    
    temp = XOR(ciphertext, key3)
    tempswap = temp[ : 4] + temp[12 : ] + temp[8 : 12] + temp[4 : 8]
    tempout = SBox_inv[tempswap[ : 2]][tempswap[2 : 4]] + SBox_inv[tempswap[4 : 6]][tempswap[6 : 8]] \
    + SBox_inv[tempswap[8 : 10]][tempswap[10 : 12]] + SBox_inv[tempswap[12 : 14]][tempswap[14 : ]]
    
    temp = XOR(tempout, key2)
    
    S_00 = XOR(mult["1001"][temp[ : 4]], mult["0010"][temp[4 : 8]])
    S_01 = XOR(mult["1001"][temp[8 : 12]], mult["0010"][temp[12 : ]])
    S_10 = XOR(mult["1001"][temp[4 : 8]], mult["0010"][temp[ : 4]])
    S_11 = XOR(mult["1001"][temp[12 : ]], mult["0010"][temp[8 : 12]])
    
    Sout = S_00 + S_10 + S_01 + S_11
    tempswap = Sout[ : 4] + Sout[12 : ] + Sout[8 : 12] + Sout[4 : 8]
    tempout = SBox_inv[tempswap[ : 2]][tempswap[2 : 4]] + SBox_inv[tempswap[4 : 6]][tempswap[6 : 8]] \
    + SBox_inv[tempswap[8 : 10]][tempswap[10 : 12]] + SBox_inv[tempswap[12 : 14]][tempswap[14 : ]]
    plaintext = XOR(tempout, key1)
    return plaintext  


def keyUtil():
    w0 = key[ : 8]
    w1 = key[8 : ]
    a, b = rotateNibble(w1)
    w2 = XOR(XOR(w0, "10000000"), subNibble(a, b))
    w3 = XOR(w2, w1)
    a, b = rotateNibble(w3)
    w4 = XOR(XOR(w2, "00110000"), subNibble(a, b))
    w5 = XOR(w4, w3)
    generateKey(w0, w1, w2, w3, w4, w5)

print("Original Plaintext : ", plaintext)
print("Original Key : ", key)
print()
keyUtil()
print()
ciphertext = encrypt()
print("Ciphertext after encryption : ", ciphertext)
plaintext = decrypt()
print("Plaintext after decryption : ", plaintext) 

Original Plaintext :  1101011100101000
Original Key :  0100101011110101

key1 :  0100101011110101
key2 :  1101110100101000
key3 :  1000011110101111

Ciphertext after encryption :  0010010011101100
Plaintext after decryption :  1101011100101000
