In [None]:
# Importing necessary libraries and functions from spn file
import random
random.seed(40)
import numpy as np
from spn import *

In [None]:
# global parametres
no_of_inputs=2048

# Functions used to break the cipher

In [None]:
# Given a S-Box this function returns corresponding Linear approximation table
def linear_approx_table(sbox):
    l = []
    for i in range(16):
        l1 = []
        for j in range(16):
            count=-8
            for x in range(16):
                t = (x & i )^(j & sbox[x])
                count += (1 - no_of_ones(t))
            l1.append(count)
        l.append(l1)
    return l

# Given an input,linear_table this function returns the input to the next round
# that we will use to get the linear equation

def f(input,linear_table):
    y = demux(input)
    l = []
    p = 1
    for i in y:
        if i == 0:
            l.append(0)
        else:
            max_val=0
            max_ele=0
            for j in range(16):
                if abs(linear_table[i][j])> abs(max_val):
                    max_val=linear_table[i][j]
                    max_ele=j
            l.append(max_ele)
            p*=(max_val/8)
    output = pbox(mux(l))
    return output,p/2

# Given inputs,cipher texts this function returns it best guess for Key
def get_k(last_before_output,inputs,cipher_texts,initial_input,bias):
    y = demux(last_before_output)
    count_zeros = no_of_zeros(y)
    if(count_zeros<1):
        return -1,-1,-1,-1,-1
    all_k=[]
    # assume pos_of_k be pos1 and pos2 as in 0 1 2 3 order
    pos1 = -1
    pos2 = -1
    pos3 = -1

    for i_ in range(4):
        if y[i_] != 0 and pos1 == -1:
            pos1 = i_
        elif y[i_] != 0 and pos2 == -1:
            pos2 = i_
        elif y[i_] != 0 and pos3 == -1:
            pos3=i_

    if(pos2 == -1):
        for i_ in range(16):
            all_k.append((i_<<(4*(3-pos1))))
    elif (pos3 == -1):
        for i_ in range(16):
            for j_ in range(16):
                all_k.append((i_<<(4*(3-pos1))) + (j_<<(4*(3-pos2))))
    else:
        for i_ in range(16):
            for j_ in range(16):
                for k_ in range(16):
                    all_k.append((i_<<(4*(3-pos1))) + (j_<<(4*(3-pos2))) + (k_ << (4*(3-pos3))))

    # p contains biases of equations with all possible values of Keys
    p=[0]*len(all_k)
    for i in range(len(inputs)):
        probe=demux((initial_input & inputs[i]))
        a=0
        for _ in probe:
            a+=(no_of_ones(_))
        a=a%2

        output_ciphers=demux(cipher_texts[i])
        for k in range(len(all_k)):
            b=0
            split_k=demux(all_k[k])
            for e in range(4):
                # Computing inverse of Key XOR ciphertext (to be placed in the computed )
                inv=s_inv[no_of_rounds-1][split_k[e]^output_ciphers[e]] & y[e]
                b+=no_of_ones(inv)
            if (a+b)%2==0:
                p[k]+=1
            else:
                p[k]-=1
    # Taking value of Key which has bias with plaintext,ciphertext pairs almost same as computed bias from linear equations
    probs = [abs(abs(x / (2*len(inputs)))- bias) for x in p]
    min_val=min(probs)
    return all_k[probs.index(min_val)],pos1,pos2,pos3,min_val



In [None]:
# generating inputs and their cipher texts
inputs=[]
cipher_texts=[]
key=[]
for i in range(no_of_rounds):
    key.append(random.randint(0, 2**16 - 1))
for i in range(no_of_inputs):
    m=random.randint(0, 2**16 - 1)
    c=encrypt(key,m,no_of_rounds)
    inputs.append(m)
    cipher_texts.append(c)

In [None]:
linear_tables=[]
for i in range(no_of_rounds):
    linear_tables.append(linear_approx_table(s[i]))

In [None]:
all_x=[]
ans=[]
ans_probs=[]
ans_pos=[]
for i in range(1,16):
    all_x.extend([i,16*i,256*i,4096*i])
linear_probs=[]
linear_last_outputs=[]

for x in all_x:
    p=0.5
    output=x
    i=0
    while(i<no_of_rounds-1):
        output,probs=f(output,linear_tables[i])
        p*=(2*probs)
        i+=1
    last_before_output=output
    linear_probs.append(abs(p))
    linear_last_outputs.append(last_before_output)

linear_prob=np.array(linear_probs)
indices=np.argsort(-linear_prob)

new_x=np.array(all_x)
new_x=new_x[indices]

linear_last_output=np.array(linear_last_outputs)
linear_last_output=linear_last_output[indices]
linear_prob=linear_prob[indices]
key_val=[-1,-1,-1,-1]
for i in range(len(linear_prob)):
    count=0
    l=demux(linear_last_output[i])
    # print(l)
    for j in range(4):
        if key_val[j]==-1 and l[j]!=0:
            count+=1
    if(count==0):
        continue
    k,pos1,pos2,pos3,min_val=get_k(linear_last_output[i],inputs,cipher_texts,new_x[i],linear_prob[i])
    if(k!=-1):
        # print(pos1,pos2,pos3)
        ans_probs.append(min_val)
        ans.append(demux(k))
        ans_pos.append([pos1,pos2,pos3])
    k_=demux(k)
    for j in [pos1,pos2,pos3]:
        if j!=-1 and key_val[j]==-1:
            key_val[j]=k_[j]
            print("found a value")
    # print(i,":",key_val)
    if (min(key_val)>-1):
        break

In [None]:
print("Actual Key taken in last round ",demux(key[no_of_rounds-1]))
print("Obtained Key by Cryptanalysis ",key_val)

key_taken = key[no_of_rounds-1]
key_obtained = mux(key_val)
xor_of_both = key_taken ^ key_obtained
bits_matched = 0
for i_ in range(16):
    bits_matched += (1-(xor_of_both & 1))
    xor_of_both >>= 1
print(f"Number of bits matched : {bits_matched}/16",)