# Part 4, Topic 2: CPA on Firmware Implementation of AES

**SUMMARY**: *Now that you've seen a CPA attack work, let's explore it in more detail. The goal of this lab will to do a CPA attack without using Analyzer.*


**LEARNING OUTCOMES:**
* Developing an algorithm based on a mathematical description
* Verify that correlation can be used to break a single byte of AES
* Extend the single byte attack to the rest of the key

**Requirements:** 
We'll be using a location value from the previous lab in this one, so make sure you've got it written down/recorded somewhere!

## Capturing Traces

Last time, we didn't go into much detail on capturing traces. This time, we'll take a closer look at the firmware. It uses something called simpleserial for communication, which is a simple communication protocol that we use for most of our firmware. We can use it to send different commands to the target and also to receive data. `simpleserial-aes` has two commands we care about:

* `'k'` - Set the key used for the AES implementation
* `'p'` - Send a plaintext to the target for it to encrypt. When the encryption is finished, it will respond with an `'r'` command.

We can use `target.simpleserial_write(<cmd>, <data>)` and `target.simpleserial.read(<cmd>, <dlen>)` to send and receive data. Basically, we want to:

1. Set the AES key used on the target using the `'k'` command
1. Arm the scope (`scope.arm()`)
1. Send the plaintext 
1. Capture the trace (`scope.capture()`). Note that this doesn't return the trace that gets captured, it returns whether (1) or not (0) the capture timed out
1. Read the ciphertext back

You can get the trace you just captured with `scope.get_last_trace()`. You'll want to do steps 2 through 5 multiple times.

## \#HARDWARE

In [None]:
# Set hardware settings
SCOPETYPE = 'OPENADC'
PLATFORM = 'CW308_SAM4S'
CRYPTO_TARGET='TINYAES128C' 
SS_VER='SS_VER_2_1'

In [None]:
# Connect to ChipWhisperer
%run "../Setup_Scripts/Setup_Generic.ipynb"

In [None]:
%%bash -s "$PLATFORM" "$CRYPTO_TARGET" "$SS_VER"
# compile firmware
cd ../../hardware/victims/firmware/simpleserial-aes
make PLATFORM=$1 CRYPTO_TARGET=$2 SS_VER=$3 -j

In [None]:
# program firmware onto target
cw.program_target(scope, prog, "../../hardware/victims/firmware/simpleserial-aes/simpleserial-aes-{}.hex".format(PLATFORM))

Unlike last time, where we used ChipWhisperer Projects, this time we'll be working with our traces as simple lists.

In [None]:
trace_array = []
textin_array = []

In [None]:
from tqdm.notebook import trange
import numpy as np
import time

ktp = cw.ktp.Basic()
real_aes_key, text = ktp.next()

# Code Block 1
target.simpleserial_write('k', secret_aes_key)

for i in trange(50, desc='Capturing traces'):
    target.flush()
    key, text = ktp.next()
    # Code Block 2
    scope.arm()
    
    target.simpleserial_write('p', text)
    
    ret = scope.capture()
    if ret:
        print("Target timed out!")
        continue
    
    response = target.simpleserial_read('r', 16)
    
    trace_array.append(scope.get_last_trace())
    textin_array.append(text[0])

You may want to plot some traces to make sure everything looks as expected:

In [None]:
cw.plot(trace_array[0]) * cw.plot(trace_array[1])

## \#SIMULATED

If you don't have hardware, you can instead load previously captured traces instead:

In [None]:
"""
import numpy as np
from tqdm.notebook import trange
import chipwhisperer as cw

aes_traces_50_tracedata = np.load(r"traces/lab4_2_traces.npy")
aes_traces_50_textindata = np.load(r"traces/lab4_2_textin.npy")
key = np.load(r"traces/lab4_2_key.npy")

trace_array = aes_traces_50_tracedata
textin_array = aes_traces_50_textindata
"""

## Attack Theory

### Why AES Is Secure

Before we can get into why our attack from the previous lab works, we need to talk a little bit about AES and why it's secure. Don't worry, we won't be getting into all the gory details
here, just the bare minimum to compare our attack to normal AES operation. AES is essentially just a repeated application (called a round) of four operations: AddRoundKey, SubBytes, ShiftRows, and
MixColumns acting on an internal state, which begins as the plaintext and ends as the ciphertext:

<img src="img/AES_Block_Diagram.png" alt="SPA of Power Analysis"/>

These operations do the following:

1. AddRoundKey XORs the AES state with the current round key. A new key is generated for each round using the previous round key.
1. SubBytes puts each byte of the state through a lookup table
1. ShiftRows shuffles the bytes around
1. MixColumns combines bytes of the state together using a matrix multiply. This is done in groups of 4 bytes

Each of these operations is important in the following way:

1. AddRoundKey combines the key with the state
2. SubBytes prevents a type of attack called linear cryptanalysis by breaking the linear relationship between its input and output
3. MixColumns prevents a single byte of the output from being solely associated with a single byte of the input/key. This is a bit like the difference between 16 single character passwords
   versus a single 16 character long password.
4. ShiftRows makes it so that MixColumns isn't combining the same sets of 4 bytes over and over again.

From an attacking point of view, we'd really like to peek into the state after the AddRoundKey happens, but before any MixColumns happens, as that instantly takes our attack from
`16*2^8` to `4*2^32`. If you think back to Lab 1, the 5 single character passwords were much easier to break than the 5 character long password. This puts our point of attack at somewhere
during the first round, so either after AddRoundKey, and SubBytes, or after ShiftRows. We'll also assume that we know the plaintext.

### Peeking Into the State

We've seen in the slides that:

* The power consumed by an electronic device is related to the data being manipulated. 
* Storing a `1` consumes power, while storing a `0` doesn't. 
* This power consumption is linear. Storing 4 `1`s takes roughly twice as much power as storing 2 `1`s.
* In microcontrollers, this power consumption is generally unrelated to what was previously stored.

In a perfect scenario (no noise), if we measure the power consumption of the target when it's manipulating data, we end up knowing how many `1`s are being stored, known as the Hamming weight. Just the device loading the key doesn't do us much good here, as it's effectively impossible to know when the key is being loaded.

### Perfect Traces

Let's take a look at a simplified situation: we'll take a look at just one byte after AddRoundKey and assume there's no noise:

**`power_trace = hamming_weight(secret_key ^ plaintext)`**

or:

$T_{pwr}(X)={HW}({XOR}(K,X))$

If we know T (our power trace) and X, we should be able to figure out K. You could try to invert the function, but we actually lose some information via the Hamming weight. Instead, we can take multiple sets of `T`s and `X`s to figure out `K`.

In [None]:
import random
secret_key = random.randint(0, 255)

And let's make a function to generate T:

In [None]:
def hamming_weight(val):
    return bin(val).count("1")

def power_trace(text):
    key = secret_key # don't tell
    return hamming_weight(text ^ key)

Then we can generate some traces using random X values:

In [None]:
num_traces = 10

texts = [random.randint(0,255) for _ in range(num_traces)]
traces = [power_trace(text) for text in texts]
print(traces)
cw.plot(traces)

The simplest strategy here is just going to be taking a guess at a key and check against the power traces that we have:

In [None]:
import numpy as np

# recreate the function, but supply key guesses instead of the actual secret
def trace_guesser(guess, text):
    return hamming_weight(text ^ guess)

# try all possible K
for guess in range(256):
    guess_traces = [trace_guesser(guess, text) for text in texts]
    if guess_traces == traces:
        print("Possible guess = {}".format(guess))

print(secret_key)

### Adding Noise

Hopefully you were able to recover the key from that. Obviously, this isn't super realistic - there's both noise and other things going on on the chip, so our power trace won't be nearly as nice. We also won't knowhow "large" the effect of the data is on the power trace. Instead, let's change our model a bit by adding random noise in:

**`power_trace = data_pwr_size*hamming_weight((secret_key + X)) + noise`**

or

$T_{pwr}(X) = C*HW({XOR}(K,X)) + N$

Let's update our power_trace function so that it it matches this model:

In [None]:
import chipwhisperer as cw

scale_factor = 0.1 # how much the operation contributes to the power trace
noise_range = [-0.1, 0.1] # how much random noise to add

def power_trace(text):
    key = secret_key # don't tell
    data_component = hamming_weight(text ^ key) * scale_factor
    noise_component = random.uniform(noise_range[0], noise_range[1])
    return data_component + noise_component

And we'll generate some traces:

In [None]:
num_traces = 1000 # larger number required here
texts = [random.randint(0,255) for _ in range(num_traces)]
traces = [power_trace(text) for text in texts]
cw.plot(traces)

Something you might notice about this equation is that it looks pretty linear. Noise should be effectively random, so if we take a bunch of power traces, we should end up with:

$T_{avg}(X) = C*HW(XOR(K,X))$

or, taking the Hamming weight, which we'll call D, as our variable,

$T_{avg}(D) = C*{D}$

So if we group by Hamming weight, we should have a "perfect" linear relationship. Let's put that to the test:


In [None]:
hw_groups = [0]*9

for hw in range (0,9): #HW min is 0, max is 8
    total = 0 # total number of traces in each group, so we can average
    for trace_number in range(len(traces)): # cycle through all traces
        if hamming_weight(secret_key ^ texts[trace_number]) == hw: # if our hamming weight should be this value
            total += 1
            hw_groups[hw] += traces[trace_number]
    if total != 0:
        hw_groups[hw] /= total

cw.plot(hw_groups)

But if the key is something else, we'll no longer have this linear relationship:

In [None]:
wrong_secret = (secret_key + 45) & 0xFF # make sure the secret is still 8 bits

hw_groups = [0]*9

for hw in range (0,9): #HW min is 0, max is 8
    total = 0 # total number of traces in each group, so we can average
    for trace_number in range(len(traces)): # cycle through all traces
        if hamming_weight(wrong_secret ^ texts[trace_number]) == hw: # if our hamming weight should be this value
            total += 1
            hw_groups[hw] += traces[trace_number]
    if total != 0:
        hw_groups[hw] /= total

cw.plot(hw_groups)

One issue that you're going to run into here is that this isn't really all that true for certain guesses of the key. In fact, key guesses and their binary inverse
will actually have the exact same plot with an inverted slope. This is mostly down to XOR being a linear binary operation. Luckily for us, the very next operation, SubBytes,
in AES is specifically in there to break any linear relationship between its input and output! Let's try repeating the above block, but adding the SubBytes operation:

In [None]:
import chipwhisperer as cw

scale_factor = 0.1 # how much the operation contributes to the power trace
noise_range = [-0.1, 0.1] # how much random noise to add

sbox = [
    # 0    1    2    3    4    5    6    7    8    9    a    b    c    d    e    f 
    0x63,0x7c,0x77,0x7b,0xf2,0x6b,0x6f,0xc5,0x30,0x01,0x67,0x2b,0xfe,0xd7,0xab,0x76, # 0
    0xca,0x82,0xc9,0x7d,0xfa,0x59,0x47,0xf0,0xad,0xd4,0xa2,0xaf,0x9c,0xa4,0x72,0xc0, # 1
    0xb7,0xfd,0x93,0x26,0x36,0x3f,0xf7,0xcc,0x34,0xa5,0xe5,0xf1,0x71,0xd8,0x31,0x15, # 2
    0x04,0xc7,0x23,0xc3,0x18,0x96,0x05,0x9a,0x07,0x12,0x80,0xe2,0xeb,0x27,0xb2,0x75, # 3
    0x09,0x83,0x2c,0x1a,0x1b,0x6e,0x5a,0xa0,0x52,0x3b,0xd6,0xb3,0x29,0xe3,0x2f,0x84, # 4
    0x53,0xd1,0x00,0xed,0x20,0xfc,0xb1,0x5b,0x6a,0xcb,0xbe,0x39,0x4a,0x4c,0x58,0xcf, # 5
    0xd0,0xef,0xaa,0xfb,0x43,0x4d,0x33,0x85,0x45,0xf9,0x02,0x7f,0x50,0x3c,0x9f,0xa8, # 6
    0x51,0xa3,0x40,0x8f,0x92,0x9d,0x38,0xf5,0xbc,0xb6,0xda,0x21,0x10,0xff,0xf3,0xd2, # 7
    0xcd,0x0c,0x13,0xec,0x5f,0x97,0x44,0x17,0xc4,0xa7,0x7e,0x3d,0x64,0x5d,0x19,0x73, # 8
    0x60,0x81,0x4f,0xdc,0x22,0x2a,0x90,0x88,0x46,0xee,0xb8,0x14,0xde,0x5e,0x0b,0xdb, # 9
    0xe0,0x32,0x3a,0x0a,0x49,0x06,0x24,0x5c,0xc2,0xd3,0xac,0x62,0x91,0x95,0xe4,0x79, # a
    0xe7,0xc8,0x37,0x6d,0x8d,0xd5,0x4e,0xa9,0x6c,0x56,0xf4,0xea,0x65,0x7a,0xae,0x08, # b
    0xba,0x78,0x25,0x2e,0x1c,0xa6,0xb4,0xc6,0xe8,0xdd,0x74,0x1f,0x4b,0xbd,0x8b,0x8a, # c
    0x70,0x3e,0xb5,0x66,0x48,0x03,0xf6,0x0e,0x61,0x35,0x57,0xb9,0x86,0xc1,0x1d,0x9e, # d
    0xe1,0xf8,0x98,0x11,0x69,0xd9,0x8e,0x94,0x9b,0x1e,0x87,0xe9,0xce,0x55,0x28,0xdf, # e
    0x8c,0xa1,0x89,0x0d,0xbf,0xe6,0x42,0x68,0x41,0x99,0x2d,0x0f,0xb0,0x54,0xbb,0x16  # f
]

