# Part 4, Topic 2: CPA on Firmware Implementation of AES

**SUMMARY**: *Now that you've seen a CPA attack work, let's explore it in more detail. The goal of this lab will to do a CPA attack without using Analyzer.*


**LEARNING OUTCOMES:**
* Developing an algorithm based on a mathematical description
* Verify that correlation can be used to break a single byte of AES
* Extend the single byte attack to the rest of the key

**Requirements:** 
We'll be using a location value from the previous lab in this one, so make sure you've got it written down/recorded somewhere!

## Capturing Traces

Last time, we didn't go into much detail on capturing traces. This time, we'll take a closer look at the firmware. It uses something called simpleserial for communication, which is a simple communication protocol that we use for most of our firmware. We can use it to send different commands to the target and also to receive data. `simpleserial-aes` has two commands we care about:

* `'k'` - Set the key used for the AES implementation
* `'p'` - Send a plaintext to the target for it to encrypt. When the encryption is finished, it will respond with an `'r'` command.

We can use `target.simpleserial_write(<cmd>, <data>)` and `target.simpleserial.read(<cmd>, <dlen>)` to send and receive data. Basically, we want to:

1. Set the AES key used on the target using the `'k'` command
1. Arm the scope (`scope.arm()`)
1. Send the plaintext 
1. Capture the trace (`scope.capture()`). Note that this doesn't return the trace that gets captured, it returns whether (1) or not (0) the capture timed out
1. Read the ciphertext back

You can get the trace you just captured with `scope.get_last_trace()`. You'll want to do steps 2 through 5 multiple times.

## \#HARDWARE

In [None]:
# Set hardware settings
SCOPETYPE = 'OPENADC'
PLATFORM = 'CW308_SAM4S'
CRYPTO_TARGET='TINYAES128C' 
SS_VER='SS_VER_2_1'

In [None]:
# Connect to ChipWhisperer
%run "../Setup_Scripts/Setup_Generic.ipynb"

In [None]:
%%bash -s "$PLATFORM" "$CRYPTO_TARGET" "$SS_VER"
# compile firmware
cd ../../hardware/victims/firmware/simpleserial-aes
make PLATFORM=$1 CRYPTO_TARGET=$2 SS_VER=$3 -j

In [None]:
# program firmware onto target
cw.program_target(scope, prog, "../../hardware/victims/firmware/simpleserial-aes/simpleserial-aes-{}.hex".format(PLATFORM))

Unlike last time, where we used ChipWhisperer Projects, this time we'll be working with our traces as simple lists.

In [None]:
trace_array = []
textin_array = []

**IMPORTANT:** The code for this attack is a lot simpler if we only worry about a single point in time. You should fill in the value that you got from the last lab below:

In [None]:
from tqdm.notebook import trange
import numpy as np
import time

AES_VUL_LOC = ???

ktp = cw.ktp.Basic()
secret_aes_key, text = ktp.next()


# Code Block 1
target.simpleserial_write('k', secret_aes_key)

for i in trange(50, desc='Capturing traces'):
    target.flush()
    key, text = ktp.next()
    # Code Block 2
    scope.arm()
    
    target.simpleserial_write('p', text)
    
    ret = scope.capture()
    if ret:
        print("Target timed out!")
        continue
    
    response = target.simpleserial_read('r', 16)
    
    trace_array.append(scope.get_last_trace()[AES_VUL_LOC])
    textin_array.append(text[0])

You may want to plot some traces to make sure everything looks as expected:

In [None]:
cw.plot(trace_array[0]) * cw.plot(trace_array[1])

## \#SIMULATED

If you don't have hardware, you can instead load previously captured traces instead:

In [None]:
"""
import numpy as np
from tqdm.notebook import trange
import chipwhisperer as cw

aes_traces_50_tracedata = np.load(r"traces/lab4_2_traces.npy")
aes_traces_50_textindata = np.load(r"traces/lab4_2_textin.npy")
key = np.load(r"traces/lab4_2_key.npy")

trace_array = aes_traces_50_tracedata
textin_array = aes_traces_50_textindata
"""

