# MSP430 SCA Version 3

This code was developed to aid in performing a side channel attack (SCA) on AES128 encryption on an MSP430. Code from the following sources was either directly copied or used as an example for this code:
- https://wiki.newae.com/V4:Tutorial_B6_Breaking_AES_(Manual_CPA_Attack)
- Chatgpt

Version 3 includes the following updates from version 2:
- The plaintext is gathered from the data instead of the filename, as described in the "Plaintext" section below

## Library Import

In [1]:
# for numpy array and plotting
import numpy as np
import matplotlib.pyplot as plt

# for list of plaintext
import os

# for converting from .mat to .py
from scipy.io import loadmat
from datetime import datetime

# for program runtime
import time

## Data Preprocessing

### List plaintext

The following code stores the filenames to be used to extract data from the files.

In [2]:
file_names = []

# Specify the directory path
directory_path = 'C:/Users/ktrippe/OneDrive - University of Arkansas/Trulogic/SCA_AES/srand2/'

# Get a list of filenames in the directory
filenames = os.listdir(directory_path)

# Create a list of tuples containing filename and modification time
file_info_list = [(filename, os.path.getmtime(os.path.join(directory_path, filename))) for filename in filenames]

# Sort the list based on modification time
sorted_file_info = sorted(file_info_list, key=lambda x: x[1])

# Print the sorted filenames and modification times
for filename, modification_time in sorted_file_info:
    formatted_time = datetime.fromtimestamp(modification_time).strftime('%Y-%m-%d %H:%M:%S')
    file_names.append(filename)
    #print(filename)
    
print(f"{len(file_names)} files found")

1278 files found


### Plaintext

Previously, plaintext was stored as the filename for the data being saved on the oscilloscope. This method works for manual data collection, but not for automatic data collection. When the data is automatically collected as MATLAB files with the Textronix 5 Series MSO oscilloscope, the filename is _"const string" + channel + date + time + ".mat"_ (eg. power_trace_ch4_20240130110537000.mat). The plaintext data could not be embedded in the filename.

We chose to instead use the GPIO pins on the MSP430 to output the plaintext and probe those pins with digital oscilloscope probes. This method appears to provide the most assurance of plaintext correspondance with power traces but is unfortunately a little more complex. Firstly, the plaintext is sent to the output pins one byte at a time and is ordered from MSB (bit 7) to LSB (bit 0). The following list orients the oscilloscope channel and probe number with the plaintext bit number:
- Ch4 D7 = bit 7
- Ch4 D6 = bit 6
- Ch4 D5 = bit 5
- Ch4 D4 = bit 4
- Ch5 D3 = bit 3
- Ch5 D2 = bit 2
- Ch5 D1 = bit 1
- Ch5 D0 = bit 0

Secondly, we included a flag that served two purposes. One was to trigger everytime a new byte of data was recieved and to trigger the beginning and end of the sbox lookup, which is where we are interested in the power trace. In total, this flag was raised 18 times - 16 total bytes of data (128 bits), and then the start and end of sbox lookup. The following orients the sbox flag to the oscilloscope channel and probe number:
- Ch5 D4 = Sbox flag

In [3]:
bit_pos = ["D0", "D1", "D2", "D3", "D4", "D5", "D6", "D7"] # bit position
decimal_plaintext = []
traces_3d = []

def hex_to_decimal(hex_str):
    # Parse the hex number into pairs
    hex_pairs = [hex_str[i:i+2] for i in range(0, len(hex_str), 2)]

    # Convert the pairs to decimal
    decimal_numbers = [int(pair, 16) for pair in hex_pairs]

    return decimal_numbers

