In [None]:
# S-AES Implementation

# --- S-Box and Inverse S-Box ---
SBOX = {
    '0000': '1001', '0001': '0100', '0010': '1010', '0011': '1011',
    '0100': '1101', '0101': '0001', '0110': '1000', '0111': '0101',
    '1000': '0110', '1001': '0010', '1010': '0000', '1011': '0011',
    '1100': '1100', '1101': '1110', '1110': '1111', '1111': '0111'
}

INVERSE_SBOX = {v: k for k, v in SBOX.items()}

# --- Helper Functions ---

def rotate_nibbles(word):
    return word[4:] + word[:4]

def substitute_nibbles(word):
    return ''.join(SBOX[word[i:i+4]] for i in range(0, len(word), 4))

def inverse_substitute_nibbles(word):
    return ''.join(INVERSE_SBOX[word[i:i+4]] for i in range(0, len(word), 4))


def g(word, round_constant):
    rotated = rotate_nibbles(word)
    substituted = substitute_nibbles(rotated)
    g_result = format(int(substituted, 2) ^ int(round_constant, 2), '08b')
    return g_result

def key_expansion(key):
    w = [None] * 6
    w[0] = key[:8]
    w[1] = key[8:]

    w[2] = format(int(w[0], 2) ^ int(g(w[1], '10000000'), 2), '08b')
    w[3] = format(int(w[2], 2) ^ int(w[1], 2), '08b')
    w[4] = format(int(w[2], 2) ^ int(g(w[3], '00110000'), 2), '08b')
    w[5] = format(int(w[4], 2) ^ int(w[3], 2), '08b')

    round_keys = [
        w[0] + w[1],
        w[2] + w[3],
        w[4] + w[5]
    ]
    return round_keys

def add_round_key(state, round_key):
    return format(int(state, 2) ^ int(round_key, 2), '016b')

def shift_rows(state):
    # swap the 2nd and 4th nibble
    return state[:4] + state[12:16] + state[8:12] + state[4:8]

def inverse_shift_rows(state):
    # swap the 2nd and 4th nibble back
    return state[:4] + state[12:16] + state[8:12] + state[4:8]

def mix_columns(state):
    n0 = state[:4]
    n1 = state[4:8]
    n2 = state[8:12]
    n3 = state[12:16]

    # Mix Columns over GF(2^4) with fixed matrix multiplication
    def multiply(a, b):
        # multiplication in GF(2^4) with irreducible polynomial x^4 + x + 1
        a_int = int(a, 2)
        b_int = int(b, 2)
        p = 0
        for _ in range(4):
            if b_int & 1:
                p ^= a_int
            carry = a_int & 0x8
            a_int <<= 1
            if carry:
                a_int ^= 0x13  # irreducible polynomial x^4 + x + 1
            b_int >>= 1
        return format(p & 0xF, '04b')

    new_n0 = format(int(n0, 2) ^ int(multiply('0101', n2), 2), '04b')
    new_n1 = format(int(n1, 2) ^ int(multiply('0101', n3), 2), '04b')
    new_n2 = format(int(multiply('0101', n0), 2) ^ int(n2, 2), '04b')
    new_n3 = format(int(multiply('0101', n1), 2) ^ int(n3, 2), '04b')

    return new_n0 + new_n1 + new_n2 + new_n3

def inverse_mix_columns(state):
    # Inverse of MixColumns in S-AES
    n0 = state[:4]
    n1 = state[4:8]
    n2 = state[8:12]
    n3 = state[12:16]

    def multiply(a, b):
        a_int = int(a, 2)
        b_int = int(b, 2)
        p = 0
        for _ in range(4):
            if b_int & 1:
                p ^= a_int
            carry = a_int & 0x8
            a_int <<= 1
            if carry:
                a_int ^= 0x13
            b_int >>= 1
        return format(p & 0xF, '04b')

    new_n0 = format(int(multiply('1010', n0), 2) ^ int(multiply('1010', n2), 2), '04b')
    new_n1 = format(int(multiply('1010', n1), 2) ^ int(multiply('1010', n3), 2), '04b')
    new_n2 = format(int(multiply('1010', n0), 2) ^ int(multiply('1010', n2), 2), '04b')
    new_n3 = format(int(multiply('1010', n1), 2) ^ int(multiply('1010', n3), 2), '04b')

    return new_n0 + new_n1 + new_n2 + new_n3

# --- Main Encryption and Decryption ---

def encrypt(plaintext, key):
    round_keys = key_expansion(key)

    state = add_round_key(plaintext, round_keys[0])
    state = substitute_nibbles(state)
    state = shift_rows(state)
    state = mix_columns(state)
    state = add_round_key(state, round_keys[1])

    state = substitute_nibbles(state)
    state = shift_rows(state)
    state = add_round_key(state, round_keys[2])

    return state

def decrypt(ciphertext, key):
    round_keys = key_expansion(key)

    state = add_round_key(ciphertext, round_keys[2])
    state = inverse_shift_rows(state)
    state = inverse_substitute_nibbles(state)

    state = add_round_key(state, round_keys[1])
    state = inverse_mix_columns(state)
    state = inverse_shift_rows(state)
    state = inverse_substitute_nibbles(state)

    state = add_round_key(state, round_keys[0])

    return state

# --- Example Usage ---

if __name__ == "__main__":
    plaintext = '0110111101101011'  # Example 16-bit input
    key = '1010001110100011'        # Example 16-bit key

    print(f"Plaintext : {plaintext}")
    print(f"Key        : {key}")

    ciphertext = encrypt(plaintext, key)
    print(f"Ciphertext : {ciphertext}")

    decrypted_text = decrypt(ciphertext, key)
    print(f"Decrypted  : {decrypted_text}")


Plaintext : 0110111101101011
Key        : 1010001110100011
Ciphertext : 1111110010000100
Decrypted  : 0000101100001011