## Attack Theory

We've seen in the slides that:

* The power consumed by an electronic device is related to the data being manipulated. 
* Storing a `1` consumes power, while storing a `0` doesn't. 
* This power consumption is linear. Storing 4 `1`s takes roughly twice as much power as storing 2 `1`s.
* In microcontrollers, this power consumption is generally unrelated to what was previously stored.

In a perfect scenario (no noise), if we measure the power consumption of the target when it's manipulating data, we end up knowing how many `1`s are being stored, known as the Hamming weight. Just the device loading the key doesn't do us much good here, as it's effectively impossible to know when the key is being loaded.

Instead, if the key is combined with some changing value that we know, we can use multiple power traces to fully figure out the key! We can do that by effectively replicating the function that generates our power trace, then running through with various key guesses until we get one that matches.

### Perfect Traces

Instead of AES, let's start off with an 8-bit addition. Let's also assume our power trace is only the Hamming weight portion. This means our power trace is:

**`power_trace = hamming_weight((secret_key + plaintext) & 0xFF)`**

or:

$T_{pwr}(X)={HW}(K+X)$

If we know T (our power trace) and X, we should be able to figure out K. You could trying to invert the function, but we actually lose some information via the Hamming weight. Instead, we need multiple T's and X's. The correct K is the one that always fits this equation. Let's generate a K:

In [None]:
import random
secret_key = random.randint(0, 255)

And let's make a function to generate T:

In [None]:
def hamming_weight(val):
    return bin(val).count("1")

def power_trace(text):
    key = secret_key # don't tell
    return hamming_weight((text + key) & 0xFF)

Then we can generate some traces using random X values:

In [None]:
num_traces = 10

texts = [random.randint(0,256) for _ in range(num_traces)]
traces = [power_trace(text) for text in texts]
print(traces)

Next, let's see if we can brute force the key:

In [None]:
import numpy as np

# recreate the function, but supply key guesses instead of the actual secret
def trace_guesser(guess, text):
    return hamming_weight((text + guess) & 0xFF)

# try all possible K
for guess in range(256):
    guess_traces = [trace_guesser(guess, text) for text in texts]
    if guess_traces == traces:
        print("Possible guess = {}".format(guess))

print(secret_key)

### Adding Noise

Hopefully you were able to recover the key from that. Obviously, this isn't super realistic - there's both noise and other things going on on the chip, so our power trace won't be nearly as nice. We also won't know "large" the effect of the data is on the power trace. Instead, let's change our model a bit by adding random noise in:

**`power_trace = data_pwr_size*hamming_weight((secret_key + X) & 0xFF) + noise`**

or

$T_{pwr}(X) = C*HW(K+X) + N$

Let's update our power_trace function so that it it matches this model:

In [None]:
import chipwhisperer as cw

scale_factor = 0.1 # how much the operation contributes to the power trace
noise_range = [-0.1, 0.1] # how much random noise to add

def power_trace(text):
    key = secret_key # don't tell
    data_component = hamming_weight((text + key) & 0xFF) * scale_factor
    noise_component = random.uniform(noise_range[0], noise_range[1])
    return data_component + noise_component

And we'll generate some traces:

In [None]:
num_traces = 1000 # larger number required here
texts = [random.randint(0,256) for _ in range(num_traces)]
traces = [power_trace(text) for text in texts]
cw.plot(traces)

Something you might notice about this equation is that it looks pretty linear. Noise should be effectively random, so if we take a bunch of power traces, we should end up with:

$T_{avg}(X) = C*HW(K+X)$

or, taking the Hamming weight, which we'll call D, as our variable,

$T_{avg}(D) = C*{D}$

So if we group by Hamming weight, we should have a "perfect" linear relationship. Let's put that to the test:


In [None]:
hw_groups = [0]*9

for hw in range (0,9): #HW min is 0, max is 8
    total = 0 # total number of traces in each group, so we can average
    for trace_number in range(len(traces)): # cycle through all traces
        if hamming_weight((secret_key + texts[trace_number])&0xFF) == hw: # if our hamming weight should be this value
            total += 1
            hw_groups[hw] += traces[trace_number]
    if total != 0:
        hw_groups[hw] /= total

cw.plot(hw_groups)

But if the key is something else, we'll no longer have this linear relationship:

In [None]:
wrong_secret = (secret_key  + 1) & 0xFF # make sure the secret is still 8 bits

hw_groups = [0]*9

for hw in range (0,9): #HW min is 0, max is 8
    total = 0 # total number of traces in each group, so we can average
    for trace_number in range(len(traces)): # cycle through all traces
        if hamming_weight((wrong_secret + texts[trace_number])&0xFF) == hw: # if our hamming weight should be this value
            total += 1
            hw_groups[hw] += traces[trace_number]
    if total != 0:
        hw_groups[hw] /= total

cw.plot(hw_groups)

So, theoretically at least, if we didn't know the key, we could do this plot with a bunch of guesses. The correct key should then be the one that's the most linear. This is okay, but it's kind of manual, but it'd be nice if we had some sort of operation we could run on the data to tell us which plot is the most linear.

As it turns out, there actually is an operation that'll do that for us: the Pearson correlation coefficient! Let's use that operation on all the possible key guesses and see if we get the right key out:

In [None]:
def get_correlation(traces, texts, key_guess):
    HW_array = []
    for i in range(len(texts)):
        HW_array.append(hamming_weight((key_guess + texts[i])&0xFF))

    hw_trace_array = np.array([traces, HW_array])
    return abs(np.corrcoef(hw_trace_array)[1][0])

best_guess = 0
best_corr = 0
for key_guess in range(0, 256):
    cur_corr = get_correlation(traces, texts, key_guess)
    if cur_corr > best_corr:
        best_guess = key_guess
        best_corr = cur_corr
        print("New best guess found: {:02X} (corr={})".format(best_guess, best_corr))
        
print("You guessed {:02X}, correct is {:02X}".format(best_guess, secret_key))

You should've gotten the right key out of that!

Alright, that's all well and good, but we're talking about AES here - surely it's a lot more complicated than just adding two 8-bit numbers together. We'll, let's look at the first operation of AES:

```C
for (uint8_t i = 0; i < 16; i++) {
    state[i] = plaintext[i] ^ key[i];
}
```

Uh-oh...

We could take our real power traces and real plaintexts and do the same thing, but it turns out that this correlation works a lot better if we could replace that XOR operation with something less linear. Luckily for us, the next thing AES does it try to make that state as nonlinear as possible by running it through a lookup table:

```C
extern uint8_t[256] sbox;
for (uint8_t i = 0; i < 16; i++) {
    state[i] = sbox[state[i]];
}
```

To attack the actual AES encryption, all we need to do is replace that 8-bit add with that XOR then lookup, and the key should pop out! Let's give it a try:

In [None]:
sbox = [
    # 0    1    2    3    4    5    6    7    8    9    a    b    c    d    e    f 
    0x63,0x7c,0x77,0x7b,0xf2,0x6b,0x6f,0xc5,0x30,0x01,0x67,0x2b,0xfe,0xd7,0xab,0x76, # 0
    0xca,0x82,0xc9,0x7d,0xfa,0x59,0x47,0xf0,0xad,0xd4,0xa2,0xaf,0x9c,0xa4,0x72,0xc0, # 1
    0xb7,0xfd,0x93,0x26,0x36,0x3f,0xf7,0xcc,0x34,0xa5,0xe5,0xf1,0x71,0xd8,0x31,0x15, # 2
    0x04,0xc7,0x23,0xc3,0x18,0x96,0x05,0x9a,0x07,0x12,0x80,0xe2,0xeb,0x27,0xb2,0x75, # 3
    0x09,0x83,0x2c,0x1a,0x1b,0x6e,0x5a,0xa0,0x52,0x3b,0xd6,0xb3,0x29,0xe3,0x2f,0x84, # 4
    0x53,0xd1,0x00,0xed,0x20,0xfc,0xb1,0x5b,0x6a,0xcb,0xbe,0x39,0x4a,0x4c,0x58,0xcf, # 5
    0xd0,0xef,0xaa,0xfb,0x43,0x4d,0x33,0x85,0x45,0xf9,0x02,0x7f,0x50,0x3c,0x9f,0xa8, # 6
    0x51,0xa3,0x40,0x8f,0x92,0x9d,0x38,0xf5,0xbc,0xb6,0xda,0x21,0x10,0xff,0xf3,0xd2, # 7
    0xcd,0x0c,0x13,0xec,0x5f,0x97,0x44,0x17,0xc4,0xa7,0x7e,0x3d,0x64,0x5d,0x19,0x73, # 8
    0x60,0x81,0x4f,0xdc,0x22,0x2a,0x90,0x88,0x46,0xee,0xb8,0x14,0xde,0x5e,0x0b,0xdb, # 9
    0xe0,0x32,0x3a,0x0a,0x49,0x06,0x24,0x5c,0xc2,0xd3,0xac,0x62,0x91,0x95,0xe4,0x79, # a
    0xe7,0xc8,0x37,0x6d,0x8d,0xd5,0x4e,0xa9,0x6c,0x56,0xf4,0xea,0x65,0x7a,0xae,0x08, # b
    0xba,0x78,0x25,0x2e,0x1c,0xa6,0xb4,0xc6,0xe8,0xdd,0x74,0x1f,0x4b,0xbd,0x8b,0x8a, # c
    0x70,0x3e,0xb5,0x66,0x48,0x03,0xf6,0x0e,0x61,0x35,0x57,0xb9,0x86,0xc1,0x1d,0x9e, # d
    0xe1,0xf8,0x98,0x11,0x69,0xd9,0x8e,0x94,0x9b,0x1e,0x87,0xe9,0xce,0x55,0x28,0xdf, # e
    0x8c,0xa1,0x89,0x0d,0xbf,0xe6,0x42,0x68,0x41,0x99,0x2d,0x0f,0xb0,0x54,0xbb,0x16  # f
]

def add_round_key(key, text):
    return key^text

def sub_bytes(state):
    return sbox[state]

def mini_aes(guess, text):
    st = add_round_key(guess, text)
    st = sub_bytes(st)
    return hamming_weight(st)

def get_correlation(traces, texts, key_guess):
    HW_array = []
    for i in range(len(texts)):
        HW_array.append(mini_aes(key_guess, texts[i]))

    hw_trace_array = np.array([traces, HW_array])
    return abs(np.corrcoef(hw_trace_array)[1][0])

best_guess = 0
best_corr = 0
for key_guess in range(0, 256):
    cur_corr = get_correlation(trace_array, text_array, key_guess)
    if cur_corr > best_corr:
        best_guess = key_guess
        best_corr = cur_corr
        print("New best guess found: {:02X} (corr={})".format(best_guess, best_corr))
        
print("You guessed {:02X}, correct is {:02X}".format(best_guess, secret_aes_key[0]))

Hopefully the "you guessed" and "correct" key bytes are the same. If so, congrats, you've got the first byte of the key! It made things simpler, but it's not actually too much work to get past our simplifications from earlier:

1. If you've got a power trace and you don't know where the SBox operation is, if you repeat this correlation calculation at each point in time, only the one where the SBox is happening will have this large correlation.
2. You can recover the rest of the bytes of the key by changing out the first byte of the plaintext with subsequent bytes (so plaintext[i] will give key[i])

In [None]:
scope.dis()
target.dis()