for j in range(24): #len(file_names)
    start_time = time.time()
    # Load a MATLAB file
    mat_data = loadmat('C:/Users/ktrippe/OneDrive - University of Arkansas/Trulogic/SCA_AES/srand2/' + file_names[j])
    # print(mat)
    # print("Loading file " + file_names[j])
    
    if "ch" in file_names[j]:
        if "ch4" in file_names[j]:
            plaintext_col = mat_data[bit_pos[7]]
            init = 6  
            tot = 3
        elif "ch5" in file_names[j]:
            init = 4 
            tot = 5
        else:
            print(file_names[j] + " not loaded")
            break

        for k in range(tot):
            # print(f"-------- Iteration {j}.{k} Complete --------")
            plaintext_col = np.column_stack((plaintext_col, mat_data[bit_pos[init - k]]))
            
    else:
        # save power trace data
        single_trace = mat_data['data']
        
        # transpose plaintext matrix            
        plaintext = plaintext_col.T
        
        my_array = plaintext[4]

        positions = []

        for i, value in enumerate(my_array[:-1]):
            if value == 0 and my_array[i + 1] == 1:
                positions.append(i + 1)
        
        # Assuming you have an array 'data' and a list of indices 'index_list'
        current_plaintext = plaintext
        data = current_plaintext[np.arange(current_plaintext.shape[0]) != 4]  # Example array, replace with your actual data
        index_list = positions  # Example list of indices, replace with your actual indices

        # Extract values using array indexing
        result_values = data[np.arange(data.shape[0])[:, None], index_list]

        plaintext_bits = []

        for x in range(len(result_values[0])):
            for y in range(len(result_values)):
                plaintext_bits.append(result_values[y][x])
                # print(result_values[y][x])

        plaintext_hex = np.array(plaintext_bits)

        # Assuming you have a NumPy array of bits
        bits_array = plaintext_hex

        # Reshape the array into chunks of 4 bits
        bits_matrix = bits_array.reshape(-1, 4)

        # Convert each chunk to its hexadecimal representation
        hex_list = [''.join(map(str, chunk)) for chunk in bits_matrix]
        hex_numbers = [hex(int(chunk, 2))[2:] for chunk in hex_list]
        hex_string = ''.join(hex_numbers[0:32])
        decimal_numbers = hex_to_decimal(hex_string)
        
        try:
            traces_3d.append(single_trace[positions[16]:positions[17]])
            decimal_plaintext.append(decimal_numbers)
            print(f"Updating traces and plaintext {hex_string}")
        except IndexError as e:
            # Handle the case where the array doesn't have 18 elements
            print(f"Error: Invalid plaintext {hex_string}")

        print(f"Traces size: {len(traces_3d)} | Plaintext size: {len(decimal_plaintext)}\n")
        
        # end_time = time.time()
        # elapsed_time = end_time - start_time

        # print(f"Elapsed Time: {elapsed_time} seconds")        

Error: Invalid plaintext efa5
Traces size: 0 | Plaintext size: 0

Updating traces and plaintext 00112233445566778899aabbccddeeff
Traces size: 1 | Plaintext size: 1

Updating traces and plaintext 69c4e0d86a7b0430d8cdb78070b4c55a
Traces size: 2 | Plaintext size: 2

Updating traces and plaintext 4f638c735f614301567824b1a21a4f6a
Traces size: 3 | Plaintext size: 3

Updating traces and plaintext 507840ad15b6581ea266f2c63fb28276
Traces size: 4 | Plaintext size: 4

Error: Invalid plaintext f0ff
Traces size: 4 | Plaintext size: 4

Updating traces and plaintext 00112233445566778899aabbccddeeff
Traces size: 5 | Plaintext size: 5

Updating traces and plaintext 69c4e0d86a7b0430d8cdb78070b4c55a
Traces size: 6 | Plaintext size: 6



In [4]:
print(f"Shape of plaintext: {plaintext_col.shape}")

Shape of plaintext: (4062500, 9)


### Additional Filtering and Conversion to Numpy Array

The SCA algorithm below is best suited for the traces array in a 

In [5]:
# Define the minimum length threshold
min_length_threshold = 1100000

# Step 1: Remove arrays that are too small
filtered_arrays = [arr for arr in traces_3d if len(arr) >= min_length_threshold]
filtered_plaintext = [plaintext for arr, plaintext in zip(traces_3d, decimal_plaintext) if len(arr) >= min_length_threshold]

# Step 2: Find the minimum length among the remaining arrays
min_length = min(len(arr) for arr in filtered_arrays)

# Step 3: Shorten all remaining arrays to the minimum length
shortened_arrays = [arr[:min_length] for arr in filtered_arrays]

# Step 4: Convert the list of shortened arrays to a NumPy array
traces = np.array(shortened_arrays)
plaintext = np.array(filtered_plaintext)

# Collapse 3D Traces Numpy Array into 2D Array
traces = np.squeeze(traces, axis=-1)

# Print the results
print("Shape of Traces Array:", traces.shape)
print("Shape of Plaintext Array:", plaintext.shape)

Shape of Traces Array: (6, 1127501)
Shape of Plaintext Array: (6, 16)


### Attack

This is where the attack is performed. All of the variables used here should be taken care of in data preprocessing.

In [6]:
#Lookup table for number of 1's in binary numbers 0-256
HW = [bin(n).count("1") for n in range(0,256)] 