def power_trace(text):
    #print(text)
    key = secret_key # don't tell
    data_component = hamming_weight(sbox[text ^ key]) * scale_factor
    noise_component = random.uniform(noise_range[0], noise_range[1])
    return data_component + noise_component

num_traces = 1000 # larger number required here
texts = [random.randint(0,255) for _ in range(num_traces)]
traces = [power_trace(text) for text in texts]

hw_groups = [0]*9

for hw in range (0,9): #HW min is 0, max is 8
    total = 0 # total number of traces in each group, so we can average
    for trace_number in range(len(traces)): # cycle through all traces
        if hamming_weight(sbox[secret_key ^ texts[trace_number]]) == hw: # if our hamming weight should be this value
            total += 1
            hw_groups[hw] += traces[trace_number]
    if total != 0:
        hw_groups[hw] /= total

cw.plot(hw_groups)

In [None]:
wrong_secret = (secret_key + 2) & 0xFF # make sure the secret is still 8 bits

hw_groups = [0]*9

for hw in range (0,9): #HW min is 0, max is 8
    total = 0 # total number of traces in each group, so we can average
    for trace_number in range(len(traces)): # cycle through all traces
        if hamming_weight(sbox[wrong_secret ^ texts[trace_number]]) == hw: # if our hamming weight should be this value
            total += 1
            hw_groups[hw] += traces[trace_number]
    if total != 0:
        hw_groups[hw] /= total

cw.plot(hw_groups)

Now we're getting somewhere!

So, theoretically at least, if we didn't know the key, we could do this plot with a bunch of guesses. The correct key should then be the one that's the most linear. This is okay, but is there any way we could automate this process? Indeed there is!

As it turns out, there actually is an operation that'll do that for us: the Pearson correlation coefficient! Let's use that operation on all the possible key guesses and see if we get the right key out:

In [None]:
def get_correlation(traces, texts, key_guess):
    HW_array = []
    for i in range(len(texts)):
        HW_array.append(hamming_weight(sbox[key_guess ^ texts[i]]))

    hw_trace_array = np.array([traces, HW_array])
    return abs(np.corrcoef(hw_trace_array)[1][0])

best_guess = 0
best_corr = 0
for key_guess in range(0, 256):
    cur_corr = get_correlation(traces, texts, key_guess)
    if cur_corr > best_corr:
        best_guess = key_guess
        best_corr = cur_corr
        print("New best guess found: {:02X} (corr={})".format(best_guess, best_corr))
        
print("You guessed {:02X}, correct is {:02X}".format(best_guess, secret_key))

You should've gotten the right key out of that! At this point, we're 99% of the way to a full attack! Really, the only difference between this and a full attack is that we
A) don't know when the AES operation is in our power trace and B) we're only targeting a single byte. Both things turn out to be really easy to work around.

We can also use our guess and check to figure out when the AES operation we care about is happening. This is because the data will be effectively random,
just like if we feed in an incorrect key guess. This means that we just have to repeat the key guess for each point in time in our power trace. Let's work with our real power
trace data now:

To get the rest of the bytes of the key, you just need to repeat this same attack with the other bytes of the text. We won't do that in this lab, but if you're a bit ahead, feel
free to give it a try.

In [None]:
scope.dis()
target.dis()