<h1>CBC Padding Oracle Attack & SHA Length Extension Attack Demo Lab</h1>

v1.0: Ryan Lehmkuhl, Ben Hoberman

v2.0: Ben Hoberman, Peyrin Kao, EvanBot

Your grade for this lab will the maximum of your score on this lab, your overall homework score, and your overall project score. This means this lab is effectively optional, so don't stress too much and have some fun with it.

# Storyline

`[Continued from Project 1]`

`Thanks to your heroic efforts, EvanBot is now working alongside you at the Caltopian Space Agency to learn more about the Jupiter mission. One day, EvanBot approaches you with a rather unusual request.`

`The CSA cafeteria is famous for serving chef Brown's "fluffy yet crunchy (?) pancakes", created with a secret ingredient. EvanBot can't get enough of these pancakes and must know what the secret ingredient is.`

`Luckily for you, Bot has intercepted a message from the highest levels of CSA intelligence, and Bot believes a part of this message may contain chef Brown's signature recipe.`

`There's just one tiny problem: this message is encrypted with more entropy than you could possibly dream of brute-forcing.`

# Introduction

Trying to implement crypto schemes yourself can be very dangerous. In the words of Runa Sandvi, The New York Times senior director of information security, "Asking why you should not roll your own crypto is a bit like asking why you should not design your own aircraft engine."

While many of the concepts we've reviewed in class may seem intuitive and simple, even the subtlest leakage of information can completely compromise any hope for confidentiality.

We will demonstrate this by completely decrypting a message encrypted with AES-CBC using a **padding oracle attack**.

# Setup

Before you get started, run the following block to install packages needed for this notebook:

In [1]:
%pip install cryptography flask requests

Note: you may need to restart the kernel to use updated packages.


In [2]:
import base64
import os
from IPython.display import clear_output
from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives.ciphers import Cipher
from cryptography.hazmat.primitives.ciphers.algorithms import AES
from cryptography.hazmat.primitives.ciphers.modes import ECB
from requests.adapters import HTTPAdapter
from requests import Session
from tests import test1, test2, test3, test4, test5, test6, test7
from helpers import PKCS7_pad, PKCS7_unpad, valid_pad, permute
%load_ext autoreload
%autoreload 2

# Question 1: Padding

Recall from Homework 2, Q4 that the AES-128 block cipher can only encrypt 16-byte messages. This means that block chaining modes such as CBC mode require the message length to be a multiple of 16. If the message we want to encrypt is not a multiple of 16, we need to add padding.

In Q4.1, Q4.2, and Q4.3, we saw that some padding and de-padding algorithms lead to ambiguity over what your original message was. We also saw that the PKCS#7 padding algorithm (PKCS stands for Public Key Cryptography Standard) always returns your original message when de-padding.

The PKCS#7 algorithm appends the number of padding bytes to the end of the message. For example, the message `PANCAKES` is padded to become:

```
0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
P A N C A K E S 8 8 8  8  8  8  8  8
```

Note that in if the message is a multiple of the block size, another block of all 16s (```0x10```) is added.

**Try replacing `msg` with some different messages to get a feel for how the padding algorithm works.**

In [5]:
### TODO: Replace the message to see how it's padded
msg = b"FROM: CSA HQ fjdsalfksd"

# Pads the message and splits it into blocks
padded_msg = PKCS7_pad(msg)
plaintext = [msg[i:i + 16].decode() for i in range(0, len(msg), 16)]
padded_plaintext = [padded_msg[i:i + 16] for i in range(0, len(padded_msg), 16)]

# Prints everything - decodes the blocks of byte objects to strings
print(f"Plaintext without padding: {plaintext}")
print(f"Length: {len(msg)}")
print(f"Plaintext with padding: {[x.decode() for x in padded_plaintext]}")
print(f"Length: {len(padded_msg)}\n")
print(f"Hex of padded message:")
for i in range(len(padded_plaintext)):
    print(f"Block {i}: ", " ".join('0x{:02x}'.format(c) for c in padded_plaintext[i]))

Plaintext without padding: ['FROM: CSA HQ fjd', 'salfksd']
Length: 23
Plaintext with padding: ['FROM: CSA HQ fjd', 'salfksd\t\t\t\t\t\t\t\t\t']
Length: 32

Hex of padded message:
Block 0:  0x46 0x52 0x4f 0x4d 0x3a 0x20 0x43 0x53 0x41 0x20 0x48 0x51 0x20 0x66 0x6a 0x64
Block 1:  0x73 0x61 0x6c 0x66 0x6b 0x73 0x64 0x09 0x09 0x09 0x09 0x09 0x09 0x09 0x09 0x09


## 1.1

In Q4.5, we thought about what happens to the padding if an attacker modifies the last byte of a padded plaintext message.

**Below, choose a byte for the last byte of the padded message that will cause `valid_pad` to return `False`.**

In [25]:
# Length of this message: 60
msg = b"TITLE: Company-wide implementation of zero food waste policy" # do not change this

