In [None]:
#Resources and Dependencies, if you are interested, you can use the links to see the code
!wget https://raw.githubusercontent.com/orel509/AttacksonImplementationsCourseBook/master/Labs/WS2.mat
!wget https://raw.githubusercontent.com/orel509/AttacksonImplementationsCourseBook/master/Labs/hamming_weight.py
!wget https://raw.githubusercontent.com/orel509/AttacksonImplementationsCourseBook/master/Labs/aes_scripts/aes_lib.py

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import scipy.io as sp
from hamming_weight import hamming_weight
from bokeh.models import Range1d
#from aes_scripts.aes_crypt_8bit_and_leak import aes_crypt_8bit_and_leak, aes_sbox
from aes_lib import aes_crypt_8bit_and_leak, aes_sbox

!pip install -q bokeh
from bokeh.plotting import figure, show
from bokeh.models import Range1d, ColorBar, LinearColorMapper, BasicTicker
from bokeh.io import output_notebook
import bokeh.colors.named as bokeh_colors_names

bokeh_colors_names_arr = dir(bokeh_colors_names)[10:]
bokeh_colors_names_arr = bokeh_colors_names_arr*3

# Call once to configure Bokeh to display plots inline in the notebook.
output_notebook()

In [None]:
def getHeatMap(title, xaxis, yaxis, data, dh, dw):
    indMax = np.unravel_index(np.argmax(data, axis=None), data.shape)
    print(np.shape(data), data[indMax])    
    color_mapper = LinearColorMapper(palette="Turbo256", low=1e-2, high=data[indMax])
    #Log heatmap

    fig = figure(title=title, x_range=(0,dw), y_range=(0,dh), x_axis_label=xaxis, y_axis_label=yaxis,
        tooltips=[("x", "$x"), ("y", "$y"), ("value", "@image")], plot_height=700, plot_width=700)
    fig.image(image=[data], x=0, y=0, dw=dw, dh=dh, color_mapper=color_mapper, level="image")
    #palette values-https://docs.bokeh.org/en/latest/docs/reference/palettes.html

    #add heat map
    color_bar = ColorBar(color_mapper=color_mapper, ticker=BasicTicker(), border_line_color=None, location=(0,0))
    fig.add_layout(color_bar, 'right')
    #show(fig)
    return fig

In [None]:
#  Differential power analysis and correlation power analysis
DPA = 0
CPA = 1
dpa_or_cpa = CPA  # DPA

In [None]:
#  Make sure the matlab AES scripts are in the path
#  Load WS2, show a few traces
ws2 = sp.loadmat('WS2.mat')
print(np.shape(ws2['traces']))  # D = 200, T = 100000

In [None]:
#  shrink it a little so it runs faster
traces = ws2['traces'][:, 0:30000]
input_count = np.shape(traces)[0]
trace_length = np.shape(traces)[1]

In [None]:
p = figure(title='Traces of AES computation', x_axis_label='Time (ms)', y_axis_label='value', tooltips=[("x", "$x"), ("y", "$y")])
times = np.linspace(0,len(traces[0, :]), len(traces[0, :]))
p.x_range=Range1d(300, 500)
p.line(times, traces[0, :], legend_label='Trace 1', line_color='blue')
p.line(times, traces[1, :], legend_label='Trace 2', line_color='orange')
show(p)

In [None]:
#  We want to guess byte 1 in the key
key_byte_to_guess = 5
classification_output = np.zeros(shape=(2**8, trace_length))
print(np.shape(classification_output))
# #
#  For each key guess
trace_classification = np.zeros(shape=(2**8, input_count))
inputs = ws2['inputs']

In [None]:
print(np.shape(traces), np.shape(classification_output), np.shape(trace_classification))