sbox=(
0x63,0x7c,0x77,0x7b,0xf2,0x6b,0x6f,0xc5,0x30,0x01,0x67,0x2b,0xfe,0xd7,0xab,0x76,
0xca,0x82,0xc9,0x7d,0xfa,0x59,0x47,0xf0,0xad,0xd4,0xa2,0xaf,0x9c,0xa4,0x72,0xc0,
0xb7,0xfd,0x93,0x26,0x36,0x3f,0xf7,0xcc,0x34,0xa5,0xe5,0xf1,0x71,0xd8,0x31,0x15,
0x04,0xc7,0x23,0xc3,0x18,0x96,0x05,0x9a,0x07,0x12,0x80,0xe2,0xeb,0x27,0xb2,0x75,
0x09,0x83,0x2c,0x1a,0x1b,0x6e,0x5a,0xa0,0x52,0x3b,0xd6,0xb3,0x29,0xe3,0x2f,0x84,
0x53,0xd1,0x00,0xed,0x20,0xfc,0xb1,0x5b,0x6a,0xcb,0xbe,0x39,0x4a,0x4c,0x58,0xcf,
0xd0,0xef,0xaa,0xfb,0x43,0x4d,0x33,0x85,0x45,0xf9,0x02,0x7f,0x50,0x3c,0x9f,0xa8,
0x51,0xa3,0x40,0x8f,0x92,0x9d,0x38,0xf5,0xbc,0xb6,0xda,0x21,0x10,0xff,0xf3,0xd2,
0xcd,0x0c,0x13,0xec,0x5f,0x97,0x44,0x17,0xc4,0xa7,0x7e,0x3d,0x64,0x5d,0x19,0x73,
0x60,0x81,0x4f,0xdc,0x22,0x2a,0x90,0x88,0x46,0xee,0xb8,0x14,0xde,0x5e,0x0b,0xdb,
0xe0,0x32,0x3a,0x0a,0x49,0x06,0x24,0x5c,0xc2,0xd3,0xac,0x62,0x91,0x95,0xe4,0x79,
0xe7,0xc8,0x37,0x6d,0x8d,0xd5,0x4e,0xa9,0x6c,0x56,0xf4,0xea,0x65,0x7a,0xae,0x08,
0xba,0x78,0x25,0x2e,0x1c,0xa6,0xb4,0xc6,0xe8,0xdd,0x74,0x1f,0x4b,0xbd,0x8b,0x8a,
0x70,0x3e,0xb5,0x66,0x48,0x03,0xf6,0x0e,0x61,0x35,0x57,0xb9,0x86,0xc1,0x1d,0x9e,
0xe1,0xf8,0x98,0x11,0x69,0xd9,0x8e,0x94,0x9b,0x1e,0x87,0xe9,0xce,0x55,0x28,0xdf,
0x8c,0xa1,0x89,0x0d,0xbf,0xe6,0x42,0x68,0x41,0x99,0x2d,0x0f,0xb0,0x54,0xbb,0x16)

In [None]:
def intermediate(pt, keyguess):
    return sbox[pt ^ keyguess]

pt = decimal_plaintext

numtraces = np.shape(traces)[0]-1
numpoint = np.shape(traces)[1]

#Use less than the maximum traces by setting numtraces to something
#numtraces = 15

bestguess = [0]*16

start_time = time.time()
#Set 16 to something lower (like 1) to only go through a single subkey & save time!
for bnum in range(0, 4):
    cpaoutput = [0]*256
    maxcpa = [0]*256
    for kguess in range(0, 256):
        # print ("Subkey %2d, hyp = %02x: "%(bnum, kguess))


        #Initialize arrays & variables to zero
        sumnum = np.zeros(numpoint)
        sumden1 = np.zeros(numpoint)
        sumden2 = np.zeros(numpoint)

        hyp = np.zeros(numtraces)
        for tnum in range(0, numtraces):
            hyp[tnum] = HW[intermediate(pt[tnum][bnum], kguess)]


        #Mean of hypothesis
        meanh = np.mean(hyp, dtype=np.float64)

        #Mean of all points in trace
        meant = np.mean(traces, axis=0, dtype=np.float64)

        #For each trace, do the following
        for tnum in range(0, numtraces):
            hdiff = (hyp[tnum] - meanh)
            tdiff = traces[tnum,:] - meant

            sumnum = sumnum + (hdiff*tdiff)
            sumden1 = sumden1 + hdiff*hdiff 
            sumden2 = sumden2 + tdiff*tdiff

        cpaoutput[kguess] = sumnum / np.sqrt( sumden1 * sumden2 )
        maxcpa[kguess] = max(abs(cpaoutput[kguess]))

        # print (maxcpa[kguess])

    #Find maximum value of key
    bestguess[bnum] = np.argmax(maxcpa)

print ("Best Key Guess: ")
for b in bestguess: 
    print ("%02x "%b)

end_time = time.time()
elapsed_time = end_time - start_time

print(f"Elapsed Time: {elapsed_time} seconds")

  cpaoutput[kguess] = sumnum / np.sqrt( sumden1 * sumden2 )