### TODO: Choose a last byte (written as a number) that makes the padding invalid
invalid_last_byte = 60 - len(msg)

# Autograder
test1(invalid_last_byte)

All tests passed!


## 1.2

Also in Q4.5, we saw one way an attacker could change the last byte to a constant value that guarantees that the message has valid padding.

**Below, give *two* different last bytes for the padded message which will cause `valid_pad` to return `True`.**

In [42]:
# Length of this message: 108
msg = b"As part of our ongoing environmental initiative, we are switching our canteens to zero food waste operation."
# do not change this

### TODO: Choose two different last bytes (written as a number) that makes the padding valid
valid_last_byte_1 = 16 - len(msg) % 16 ### YOUR CODE HERE ###
valid_last_byte_2 = 1 ### YOUR CODE HERE ###

# Autograder
test2(valid_last_byte_1, valid_last_byte_2)

All tests passed!


# Question 2: Padding Oracles

We now know why padding is important and how to implement it, but how can a faulty implementation break our crypto system? For this, recall the **padding oracle** from Homework 2.

In cryptography, an oracle is a queryable 'black box' (a function with unknown inner-workings) which provides some piece of information which otherwise would not be available. For example, when we studied the IND-CPA game, the challenger acted as an **encryption oracle** since the adversary could query it on a given message and receive back a ciphertext without knowing how the encryption was done (i.e. which key was used). 

A padding oracle takes some ciphertext `c` as input, and returns `True` if the ciphertext (which is decrypted using the secret shared key) is properly padded and `False` otherwise. We've defined a function, `valid_pad` which acts as such a padding oracle for PKCS7.

Many real systems naturally act as padding oracles. Consider a web server which uses AES-CBC to encrypt communications between its clients (early versions of TLS did this). If a client sends a message with invalid padding, the exception might cause the web server to respond with something like:
<center>
<img src="https://i.imgur.com/IFhVUbJ.png" align="center" style="height:400px" />
</center>

Why is this bad? At a fundamental level, the resulting error leaks information about the plaintext which should never be allowed in any cryptographic system. But come on... detecting incorrect padding can't be that bad? Right? **WRONG.** This simple leakage ends up completely destroying any hope for confidentiality with the encryption scheme. To see why, let's review how the CBC block chaining mode works.

Recall that decryption in CBC mode is as follows:
<center>
<img src="https://i.imgur.com/CRUh4nu.png" align="center" style="height:300px" />
</center>

In particular, we are interested in the decryption of a single block - especially the temporary block state that occurs before the XOR:
<center>
<img src="https://raw.githubusercontent.com/cs161-staff/labs/master/padding_oracle/cbc_decrypt_block.png" align="center" style="height:400px" />
</center>

## 2.1

In Homework 2, Q5, you found an expression for $P_n$ in terms of $C_n$ and $C_{n-1}$.

**Implement the expression you found for $P_n$ in the code block below.**

In [52]:
from helpers import xor_block
# xor_block(block1, block2) automatically xors two blocks of bytes together

def P_from_DCC(D, C, C_prev):
    """
    Compute P_n from D(.), C_n, and C_{n-1}.

    D:         a block cipher decryption function.
    C:         C_n, a 16-byte block of text
    C_prev:    C_{n-1}, the 16-byte ciphertext preceding C
    """
    T = D(C) ### YOUR CODE HERE ###
    P = xor_block(T, C_prev) ### YOUR CODE HERE ###

    return P

# Autograder
test3(P_from_DCC)

All tests passed!
[storyline message] Given the overwhelming success of this policy at HQ, we are excited to expand our ongoing commitment.


## 2.2

Now, we're going to start building towards a full-fledged attack based on this decryption process.

Assume that you have intercepted some ciphertext $(IV, C_1, C_2, \ldots, C_n)$ and have access to a padding oracle. You have complete freedom with what you send the padding oracle (ie. a subset of the ciphertext blocks or something completely different). Whatever you send, the padding oracle will decrypt it **using the original symmetric key** and truthfully report whether it is padded correctly.

For now, ignore all of the blocks except for the last block $C_n$.
<center>
<img src="https://raw.githubusercontent.com/cs161-staff/labs/master/padding_oracle/cbc_decrypt_byte.png" align="center" style="height:400px" />
</center>

## 2.3

In HW2, Q6.1, you found a way to change $C_{n-1}[15]$ such that the decrypted message has valid padding no matter what.

**Implement your answer from HW2, Q6.1 in the following function to test it out.**

In [56]:
def pad_correctly(T_byte):
    """
    Computes C'_{n-1}[15] given T_n[15] which results
    in a correctly padded P_n[15].
    
    T_byte: T_n[15]
    """
    # Hint: ^ is XOR in Python.
    return T_byte ^ 1 ### YOUR CODE HERE ###

# Autograder
test4(pad_correctly)

All tests passed!
[storyline message] As the policy name suggests, no food waste is allowed in both cooking and dining facilities.