In [None]:
# guessing the byte of the key and finding the correct guess
for key_guess in range(2**8):
  # For each plaintext input
  for input in range(input_count):
    # Calculate what the value of S[P ^ K] is
    p_xor_k = np.bitwise_xor(inputs[input, key_byte_to_guess - 1], key_guess)
    s_p_xor_k = aes_sbox(p_xor_k, 1)

    if dpa_or_cpa == DPA:
      trace_classification[key_guess, input] = (np.bitwise_and(s_p_xor_k, 1) != 0)
    else:
      trace_classification[key_guess, input] = hamming_weight(s_p_xor_k)

  # % Calculate the mean of each classified set
  if dpa_or_cpa == DPA:
    mean_for_1 = np.mean(traces[trace_classification[key_guess, :] == 1, :], axis=0)
    mean_for_0 = np.mean(traces[trace_classification[key_guess, :] == 0, :], axis=0)
    # % Save the difference of means in the table
    classification_output[key_guess, :] = np.subtract(mean_for_1, mean_for_0)
  else:
    shape = np.shape(trace_classification[key_guess, :])
    my_trace = np.reshape(trace_classification[key_guess, :], newshape=(shape[0], 1))

    traces1 = (traces - traces.mean(axis=0))/traces.std(axis=0)  # A matrix
    my_trace = (my_trace - my_trace.mean(axis=0))/my_trace.std(axis=0)  # B matrix
    correlation = (np.dot(my_trace.T, traces1) / my_trace.shape[0])[0]

    classification_output[key_guess, :] = np.transpose(correlation)

  print('[{:02x}]'.format(key_guess), end=" ")
  if (key_guess % 16) == 15:
    print('\n')

In [None]:
#  Plot the trace classification matrix
[dh, dw] = np.shape(trace_classification)

show(getHeatMap('Trace classification', 'Trace index', 'Key guess for byte  ' + str(key_byte_to_guess), trace_classification, dh, dw))

In [None]:
#  Find out the correct timne and correct key
absolute = np.abs(classification_output)
index = np.unravel_index(np.argmax(absolute, axis=None), absolute.shape)
correct_time = index[1]

absolute = np.abs(classification_output[:, correct_time])
correct_key = np.argmax(absolute)  # this is actually correct_key + 1
print(correct_key, correct_time)

In [None]:
heights = np.abs(classification_output[:, correct_time])
p = figure(title='Correlation for each value of the byte', x_axis_label='Key guess', y_axis_label='Correlation', tooltips=[("x", "$x"), ("y", "$y")])
p.vbar(x=range(1, np.shape(classification_output)[0] + 1), top=heights, width=0.8, fill_color='blue')
show(p)

In [None]:
# #
#  CPA only: show the actual power consumption at correct time, compared to
#  power model
plot1 = np.true_divide(traces[:, correct_time], 5)
p = figure(title='The actual power consumption at correct time compared to the power model', x_axis_label='Trace index', y_axis_label='power', tooltips=[("x", "$x"), ("y", "$y")])
times = np.linspace(0,200,200)
p.line(times, plot1, legend_label='Power model for correct key', line_color='blue')
p.line(times, np.transpose(trace_classification[correct_key, :]), legend_label='Power consumption at correct time', line_color='orange')
show(p)

In [None]:
window_frame = 100#ms
offframe = 2*window_frame
ys=[x[correct_time - offframe:correct_time + offframe] for x in classification_output]
p = figure(title='The correct key at the correct time', x_axis_label='Time (ms)', y_axis_label='Correlation', tooltips=[("x", "$x"), ("y", "$y"), ("Key guess", "$index")])
times = np.linspace(correct_time - offframe, correct_time + offframe, offframe*2)
times = np.tile(times,(len(ys),1))#duplicate the array for the multiline function
p.multi_line(xs=list(times), ys=ys, color=bokeh_colors_names_arr[:len(times)])
#p.multi_line(xs=[[1, 2, 3], [2, 3, 4]], ys=[[6, 7, 2], [4, 5, 7]], color=['red','green'])
p.x_range=Range1d(correct_time - window_frame, correct_time + window_frame)
show(p)

In [None]:
p = figure(title='Correlation at the correct time', x_axis_label='Time (ms)', y_axis_label='Correlation', tooltips=[("x", "$x"), ("y", "$y")])
times = np.linspace(0,trace_length,len(classification_output[correct_key, :]))
p.line(times, classification_output[correct_key, :], line_color='blue')
show(p)