## 2.4

Your answer from the previous part requires knowing $T_n[15]$, but normally you wouldn't know this value.

**How might you leverage the padding oracle and the `pad_correctly` function to learn the value of $C_{n-1}'[15]$?**

*Hint: a brute force attack that only needs to try a small number of options is very fast. 256 is very small.*

**<span style = "color: red">we can use a for loop to brute force find out the value of Cn-1[15]</span>**

## 2.5

In HW2, Q6.2, you found a way to learn the corresponding byte of plaintext $P_n[15]$ given $C'_{n-1}[15]$ and $C_{n-1}[15]$.

**Implement your answer from HW2, Q6.2 to calculate the corresponding byte of plaintext.**

In [63]:
def recover_byte(C_byte, C_prime_byte):
    """
    Given the modified ciphertext C'_{n-1}[15] which resulted in
    the padding always being correct, and the original ciphertext
    C_{n-1}[15], return the last byte of the original plaintext P_n[15].
    
    C_byte: C_{n-1}[15] - the original last byte of the ciphertext
    C_prime_byte: C'_{n-1}[15] - the modified last byte which resulted in correct padding all of the time
    """
    return C_byte ^ pad_correctly(C_prime_byte) ### YOUR CODE HERE ###

# Autograder --- make sure you've completed 2.3 first! This test requires your implementation of pad_correctly.
test5(recover_byte, pad_correctly)

All tests passed!
[storyline message] To enforce this policy, we are removing all trash cans from the kitchen and the dining hall.


## 2.6

In 2.4, you found a value $C'_{n-1}[15]$ that results in valid padding regardless of the plaintext message. However, it is sometimes possible for a *different* value of $C'_{n-1}[15]$ to result in valid padding, depending on the plaintext message $P_n$.

**When does this different value result in valid padding?**

*Hint: how did you come up with the two answers for question 1.2 in this lab?*

**<span style = "color: red">when the value is 1, it would be same</span>**

## 2.7

Suppose you get two possible values for $C_n'[15]$ that result in valid padding. One of them results in valid padding regardless of the message (2.3), and the other results in valid padding only for this specific message (2.6). **How can we modify $C_n[14]$** to check which value is which?

*Hint: There are over 200 values that work here*

**<span style = "color: red">when we padding, we assign some specific value like "$" or "#" to Cn[14]</span>**

## 2.8

**Fill in the function so that it always correctly finds the last byte of plaintext based on the last two blocks of ciphertext and a padding oracle.**

In [117]:
def solve_last_byte(C_prev, C, oracle):
    """
    Returns the correct last byte of the original plaintext
    P_n[15] given the previous the ciphertexts C_n, C_{n-1},
    and a padding oracle
    
    C: current block (the one whose plaintext we're solving for) of the ciphertext (C_n)
    C_prev: previous block of the cipher text (C_{n-1})
    oracle: a function which returns whether (C_prev, C) are padded correctly
    """
    # Loop through the possible values for byte
    for byte in C: ### YOUR CODE HERE ###
        C_prev_prime = bytearray(C_prev)
        # Change the last byte of C_{n-1}
        C[15] = C_prev_prime[15] ### YOUR CODE HERE ###
        # If the oracle reports the modified ciphertext has valid padding...
        if oracle(C_prev_prime, C):
            C_prev_prime[14] = "a" ### YOUR CODE HERE ###
            # ...check that we get valid padding regardless of the message
            if oracle(C_prev_prime, C):
                return recover_byte(C_prev[15], C_prev_prime[15])
            
# Autograder
test6(solve_last_byte)

14


TypeError: 'bytes' object does not support item assignment

## 2.9

In HW2, Q6.3 and Q6.4, you saw how to modify the previous steps to learn the second-to-last byte of the plaintext message.

**How could you extend the previous steps to decode an entire block? An entire message?**

**<span style = "color: red">(YOUR ANSWER HERE)</span>**

## 2.10

Now, we can finally implement the full decryption.

**Fill out the function below based on prior parts in order to successfully decrypt any CBC ciphertext block based on a padding oracle.**

<center>
<img src="https://raw.githubusercontent.com/cs161-staff/labs/master/padding_oracle/cbc_decrypt_byte.png" align="center" style="height:400px" />
</center>

Remember that as you discover more of the message, the "correct" padding you need to enforce will change.

In [105]:
def decrypt_block(C_prev, C, oracle, display = False, tail = None):
    """
    Recover plaintext P_n given ciphertext C_n, C_{n-1}, and a padding oracle.
    Don't worry about the display or tail arguments -- they're for cool visualization later.

    C: current block (the one we're solving for) of the ciphertext (C_i)
    C_prev: last block of the cipher text (C_{i-1})
    oracle: a function which returns whether (C_prev, C) are padded correctly
    """
    correct_block = bytearray(16) # Reconstructed plaintext (P from previous parts)
    temp_block = bytearray(16) # Reconstructed temporary state (T from previous parts)
    
    # Iterate over the block from end to beginning
    for i in reversed(range(16)):
        # Set the padding byte that the known bytes should be to guarantee correct padding
        # What is the correct plaintext padding for a block which has n < 16 bytes of text?
        padding_byte = ... ### YOUR CODE HERE ###
        
        # Loop through the possible values for byte
        for byte in ...: ### YOUR CODE HERE ###
            C_prev_prime = bytearray(C_prev)
            # How do we set the already-known plaintext to a value of our choice? ###
            # Hint: you'll need the temporary state
            C_prev_prime[i+1:] = xor_block(..., [padding_byte]*padding_byte) ### YOUR CODE HERE ###
            
            C_prev_prime[i] = byte
            if oracle(C_prev_prime, C):
                if i == 15: # Recall from 2.6 that the last byte can have two possible values.
                    C_prev_prime[i-1] = ... ### YOUR CODE HERE ###
                    if not oracle(C_prev_prime, C):
                        continue
                        
                # We can now deduce the ith byte.
                # As you pad farther and farther back, you'll need to slightly tweak how you recover
                # the correct byte. Think back to how you derived it for the last byte; the process
                # will be very similar
                correct_block[i] = ... ### YOUR CODE HERE ###
                temp_block[i] = ... ### YOUR CODE HERE ###
                
                # Visualization code -- don't worry about it
                if display:
                    progress = correct_block.replace(b'\x00', b' ')
                    clear_output(wait=True)
                    if tail is not None:
                        progress = progress + tail
                    print(f'Recovered so far: {progress}')
                    
                break
    return correct_block

# Autograder
test7(decrypt_block)

TypeError: 'ellipsis' object is not iterable

# Question 3: Implementing the Attack

`Chef Brown's IoT pancake maker happens to host a central web application. This allows Brown to start the machine anywhere and arrive at the canteen to a fresh stack of pancakes. Brown read somewhere that her encryption scheme should not be deterministic, so she opted to use AES-CBC with a random IV and PKCS#7 padding.`

`You discover the web application will return a 500 error if it receives an invalid command (such as a message with invalid padding). Using the web application's API, can you decrypt the ciphertext and discover Brown's recipe?`

## Setup

We have provided a few things for you:
* The pancake maker's web application
* A `Client` class to interface with the web application

**If you're running the notebook locally, open a new terminal window and run:**
```
python iot.py
```
**If you're running the notebook on DataHub/Google Colab, execute the block below:**

In [106]:
from iot import app
import threading
threading.Thread(target=app.run, kwargs={'host':'0.0.0.0','port':12000}).start() 

 * Serving Flask app "iot" (lazy loading)
 * Environment: production
[2m   Use a production WSGI server instead.[0m
 * Debug mode: off


Exception in thread Thread-9:
Traceback (most recent call last):
  File "/Users/tt/opt/anaconda3/lib/python3.8/threading.py", line 932, in _bootstrap_inner
    self.run()
  File "/Users/tt/opt/anaconda3/lib/python3.8/threading.py", line 870, in run
    self._target(*self._args, **self._kwargs)
  File "/Users/tt/opt/anaconda3/lib/python3.8/site-packages/flask/app.py", line 990, in run
    run_simple(host, port, self, **options)
  File "/Users/tt/opt/anaconda3/lib/python3.8/site-packages/werkzeug/serving.py", line 1052, in run_simple
    inner()
  File "/Users/tt/opt/anaconda3/lib/python3.8/site-packages/werkzeug/serving.py", line 996, in inner
    srv = make_server(
  File "/Users/tt/opt/anaconda3/lib/python3.8/site-packages/werkzeug/serving.py", line 847, in make_server
    return ThreadedWSGIServer(
  File "/Users/tt/opt/anaconda3/lib/python3.8/site-packages/werkzeug/serving.py", line 740, in __init__
    HTTPServer.__init__(self, server_address, handler)
  File "/Users/tt/opt/anacond

**To define the `Client` class, run the block below:**

In [107]:
# No code you have to write in here :)
LOCAL_URL = 'http://127.0.0.1:12000/api/'

class Client:
    def __init__(self):
        # Start a new HTTP Session
        self.session = Session()
        self.session.mount('http://', HTTPAdapter())

        # Get the cached command
        url = LOCAL_URL + 'cache'
        response = self.session.get(url)
        content = response.json()
        if response.status_code in [400, 401, 403]:
            print('/cache API Error: ' + str(status_code))
        elif response.status_code != 200:
            print('/cache API Error: ' + str(status_code))

        # Decode the ciphertext
        self.iv = base64.b64decode(bytes(content['iv'], 'utf8'))
        self.ciphertext = base64.b64decode(bytes(content['ciphertext'], 'utf8'))

    def execute(self, iv, ciphertext):
        '''Sends ciphertext to web application. Return True if command is
        executed, False if the application returns an error'''
        if type(iv) == type(bytearray()):
            iv = bytes(iv)
        if type(ciphertext) == type(bytearray()):
            ciphertext = bytes(ciphertext)
        data = {
            'iv': base64.b64encode(iv), 
            'ciphertext': base64.b64encode(ciphertext)
            }
        response = self.session.post(LOCAL_URL + 'execute',data=data)
        status_code = response.status_code
        content = response.json()
        return content['success']

Now, we can apply our attack! **First, we need to formulate our oracle in terms of the `client`. To do so, run the code block below.**

In [108]:
# No code you have to write in here :)
def make_network_oracle(client):
    def oracle(C_last, C):
        return client.execute(C_last, C)
    return oracle

## 3.1

Now we can finally put it all together and execute our attack against a real server. We can directly use your work from previous parts to fill out the `decrypt` function below!

In [109]:
# No code you have to write in here :)
def decrypt(client):
    # Make the ciphertext mutable by casting to bytearray instead of bytes
    iv = bytearray(client.iv)
    ciphertext = bytearray(client.ciphertext)
    
    # Split the ciphertext into blocks
    blocks = [iv] + [ciphertext[i:i + 16] for i in range(0, len(ciphertext), 16)]
    # This stores our recovered plaintext
    plaintext = [bytearray(16) for _ in range(len(blocks))]
    oracle = make_network_oracle(client)

    # Recover each block and byte in reverse order
    for i in reversed(range(1, len(blocks))):
        C_last = blocks[i - 1]
        C = blocks[i]
        plaintext[i] = decrypt_block(C_last, C, oracle, True, b''.join(plaintext[i+1:]))
    
    return b''.join(plaintext[1:]).decode()

Run your attack:

In [110]:
print(f"\n\nOrigin of chef Brown's recipe:\n\n {decrypt(Client())}")

TypeError: 'ellipsis' object is not iterable

What is the secret ingredient in Chef Brown's recipe?

**<span style = "color: red">(YOUR ANSWER HERE)</span>**

# Question 4: History and Defenses

The padding oracle that you've developed in this lab has been used in plenty of real-world attacks. The attack was first discovered (publicly) in 2002. Since 2002 is well within the era of the internet, you can find the paper published by the exploit's original authors [here](https://www.iacr.org/cryptodb/archive/2002/EUROCRYPT/2850/2850.pdf). Exploits were quickly engineered against all sorts of servers to great success.

Fixes for this original attack "fixed" the vulnerability by simply stopping the server from directly telling users whether a message was padded correctly or not. But later attacks, in particular [lucky thirteen](https://arstechnica.com/information-technology/2013/02/lucky-thirteen-attack-snarfs-cookies-protected-by-ssl-encryption/), utilized the fact that a *timing* side-channel could let an attacker statistically infer when a padding was set correctly even without the server telling them explicitly. This illustrates how hard it is to avoid side-channel attacks and why we always prefer to use schemes which fail as safely as possible.

For the most part, CBC padding oracle attacks have been patched, but there are notable circumstances in which they can still be used to great success. [Some attacks](https://en.wikipedia.org/wiki/POODLE) use the fact that users can request to use older versions of protocols for backwards compatibilty, which can directly enable old attacks to come back to life. And since implementing cryptography is *hard*, patches and fixes can result in the resurrection of previously-patched vulnerabilities. Check out [this](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2016-2107) CVE, an ironic example of how a patch for lucky thirteen actually *enabled* a padding oracle attack.

Now, let's explore a possible defense. Sending just a message encrypted with CBC mode guarantees us confidentiality if we use a secure block cipher, which we generally assume AES is. The lack of authentication and integrity checking mean that we as attackers are allowed to freely tamper with messages in order to decrypt them. It stands to reason that added some authentication and integrity checks would prevent us from successfully modifying messages for the server's examination. For simplicity, let's consider just two possible methods of achieving this: MAC-then-encrypt and encrypt-then-MAC, both with a secure MAC algorithm.

## 4.1

Does MAC-then-encrypt prevent users from exploiting our padding oracle attack? If yes, explain why. If not, explain how to modify the attack so that it still works.

**<span style = "color: red">Yes, since after MAC, we know if eve modifiy our message.</span>**

## 4.2

Does encrypt-then-MAC prevent users from exploiting our padding oracle attack? If yes, explain why. If not, explain how to modify the attack so that it still works.

**<span style = "color: red">Yes, since after MAC, we know if eve modifiy our message.</span>**

# Question 5: SHA Length-Extension

In lecture, Nick let slip that SHA-2, great as (we currently think) it is, is vulnerable to an attack called a *length-extension attack*. In this question, we'll walk you through how the SHA-2 algorithm actually works and show you how an attacker could use that knowledge to forge SHA-2 hashes based on (message, hash) pairs that they observe. As with the CBC padding oracle part of this lab, we highly recommend that you complete HW2 before attempting this question in order to gain some intuition.

Note that you don't actually have to write any code for this question, so it's double optional. But please do read, run, and play with the examples :).

Recall the general outline of the SHA-2 algorithm from HW2:

<img src="https://raw.githubusercontent.com/cs161-staff/labs/master/padding_oracle/sha-outline.png" style="height:300px" />

## The Nitty-Gritty

Since we're going to actually write a length extension attack, we have to interact with the real algorithm specified in actual code. Below, you can (and should) read the SHA-256 code generously provided under the MIT License [here](https://perso.crans.org/besson/publis/notebooks/Benchmark_of_the_SHA256_hash_function__Python_Cython_Numba.html). The code is pretty well commented, but it's okay if you don't understand every small detail.

In [111]:
def leftrotate(x, c):
    """ Left rotate the number x by c bytes."""
    x &= 0xFFFFFFFF
    return ((x << c) | (x >> (32 - c))) & 0xFFFFFFFF

def rightrotate(x, c):
    """ Right rotate the number x by c bytes."""
    x &= 0xFFFFFFFF
    return ((x >> c) | (x << (32 - c))) & 0xFFFFFFFF

def leftshift(x, c):
    """ Left shift the number x by c bytes."""
    return x << c

def rightshift(x, c):
    """ Right shift the number x by c bytes."""
    return x >> c

class SHA2():
    """SHA256 hashing, see https://en.wikipedia.org/wiki/SHA-2#Pseudocode."""
    
    def __init__(self):
        self.name        = "SHA256"
        self.byteorder   = 'big'
        self.block_size  = 64
        self.digest_size = 32
        # Note 2: For each round, there is one round constant k[i] and one entry in the message schedule array w[i], 0 ≤ i ≤ 63
        # Note 3: The compression function uses 8 working variables, a through h
        # Note 4: Big-endian convention is used when expressing the constants in this pseudocode,
        #         and when parsing message block data from bytes to words, for example,
        #         the first word of the input message "abc" after padding is 0x61626380

        # Initialize hash values:
        # (first 32 bits of the fractional parts of the square roots of the first 8 primes 2..19):
        # you don't need to worry about these
        h0 = 0x6a09e667
        h1 = 0xbb67ae85
        h2 = 0x3c6ef372
        h3 = 0xa54ff53a
        h4 = 0x510e527f
        h5 = 0x9b05688c
        h6 = 0x1f83d9ab
        h7 = 0x5be0cd19
        
        # Initialize array of round constants:
        # (first 32 bits of the fractional parts of the cube roots of the first 64 primes 2..311):
        # You don't need to worry about these
        self.k = [
            0x428a2f98, 0x71374491, 0xb5c0fbcf, 0xe9b5dba5, 0x3956c25b, 0x59f111f1, 0x923f82a4, 0xab1c5ed5,
            0xd807aa98, 0x12835b01, 0x243185be, 0x550c7dc3, 0x72be5d74, 0x80deb1fe, 0x9bdc06a7, 0xc19bf174,
            0xe49b69c1, 0xefbe4786, 0x0fc19dc6, 0x240ca1cc, 0x2de92c6f, 0x4a7484aa, 0x5cb0a9dc, 0x76f988da,
            0x983e5152, 0xa831c66d, 0xb00327c8, 0xbf597fc7, 0xc6e00bf3, 0xd5a79147, 0x06ca6351, 0x14292967,
            0x27b70a85, 0x2e1b2138, 0x4d2c6dfc, 0x53380d13, 0x650a7354, 0x766a0abb, 0x81c2c92e, 0x92722c85,
            0xa2bfe8a1, 0xa81a664b, 0xc24b8b70, 0xc76c51a3, 0xd192e819, 0xd6990624, 0xf40e3585, 0x106aa070,
            0x19a4c116, 0x1e376c08, 0x2748774c, 0x34b0bcb5, 0x391c0cb3, 0x4ed8aa4a, 0x5b9cca4f, 0x682e6ff3,
            0x748f82ee, 0x78a5636f, 0x84c87814, 0x8cc70208, 0x90befffa, 0xa4506ceb, 0xbef9a3f7, 0xc67178f2
        ]

        # Store them
        # this is the internal state that we update at the end of every 512-bit chunk of input
        self.hash_pieces = [h0, h1, h2, h3, h4, h5, h6, h7]
    
    #this function is what actually does a "hash"
    #if we use it multiple times, EACH input will be padded and treated as
    #the message to hash with the internal state hash_pieces
    def update(self, arg, fakelength=None, visualize=False):
        h0, h1, h2, h3, h4, h5, h6, h7 = self.hash_pieces
        # 1. Pre-processing, exactly like MD5
        data = bytearray(arg)
        orig_len_in_bits = (8 * len(data)) & 0xFFFFFFFFFFFFFFFF
        # 1.a. Add a single '1' bit at the end of the input bits
        data.append(0x80) #0x80 = 0b10000000, and we assume inputs are always given in bytes
        # 1.b. Padding with zeros as long as the input bits length ≡ 448 (mod 512)
        # why 448? because we append length as a 64 bit integer next
        while len(data) % 64 != 56:
            data.append(0)
        # 1.c. append original length in bits mod (2 pow 64) to message
        # modification to allow us to override the message length for when we spoof hashes
        l = None
        if fakelength:
            l = (fakelength * 8).to_bytes(8, byteorder='big')
        else:
            l = orig_len_in_bits.to_bytes(8, byteorder='big') #unsigned by default, which we want
        data += l 
        assert len(data) % 64 == 0, "Error in padding"
        
        # 2. Computations
        # Process the message in successive 512-bit = 64-bytes chunks:
        for r, offset in enumerate(range(0, len(data), 64)):
            # 2.a. 512-bits = 64-bytes chunks
            chunks = data[offset : offset + 64]
            w = [0 for i in range(64)]
            # 2.b. Break chunk into sixteen 32-bit = 4-bytes words w[i], 0 ≤ i ≤ 15
            for i in range(16):
                w[i] = int.from_bytes(chunks[4*i : 4*i + 4], byteorder='big')
            # 2.c.  Extend the first 16 words into the remaining 48
            #       words w[16..63] of the message schedule array:
            for i in range(16, 64):
                s0 = (rightrotate(w[i-15], 7) ^ rightrotate(w[i-15], 18) ^ rightshift(w[i-15], 3)) & 0xFFFFFFFF
                s1 = (rightrotate(w[i-2], 17) ^ rightrotate(w[i-2], 19) ^ rightshift(w[i-2], 10)) & 0xFFFFFFFF
                w[i] = (w[i-16] + s0 + w[i-7] + s1) & 0xFFFFFFFF
            # 2.d. Initialize hash value for this chunk
            a, b, c, d, e, f, g, h = h0, h1, h2, h3, h4, h5, h6, h7
            # 2.e. Main loop, cf. https://tools.ietf.org/html/rfc6234
            # this is the "compression function" you saw in HW2
            for i in range(64):
                S1 = (rightrotate(e, 6) ^ rightrotate(e, 11) ^ rightrotate(e, 25)) & 0xFFFFFFFF
                ch = ((e & f) ^ ((~e) & g)) & 0xFFFFFFFF
                temp1 = (h + S1 + ch + self.k[i] + w[i]) & 0xFFFFFFFF
                S0 = (rightrotate(a, 2) ^ rightrotate(a, 13) ^ rightrotate(a, 22)) & 0xFFFFFFFF
                maj = ((a & b) ^ (a & c) ^ (b & c)) & 0xFFFFFFFF
                temp2 = (S0 + maj) & 0xFFFFFFFF

                new_a = (temp1 + temp2) & 0xFFFFFFFF
                new_e = (d + temp1) & 0xFFFFFFFF
                # Rotate the 8 variables
                a, b, c, d, e, f, g, h = new_a, a, b, c, new_e, e, f, g

            # Add this chunk's hash to result so far:
            # the & 0xFFFFFFFF is to ensure we're keeping each part of the hash
            # within 32 bits (enforces addition mod 2^32)
            h0 = (h0 + a) & 0xFFFFFFFF
            h1 = (h1 + b) & 0xFFFFFFFF
            h2 = (h2 + c) & 0xFFFFFFFF
            h3 = (h3 + d) & 0xFFFFFFFF
            h4 = (h4 + e) & 0xFFFFFFFF
            h5 = (h5 + f) & 0xFFFFFFFF
            h6 = (h6 + g) & 0xFFFFFFFF
            h7 = (h7 + h) & 0xFFFFFFFF
            
            if visualize:
                rep = sum(leftshift(x, 32 * i) for i, x in enumerate([h0, h1, h2, h3, h4, h5, h6, h7]))
                hexrep = rep.to_bytes(self.digest_size, byteorder=self.byteorder)
                print(f"SHA's internal state after processing {r} blocks is {hexrep}")
        # 3. Conclusion
        self.hash_pieces = [h0, h1, h2, h3, h4, h5, h6, h7]

    def digest(self):
        # h0 append h1 append h2 append h3 append h4 append h5 append h6 append h7
        # since Python 3 allows for arbitrarily large integers, this is actually just a way of
        # doing h0 || h1 || h2 || h3 || h4 || h5 || h6 || h7
        return sum(leftshift(x, 32 * i) for i, x in enumerate(self.hash_pieces[::-1]))

    def hexdigest(self):
        """ Like digest() except the digest is returned as a string object of double length, containing only hexadecimal digits. This may be used to exchange the value safely in email or other non-binary environments."""
        digest = self.digest()
        raw = digest.to_bytes(self.digest_size, byteorder=self.byteorder)
        format_str = '{:0' + str(2 * self.digest_size) + 'x}'
        return format_str.format(int.from_bytes(raw, byteorder='big'))

def hash(data):
    """ Shortcut function to directly receive the hex digest from SHA2(data)."""
    h = SHA2()
    if isinstance(data, str):
        data = bytes(data, encoding='utf8')
    h.update(data)
    return h.hexdigest()

## Implementing the Attack

The notorious hacker Vladimir Computin has hacked into Commerica's *Department of Voting Machine Software* in order to corrupt Commerica's upcoming election by modifying the *Department*'s star product, VoteOS. In particular, Computin has discovered a buffer overflow vulnerability in VoteOS that circumvents stack canaries and ASLR, but he wants to also defeat NX bit protection. He intends to do this by distributing a version of VoteOS which contains his shellcode in the binary itself. The only problem is that VoteOS is distributed along with the signature $\mathrm{SHA}(k||B)$ where $k$ is a secret key that all government officials memorize and $B$ is the VoteOS binary. Computin controls an empire of USB flash drive distributors, but he is somehow unable to discover $k$. Stymied, Computin calls you, his trusty CS 161 friend and former KGB colleague, to help circumvent this problem.

You know the gist of SHA-2 length extension attacks, so you tell Computin that it'll be no problem. Since you know the hash of the original binary and are only adding code to the end of the file, you code up a way to spoof the new binary's SHA-256 hash.

First, you realize that since SHA-256 pads its inputs before performing any computations, you can't actually directly extend $\mathrm{SHA}(k||B)$ to $\mathrm{SHA}(k||B||M)$ (where $M$ is Computin's malcode). Consulting your reference implementation of SHA-256, you realize that you'll only be able to construct the hash $\mathrm{SHA}(k||B||p||M)$ where $p$ is the padding generated by the SHA-256 algorithm. This, however, is no problem &mdash; Computin's attacks will work just as well on the binary $B||p||M$ as they would on $B||M$ (why? Remeber that the message itself is a binary, and think about the structure of assembly programs).

You write the following code to compute $p$ from $|k||B|$:

In [114]:
def gen_pad(arg):
    data = bytearray(arg)
    orig_len_in_bits = (8 * len(data)) & 0xFFFFFFFFFFFFFFFF
    # 1.a. Add a single '1' bit at the end of the input bits
    data.append(0x80) #0x80 = 0b10000000, and we assume inputs are always given in bytes
    # 1.b. Padding with zeros as long as the input bits length ≡ 448 (mod 512)
    # why 448? because we append length as a 64 bit integer next
    while len(data) % 64 != 56:
        data.append(0)
    # 1.c. append original length in bits mod (2 pow 64) to message
    data += orig_len_in_bits.to_bytes(8, byteorder='big') #unsigned by default, which we want
    assert len(data) % 64 == 0
    assert orig_len_in_bits % 8 == 0
    return data[orig_len_in_bits // 8:]

With your dirty work done, you implement your attack. Feel free to play with the key, message, and modification to the message, and try setting `visualize=True` to see the SHA algorithm updating as it processes its input.

In [115]:
import random
random.seed(161)
msg_size = 102
mod_size = 16
k = b'stacks of pancakes'
B = random.getrandbits(msg_size * 8).to_bytes(msg_size, byteorder='big')
p = gen_pad(k + B)
M = random.getrandbits(mod_size * 8).to_bytes(mod_size, byteorder='big')
hasher = SHA2()
hasher.update(k + B, visualize=True)
print(f"The original binary B has signature SHA(k||B) = {hasher.hexdigest()}")
hasher.update(M, fakelength=len(k + B + p + M)) #set our fake length
print(f"Our modified hash for k||B||p||M is {hasher.hexdigest()}")
print(f"If we had expliticly found the hash of k||B||p||M, we'd have gotten {hash(k + B + p + M)}")

SHA's internal state after processing 0 blocks is b'\xcd\xcd\x86\xbf\xd5\xf2\xff\xec7e\xf8\x88n\xdd\xd8\xe9=\x12V\x93\xfd\x0cOS\x14\x0b\xac ;\xc8\xe7\x0e'
SHA's internal state after processing 1 blocks is b'\xabY\x86\xeb\xdaj#\x12\xa2\xf0.w\xc0\xce\xb8(\x0f\xddR\x94s|\x9f\x932{$vY\xdf\x87\xed'
SHA's internal state after processing 2 blocks is b'\xca;:\x8e\x04\x00\xaa\xf7\x0b\x92\x82\xe3P\xfc`\x07\xf5qz\xbf\xe2\x18|\xa6\x98\xef\xca\x0e<\xb8J_'
The original binary B has signature SHA(k||B) = 3cb84a5f98efca0ee2187ca6f5717abf50fc60070b9282e30400aaf7ca3b3a8e
Our modified hash for k||B||p||M is f8e67240f834cd6fb98254ed795549e7da54418df680aa062e4eb6b6bd80fad6
If we had expliticly found the hash of k||B||p||M, we'd have gotten f8e67240f834cd6fb98254ed795549e7da54418df680aa062e4eb6b6bd80fad6


# Question 6: Lab Feedback

This is the third time 161 has offered this lab, and we would greatly appreciate your feedback!

<a href="https://forms.gle/BaY7tSbCF5TdGE9cA">Feedback form</a>