## Regular Sha256 implementation using only Numpy

In [None]:
import numpy as np
import hashlib

# This is a simple implementation of SHA256, based on the specification:
# https://nvlpubs.nist.gov/nistpubs/FIPS/NIST.FIPS.180-4.pdf

# Just for safety, we use uint64 for all items
ITEM_TYPE = np.uint64

# Mask for 32 bit words, used for taking modulo 2^32
MASK = ITEM_TYPE(0xFFFFFFFF)


def uint64_to_bin(uint64):
    """
                    uint64_to_bin: 64-bit Number -> binary representation of the number as a string
    """
    return ("".join([str(uint64 >> i & 1) for i in range(63, -1, -1)]))


def to_words(data):  # Array of bytes(uint8) -> Array of words(uint32)
    """
                    to_words: Array of bytes(uint8) -> Array of words(uint32)
    """
    data_len = len(data)
    return np.array(
        [np.left_shift(data[i], 24) + np.left_shift(data[i+1], 16) + np.left_shift(data[i+2], 8) + data[i+3]
         for i in range(0, data_len, 4)])


def to_hex(data, chunks=8, delim=""):
    """
                    to_hex: Array of int -> String

                    ex:
                            to_hex(np.array([0x12345678, 0x9ABCDEF0, 0x12345678, 0x9ABCDEF0]), 4, " ")
                            -> "1234 5678 9ABC DEFG 1234 5678 9ABC DEFG"
    """
    # assert data.dtype == np.uint8
    return delim.join(list(map(lambda x: hex_with_chunks(x, chunks), data)))


def hex_with_chunks(x, chunks):
    """
                    hex_with_chunks: int -> String
    """
    x = np.uint64(x)
    chars = []
    for chunk in range(chunks):
        nibble = np.uint64(x % 16)
        char = hex(nibble)[2:]
        chars.append(char)
        x = x >> np.uint64(4)

    return "".join(chars[::-1])


##############################################################
# SHA256 Constants
# SHA-256 constants
K = np.array([
    0x428a2f98, 0x71374491, 0xb5c0fbcf, 0xe9b5dba5,
    0x3956c25b, 0x59f111f1, 0x923f82a4, 0xab1c5ed5,

    0xd807aa98, 0x12835b01, 0x243185be, 0x550c7dc3,
    0x72be5d74, 0x80deb1fe, 0x9bdc06a7, 0xc19bf174,

    0xe49b69c1, 0xefbe4786, 0x0fc19dc6, 0x240ca1cc,
    0x2de92c6f, 0x4a7484aa, 0x5cb0a9dc, 0x76f988da,

    0x983e5152, 0xa831c66d, 0xb00327c8, 0xbf597fc7,
    0xc6e00bf3, 0xd5a79147, 0x06ca6351, 0x14292967,

    0x27b70a85, 0x2e1b2138, 0x4d2c6dfc, 0x53380d13,
    0x650a7354, 0x766a0abb, 0x81c2c92e, 0x92722c85,

    0xa2bfe8a1, 0xa81a664b, 0xc24b8b70, 0xc76c51a3,
    0xd192e819, 0xd6990624, 0xf40e3585, 0x106aa070,

    0x19a4c116, 0x1e376c08, 0x2748774c, 0x34b0bcb5,
    0x391c0cb3, 0x4ed8aa4a, 0x5b9cca4f, 0x682e6ff3,

    0x748f82ee, 0x78a5636f, 0x84c87814, 0x8cc70208,
    0x90befffa, 0xa4506ceb, 0xbef9a3f7, 0xc67178f2,
], dtype=ITEM_TYPE)

print(f"K is: {K.shape} : {K.dtype}")

H = np.array([
    0x6a09e667,  # h0
    0xbb67ae85,  # h1
    0x3c6ef372,  # h2
    0xa54ff53a,  # h3
    0x510e527f,  # h4
    0x9b05688c,  # h5
    0x1f83d9ab,  # h6
    0x5be0cd19,  # h7
], dtype=ITEM_TYPE)

print(f"H is: {H.shape} : {H.dtype}")


##############################################################
# SHA256 Functions

def rotr_32bit(x, n):
    n = ITEM_TYPE(n % 32)
    n_comp = ITEM_TYPE(max(0, 32 - n))
    x = np.uint64(x)
    l_shifted = (x << n_comp) & MASK
    return ((x >> n) | l_shifted) & MASK


def shr_32bit(x, n):
    n = ITEM_TYPE(min(n, 32))
    return (x >> n) & MASK


def big_sigma0(x):
    return (rotr_32bit(x, 2) ^ rotr_32bit(x, 13) ^ rotr_32bit(x, 22)) & MASK


def big_sigma1(x):
    return (rotr_32bit(x, 6) ^ rotr_32bit(x, 11) ^ rotr_32bit(x, 25)) & MASK


def sigma0(x):
    return (rotr_32bit(x, 7) ^ rotr_32bit(x, 18) ^ shr_32bit(x, 3)) & MASK


def sigma1(x):
    return (rotr_32bit(x, 17) ^ rotr_32bit(x, 19) ^ shr_32bit(x, 10)) & MASK


def inv(x, width):
    """
            uint, uint -> uint

            returns the bitwise inverse of x, with the given width
    """
    return np.uint32(((1 << width) - 1) - x) & MASK


def ch(x, y, z):
    return ((x & y) ^ (inv(x, 32) & z)) & MASK


def maj(x, y, z):
    return ((x & y) ^ (x & z) ^ (y & z)) & MASK

###############################################################


def sha256_preprocess(data):
    """ 
            Takes a message of arbitrary length and returns a message
            of length that is a multiple of 512 bits, with the original message padded
            with a 1 bit, followed by 0 bits, followed by the original message length
            in bits, encoded as a 4 bit integers.
    """

    data = np.array(data, dtype=np.uint8)
    message_len = data.shape[0] * 8  # denoted as 'l' in spec
    # find padding length 'k'
    k = (((448 - 1 - message_len) % 512) + 512) % 512
    padstring = "1" + "0" * k + str(uint64_to_bin(message_len))

    total_size = len(padstring) + message_len
    print("total size:", total_size)
    assert total_size % 512 == 0

    pad = np.array([int(padstring[i:i+8], 2)
                   for i in range(0, len(padstring), 8)], dtype=np.uint8)
    padded = np.concatenate((data, pad))

    # split into 32bit words
    words = to_words(padded)
    return words


def sha256_hash(data, h0, k):
    """
                    Takes a message of length that is a multiple of 512 bits, and returns
                    the SHA256 hash of the message.
    """
    # calculate number of blocks from data length
    block_count = data.shape[0] // (BLOCK_SIZE_IN_WORDS)

    # initialize hash values with h0
    h = h0

    # initialize round constants
    W_SIZE = 64
    INNER_ROUNDS = 64

    # For each block in the message (each block is 512 bits)
    for i in range(block_count):
        # Initialize message schedule
        # w_t = M_t for 0 <= t <= 15
        w = np.zeros(W_SIZE, dtype=ITEM_TYPE)

        # First 16 words of the message schedule comes from the message
        for j in range(16):
            w[j] = data[i*16 + j]

        # Remaining 48 words of the message schedule are calculated
        for t in range(16, W_SIZE):
            sigma1_result = sigma1(w[t-2])
            sigma0_result = sigma0(w[t-15])

            w[t] = (sigma0_result + sigma1_result + w[t-7] + w[t-16]) & MASK

        # initialize working variables
        # with the previous hash value
        working = h

        # main loop
        # for t in range(0, 64):
        for t in range(0, INNER_ROUNDS):
            maj_result = maj(working[0], working[1], working[2])
            ch_result = ch(working[4], working[5], working[6])

            t1 = (working[7] + big_sigma1(working[4]) +
                  ch_result + k[t] + w[t]) & MASK
            t2 = (big_sigma0(working[0]) + maj_result) & MASK

            # Update working variables
            working = np.array(
                [
                    (t1 + t2),
                    working[0],
                    working[1],
                    working[2],
                    (working[3] + t1),
                    working[4],
                    working[5],
                    working[6]
                ],
                dtype=ITEM_TYPE
            ) & MASK

        # After inner loops are done
        # add final working to h
        h = (h + working) & MASK
    return h


def sha256(text):
    """
                    Takes a message of arbitrary length and returns the SHA256 hash of the message.
    """
    preprocessed = sha256_preprocess(text)
    return sha256_hash(preprocessed, H, K)


BLOCK_SIZE_IN_BYTES = 64  # in bytes
WORD_SIZE_IN_BYTES = 4  # in bytes
BLOCK_SIZE_IN_WORDS = BLOCK_SIZE_IN_BYTES // WORD_SIZE_IN_BYTES


text = (
    b"Lorem ipsum dolor sit amet, consectetur adipiscing elit. "
    b"Curabitur bibendum, urna eu bibendum egestas, neque augue eleifend odio, et sagittis viverra."
)
assert len(text) == 150

hasher = hashlib.sha256(text)

sample_input = list(text)

expected_output = np.array(list(hasher.digest()))  # returns 64 x bytes(uint8)
impl_output = sha256(sample_input)  # returns 8 x words(uint32)

expected_hash = to_hex(expected_output, 2)
impl_hash = to_hex(impl_output)
print(f"expected_hash: {expected_hash}")
print(f"impl_hash    : {impl_hash}")
assert expected_hash == impl_hash


## Converting to FHE computation using ConcreteNumpy

In [None]:
# Import required libraries
import concrete.numpy as cnp
import numpy as np


Firstly we need to realize that we are required to do some encoding and decoding. Because we cannot work with 32-bit integers directly, we need to encode them into array of 4-bit integers. This is done by the function `encode` and `decode`. 

In [None]:
def encode(number: np.uint32, width: np.uint32, chunk_size: np.uint32):
    binary_repr = np.binary_repr(int(number), width=int(width))
    blocks = [binary_repr[i:i+int(chunk_size)]
              for i in range(0, len(binary_repr), int(chunk_size))]
    return np.array([int(block, base=2) for block in blocks])


def decode(encoded_number, chunk_size: np.uint32) -> np.uint32:
    result = 0
    for i in range(len(encoded_number)):
        result += 2**(chunk_size*i) * \
            encoded_number[(len(encoded_number) - i) - 1]
    return np.uint32(result)


Now we can start to implement the Sha256 algorithm. We will use the same function names as in the original algorithm. The only difference is that we will use the `ConcreteNumpy` functions instead of the `Numpy` functions. And the functions `ch`, `maj`, `sigma0`, `sigma1`, `big_sigma0`, `big_sigma1` will be implemented so that they work with 4-bit unsigned-integers. The functions `rotr_32bit`, `shr_32bit` will not be implemented because we will calculate which bits end up where manually in `sigma0`, `sigma1`, `big_sigma0`, `big_sigma1`.

Now Let's implement 4-bit versions of the functions `ch`, `maj`, `sigma0`, `sigma1`, `big_sigma0`, `big_sigma1`:

* Input to below functions are (8 x uint4) arrays
* Output is (8 x uint4) array

In [None]:
def ch(x, y, z):
    """
        Choice function

        x: 8 x (uint4)
        y: 8 x (uint4)
        z: 8 x (uint4)

        returns: 8 x (uint4)
    """
    left = x & y

    # Compute ~x using 1's complement, because it's faster
    mask = 0xf
    not_x = mask - (x & mask)

    right = not_x & z
    return left ^ right


# Majority function can be computed as (x & y) ^ (x & z) ^ (y & z)
# this notation works on arrays as well
def maj(x, y, z):
    return (x & y) ^ (x & z) ^ (y & z)


* To calculate sigma0, we need to do shifts on (8xuint4)
* We will do the shifts manually by:
  * Finding the bits that will end up in the output for the given shift
  * XORing the bits together and putting them in the output
  * Merging every 4 bits into a single uint4, by multiplying them by 2^0, 2^1, 2^2, 2^3 respectively 
  * returning the resulting array

Note: 
* For some reason, *8, *4, *2, *1 is faster than <<3, <<2, <<1, <<0 (at least in the case of 4-bit numbers)
* Also, //8, //4, //2, //1 is faster than >>3, >>2, >>1, >>0


In [None]:
def sigma0(x):

    result = np.array([
        (((((x[6]//4) & 1) ^ ((x[3]//2) & 1))) * 8 +
         ((((x[6]//2) & 1) ^ ((x[3]) & 1))) * 4 +
         ((((x[6]) & 1) ^ ((x[4]//8) & 1))) * 2 +
         ((((x[7]//8) & 1) ^ ((x[4]//4) & 1) ^ ((x[0]//8) & 1)))),

        (((((x[7]//4) & 1) ^ ((x[4]//2) & 1) ^ ((x[0]//4) & 1))) * 8 +
         ((((x[7]//2) & 1) ^ ((x[4]) & 1) ^ ((x[0]//2) & 1))) * 4 +
         ((((x[7]) & 1) ^ ((x[5]//8) & 1) ^ ((x[0]) & 1))) * 2 +
         ((((x[0]//8) & 1) ^ ((x[5]//4) & 1) ^ ((x[1]//8) & 1)))),

        (((((x[0]//4) & 1) ^ ((x[5]//2) & 1) ^ ((x[1]//4) & 1))) * 8 +
         ((((x[0]//2) & 1) ^ ((x[5]) & 1) ^ ((x[1]//2) & 1))) * 4 +
         ((((x[0]) & 1) ^ ((x[6]//8) & 1) ^ ((x[1]) & 1))) * 2 +
         ((((x[1]//8) & 1) ^ ((x[6]//4) & 1) ^ ((x[2]//8) & 1)))),

        (((((x[1]//4) & 1) ^ ((x[6]//2) & 1) ^ ((x[2]//4) & 1))) * 8 +
         ((((x[1]//2) & 1) ^ ((x[6]) & 1) ^ ((x[2]//2) & 1))) * 4 +
         ((((x[1]) & 1) ^ ((x[7]//8) & 1) ^ ((x[2]) & 1))) * 2 +
         ((((x[2]//8) & 1) ^ ((x[7]//4) & 1) ^ ((x[3]//8) & 1)))),

        (((((x[2]//4) & 1) ^ ((x[7]//2) & 1) ^ ((x[3]//4) & 1))) * 8 +
         ((((x[2]//2) & 1) ^ ((x[7]) & 1) ^ ((x[3]//2) & 1))) * 4 +
         ((((x[2]) & 1) ^ ((x[0]//8) & 1) ^ ((x[3]) & 1))) * 2 +
         ((((x[3]//8) & 1) ^ ((x[0]//4) & 1) ^ ((x[4]//8) & 1)))),

        (((((x[3]//4) & 1) ^ ((x[0]//2) & 1) ^ ((x[4]//4) & 1))) * 8 +
         ((((x[3]//2) & 1) ^ ((x[0]) & 1) ^ ((x[4]//2) & 1))) * 4 +
         ((((x[3]) & 1) ^ ((x[1]//8) & 1) ^ ((x[4]) & 1))) * 2 +
         ((((x[4]//8) & 1) ^ ((x[1]//4) & 1) ^ ((x[5]//8) & 1)))),

        (((((x[4]//4) & 1) ^ ((x[1]//2) & 1) ^ ((x[5]//4) & 1))) * 8 +
         ((((x[4]//2) & 1) ^ ((x[1]) & 1) ^ ((x[5]//2) & 1))) * 4 +
         ((((x[4]) & 1) ^ ((x[2]//8) & 1) ^ ((x[5]) & 1))) * 2 +
         ((((x[5]//8) & 1) ^ ((x[2]//4) & 1) ^ ((x[6]//8) & 1)))),

        (((((x[5]//4) & 1) ^ ((x[2]//2) & 1) ^ ((x[6]//4) & 1))) * 8 +
         ((((x[5]//2) & 1) ^ ((x[2]) & 1) ^ ((x[6]//2) & 1))) * 4 +
         ((((x[5]) & 1) ^ ((x[3]//8) & 1) ^ ((x[6]) & 1))) * 2 +
         ((((x[6]//8) & 1) ^ ((x[3]//4) & 1) ^ ((x[7]//8) & 1))))
    ])
    return cnp.array(result)


* Similarly we can implement sigma1, big_sigma0, big_sigma1

In [None]:
def sigma1(x):
    result = np.array([
        ((((x[3]) & 1) ^ ((x[3]//4) & 1))) * 8 +
        ((((x[4]//8) & 1) ^ ((x[3]//2) & 1))) * 4 +
        ((((x[4]//4) & 1) ^ ((x[3]) & 1))) * 2 +
        ((((x[4]//2) & 1) ^ ((x[4]//8) & 1))),

        ((((x[4]) & 1) ^ ((x[4]//4) & 1))) * 8 +
        ((((x[5]//8) & 1) ^ ((x[4]//2) & 1))) * 4 +
        ((((x[5]//4) & 1) ^ ((x[4]) & 1))) * 2 +
        ((((x[5]//2) & 1) ^ ((x[5]//8) & 1))),

        ((((x[5]) & 1) ^ ((x[5]//4) & 1))) * 8 +
        ((((x[6]//8) & 1) ^ ((x[5]//2) & 1))) * 4 +
        ((((x[6]//4) & 1) ^ ((x[5]) & 1) ^ ((x[0] // 8) & 1))) * 2 +
        ((((x[6]//2) & 1) ^ ((x[6]//8) & 1) ^ ((x[0] // 4) & 1))),

        ((((x[6]) & 1) ^ ((x[6]//4) & 1) ^ ((x[0] // 2) & 1))) * 8 +
        ((((x[7]//8) & 1) ^ ((x[6]//2) & 1) ^ ((x[0]) & 1))) * 4 +
        ((((x[7]//4) & 1) ^ ((x[6]) & 1) ^ ((x[1] // 8) & 1))) * 2 +
        ((((x[7]//2) & 1) ^ ((x[7]//8) & 1) ^ ((x[1] // 4) & 1))),

        ((((x[7]) & 1) ^ ((x[7]//4) & 1) ^ ((x[1] // 2) & 1))) * 8 +
        ((((x[0]//8) & 1) ^ ((x[7]//2) & 1) ^ ((x[1]) & 1))) * 4 +
        ((((x[0]//4) & 1) ^ ((x[7]) & 1) ^ ((x[2] // 8) & 1))) * 2 +
        ((((x[0]//2) & 1) ^ ((x[0]//8) & 1) ^ ((x[2] // 4) & 1))),

        ((((x[0]) & 1) ^ ((x[0]//4) & 1) ^ ((x[2] // 2) & 1))) * 8 +
        ((((x[1]//8) & 1) ^ ((x[0]//2) & 1) ^ ((x[2]) & 1))) * 4 +
        ((((x[1]//4) & 1) ^ ((x[0]) & 1) ^ ((x[3] // 8) & 1))) * 2 +
        ((((x[1]//2) & 1) ^ ((x[1]//8) & 1) ^ ((x[3] // 4) & 1))),

        ((((x[1]) & 1) ^ ((x[1]//4) & 1) ^ ((x[3] // 2) & 1))) * 8 +
        ((((x[2]//8) & 1) ^ ((x[1]//2) & 1) ^ ((x[3]) & 1))) * 4 +
        ((((x[2]//4) & 1) ^ ((x[1]) & 1) ^ ((x[4] // 8) & 1))) * 2 +
        ((((x[2]//2) & 1) ^ ((x[2]//8) & 1) ^ ((x[4] // 4) & 1))),

        ((((x[2]) & 1) ^ ((x[2]//4) & 1) ^ ((x[4] // 2) & 1))) * 8 +
        ((((x[3]//8) & 1) ^ ((x[2]//2) & 1) ^ ((x[4]) & 1))) * 4 +
        ((((x[3]//4) & 1) ^ ((x[2]) & 1) ^ ((x[5] // 8) & 1))) * 2 +
        ((((x[3]//2) & 1) ^ ((x[3]//8) & 1) ^ ((x[5] // 4) & 1))),
    ]
    )
    return (result)


In [None]:
def big_sigma0(x):
    result = np.array([
        (((x[7] // 2) & 1) ^ ((x[4]) & 1) ^ ((x[2] // 2) & 1)) * 8 +
        (((x[7]) & 1) ^ ((x[5] // 8) & 1) ^ ((x[2]) & 1)) * 4 +
        (((x[0] // 8) & 1) ^ ((x[5] // 4) & 1) ^ ((x[3] // 8) & 1)) * 2 +
        (((x[0] // 4) & 1) ^ ((x[5] // 2) & 1)
         ^ ((x[3] // 4) & 1)),

        (((x[0] // 2) & 1) ^ ((x[5]) & 1) ^ ((x[3] // 2) & 1)) * 8 +
        (((x[0]) & 1) ^ ((x[6] // 8) & 1) ^ ((x[3]) & 1)) * 4 +
        (((x[1] // 8) & 1) ^ ((x[6] // 4) & 1) ^ ((x[4] // 8) & 1)) * 2 +
        (((x[1] // 4) & 1) ^ ((x[6] // 2) & 1)
         ^ ((x[4] // 4) & 1)),

        (((x[1] // 2) & 1) ^ ((x[6]) & 1) ^ ((x[4] // 2) & 1)) * 8 +
        (((x[1]) & 1) ^ ((x[7] // 8) & 1) ^ ((x[4]) & 1)) * 4 +
        (((x[2] // 8) & 1) ^ ((x[7] // 4) & 1) ^ ((x[5] // 8) & 1)) * 2 +
        (((x[2] // 4) & 1) ^ ((x[7] // 2) & 1)
         ^ ((x[5] // 4) & 1)),

        (((x[2] // 2) & 1) ^ ((x[7]) & 1) ^ ((x[5] // 2) & 1)) * 8 +
        (((x[2]) & 1) ^ ((x[0] // 8) & 1) ^ ((x[5]) & 1)) * 4 +
        (((x[3] // 8) & 1) ^ ((x[0] // 4) & 1) ^ ((x[6] // 8) & 1)) * 2 +
        (((x[3] // 4) & 1) ^ ((x[0] // 2) & 1)
         ^ ((x[6] // 4) & 1)),

        (((x[3] // 2) & 1) ^ ((x[0]) & 1) ^ ((x[6] // 2) & 1)) * 8 +
        (((x[3]) & 1) ^ ((x[1] // 8) & 1) ^ ((x[6]) & 1)) * 4 +
        (((x[4] // 8) & 1) ^ ((x[1] // 4) & 1) ^ ((x[7] // 8) & 1)) * 2 +
        (((x[4] // 4) & 1) ^ ((x[1] // 2) & 1)
         ^ ((x[7] // 4) & 1)),

        (((x[4] // 2) & 1) ^ ((x[1]) & 1) ^ ((x[7] // 2) & 1)) * 8 +
        (((x[4]) & 1) ^ ((x[2] // 8) & 1) ^ ((x[7]) & 1)) * 4 +
        (((x[5] // 8) & 1) ^ ((x[2] // 4) & 1) ^ ((x[0] // 8) & 1)) * 2 +
        (((x[5] // 4) & 1) ^ ((x[2] // 2) & 1)
         ^ ((x[0] // 4) & 1)),

        (((x[5] // 2) & 1) ^ ((x[2]) & 1) ^ ((x[0] // 2) & 1)) * 8 +
        (((x[5]) & 1) ^ ((x[3] // 8) & 1) ^ ((x[0]) & 1)) * 4 +
        (((x[6] // 8) & 1) ^ ((x[3] // 4) & 1) ^ ((x[1] // 8) & 1)) * 2 +
        (((x[6] // 4) & 1) ^ ((x[3] // 2) & 1)
         ^ ((x[1] // 4) & 1)),

        (((x[6] // 2) & 1) ^ ((x[3]) & 1) ^ ((x[1] // 2) & 1)) * 8 +
        (((x[6]) & 1) ^ ((x[4] // 8) & 1) ^ ((x[1]) & 1)) * 4 +
        (((x[7] // 8) & 1) ^ ((x[4] // 4) & 1) ^ ((x[2] // 8) & 1)) * 2 +
        (((x[7] // 4) & 1) ^ ((x[4] // 2) & 1)
         ^ ((x[2] // 4) & 1)),
    ])

    return result


In [None]:
def big_sigma1(x):
    result = np.array([
        (((x[6] // 2) & 1) ^ ((x[5] // 4) & 1) ^ ((x[1]) & 1)) * 8 +
        (((x[6]) & 1) ^ ((x[5] // 2) & 1) ^ ((x[2] // 8) & 1)) * 4 +
        (((x[7] // 8) & 1) ^ ((x[5]) & 1) ^ ((x[2] // 4) & 1)) * 2 +
        (((x[7] // 4) & 1) ^ ((x[6] // 8) & 1) ^ ((x[2] // 2) & 1)),

        (((x[7] // 2) & 1) ^ ((x[6] // 4) & 1) ^ ((x[2]) & 1)) * 8 +
        (((x[7]) & 1) ^ ((x[6] // 2) & 1) ^ ((x[3] // 8) & 1)) * 4 +
        (((x[0] // 8) & 1) ^ ((x[6]) & 1) ^ ((x[3] // 4) & 1)) * 2 +
        (((x[0] // 4) & 1) ^ ((x[7] // 8) & 1) ^ ((x[3] // 2) & 1)),

        (((x[0] // 2) & 1) ^ ((x[7] // 4) & 1) ^ ((x[3]) & 1)) * 8 +
        (((x[0]) & 1) ^ ((x[7] // 2) & 1) ^ ((x[4] // 8) & 1)) * 4 +
        (((x[1] // 8) & 1) ^ ((x[7]) & 1) ^ ((x[4] // 4) & 1)) * 2 +
        (((x[1] // 4) & 1) ^ ((x[0] // 8) & 1) ^ ((x[4] // 2) & 1)),

        (((x[1] // 2) & 1) ^ ((x[0] // 4) & 1) ^ ((x[4]) & 1)) * 8 +
        (((x[1]) & 1) ^ ((x[0] // 2) & 1) ^ ((x[5] // 8) & 1)) * 4 +
        (((x[2] // 8) & 1) ^ ((x[0]) & 1) ^ ((x[5] // 4) & 1)) * 2 +
        (((x[2] // 4) & 1) ^ ((x[1] // 8) & 1) ^ ((x[5] // 2) & 1)),

        (((x[2] // 2) & 1) ^ ((x[1] // 4) & 1) ^ ((x[5]) & 1)) * 8 +
        (((x[2]) & 1) ^ ((x[1] // 2) & 1) ^ ((x[6] // 8) & 1)) * 4 +
        (((x[3] // 8) & 1) ^ ((x[1]) & 1) ^ ((x[6] // 4) & 1)) * 2 +
        (((x[3] // 4) & 1) ^ ((x[2] // 8) & 1) ^ ((x[6] // 2) & 1)),

        (((x[3] // 2) & 1) ^ ((x[2] // 4) & 1) ^ ((x[6]) & 1)) * 8 +
        (((x[3]) & 1) ^ ((x[2] // 2) & 1) ^ ((x[7] // 8) & 1)) * 4 +
        (((x[4] // 8) & 1) ^ ((x[2]) & 1) ^ ((x[7] // 4) & 1)) * 2 +
        (((x[4] // 4) & 1) ^ ((x[3] // 8) & 1) ^ ((x[7] // 2) & 1)),

        (((x[4] // 2) & 1) ^ ((x[3] // 4) & 1) ^ ((x[7]) & 1)) * 8 +
        (((x[4]) & 1) ^ ((x[3] // 2) & 1) ^ ((x[0] // 8) & 1)) * 4 +
        (((x[5] // 8) & 1) ^ ((x[3]) & 1) ^ ((x[0] // 4) & 1)) * 2 +
        (((x[5] // 4) & 1) ^ ((x[4] // 8) & 1) ^ ((x[0] // 2) & 1)),

        (((x[5] // 2) & 1) ^ ((x[4] // 4) & 1) ^ ((x[0]) & 1)) * 8 +
        (((x[5]) & 1) ^ ((x[4] // 2) & 1) ^ ((x[1] // 8) & 1)) * 4 +
        (((x[6] // 8) & 1) ^ ((x[4]) & 1) ^ ((x[1] // 4) & 1)) * 2 +
        (((x[6] // 4) & 1) ^ ((x[5] // 8) & 1) ^ ((x[1] // 2) & 1)),
    ])

    return result


## Implementing the Sha256 algorithm

* Now we have required functions, we can implement the Sha256 algorithm
* Lets start with padding the message

In [None]:
SHA256_MODULUS = 2 ** 32
BLOCK_SIZE_IN_BYTES = 64  # in bytes
WORD_SIZE_IN_BYTES = 4  # in bytes
BLOCK_SIZE_IN_WORDS = BLOCK_SIZE_IN_BYTES // WORD_SIZE_IN_BYTES
CHUNK_SIZE_IN_BITS = 4


In [None]:
def to_words(data):
    """
        to_words: Array of bytes -> Array of words(uint32)
    """
    data_len = data.shape[0]
    return np.array(
        [np.left_shift(data[i], 24) + np.left_shift(data[i+1], 16) + np.left_shift(data[i+2], 8) + data[i+3]
         for i in range(0, data_len, 4)])


def sha256_preprocess(text):
    """
        Takes a message of arbitrary length and returns a message
        of length that is a multiple of 512 bits, with the original message padded
        with a 1 bit, followed by 0 bits, followed by the original message length
        in bits, encoded as a 4 bit integers.
    """
    data = text
    # convert to uint4 and group into 32 bit words (8 uint4s)

    message_len = data.shape[0] * 8  # denoted as 'l' in spec

    # find padding length 'k'
    k = (((448 - 1 - message_len) % 512) + 512) % 512
    padstring = "1" + "0" * k + str(uint64_to_bin(message_len))
    total_size = len(padstring) + message_len

    assert total_size % 512 == 0

    pad = np.array([int(padstring[i:i+8], 2)
                   for i in range(0, len(padstring), 8)], dtype=np.uint8)
    padded = np.concatenate((data, pad))

    # convert bytes to words
    words = to_words(padded)

    # convert words to (8 x uint4)
    words = np.array([encode(word, 32, 4) for word in words])
    return words


* Now we can start the main function
* We use ConcreteNumpy's `cnp.tag` to track the fucntions in the computation graph

In [None]:
def add_modulo_32(x, y):
    """
        Array of uint4 -> Array of uint4

        Note: Calculate (x + y) mod 2^32 for x and y arrays of uint4

        * Starts from the least significant bit
        * Adds the bits of x and y
        * Carries the overflow to the next index on the left
    """
    z = cnp.zeros(x.shape)
    carry = cnp.zero()

    # process from 7 to 0 [x[0] is most significant]
    for i in range(7, -1, -1):
        sum = x[i] + y[i] + carry
        z[i] = sum & 0xf
        carry = sum >> 4
    return z


def add_modulo_32_multiple(elements):
    """
        Calculate (x0 + x1 + ... xn) mod 2^32 for arrays of uint4
        where x0, x1, ..., xn are (8 x uint4) arrays

        * Note this function implemented to increase performance when adding multiple elements
    """
    element_count = len(elements)
    z = cnp.zeros(elements[0].shape)
    carry = cnp.zero()

    # process from 7 to 0 [x[0] is most significant]
    for i in range(7, -1, -1):
        sum = np.sum(elements[j][i] for j in range(element_count)) + carry
        z[i] = sum & 0xf
        carry = sum >> 4

    return z


def add_modulo_32_encoded(x, y):
    z = cnp.zeros(x.shape)
    carry = cnp.zeros(x[0].shape)

    # process from 7 to 0 [x[0] is most significant]
    for i in range(7, -1, -1):
        sum = add_modulo_32_multiple(
            np.array([x[i], y[i], carry], dtype=object))
        # sum = x[i] + y[i] + carry
        z[i] = sum & 0xf
        carry = sum >> 4

    return z


def create_working_next(working, t1, t2):
    working_next = np.zeros_like(working)
    working_next[0] = add_modulo_32(t1, t2)
    working_next[1] = working[0]
    working_next[2] = working[1]
    working_next[3] = working[2]
    working_next[4] = add_modulo_32(working[3], t1)
    working_next[5] = working[4]
    working_next[6] = working[5]
    working_next[7] = working[6]
    return working_next


def sha256_hash(data, h0, k):
    # find number of blocks
    block_count = data.shape[0] // (BLOCK_SIZE_IN_WORDS)  # (16, 8)

    # initialize hash values
    h = h0

    # initialize round constants
    # N = block_count
    N = block_count
    W_SIZE = 64
    INNER_ROUNDS = 64

    # for each block in the message (each block is 512 bits)
    for i in range(N):
        # initialize message schedule
        w = cnp.zeros((W_SIZE, 8))

        # w_t = M_t for 0 <= t <= 15
        with cnp.tag("first-16-message-schedule-words"):
            for t in range(16):
                w[t] = data[i*BLOCK_SIZE_IN_WORDS + t]

        # w_t = sigma1(w_t-2) + w_t-7 + sigma0(w_t-15) + w_t-16 for 16 <= t <= 63
        with cnp.tag("remaining-48-message-schedule-words"):
            for t in range(16, W_SIZE):
                with cnp.tag("sigma1"):
                    sigma1_result = sigma1(w[t-2])
                with cnp.tag("sigma0"):
                    sigma0_result = sigma0(w[t-15])

                sigma_sum = add_modulo_32(sigma0_result, sigma1_result)
                head_sum = add_modulo_32(w[t-7], w[t-16])
                w[t] = add_modulo_32(sigma_sum, head_sum)

        # initialize working variables
        working = h
        # inner loop
        inner_rounds = INNER_ROUNDS
        with cnp.tag("inner-rounds"):
            for t in range(0, inner_rounds):
                with cnp.tag("big_sigma0"):
                    big_sigma0_result = big_sigma0(working[0])

                with cnp.tag("big_sigma1"):
                    big_sigma1_result = big_sigma1(working[4])

                with cnp.tag("ch"):
                    ch_result = ch(working[4], working[5], working[6])

                with cnp.tag("maj"):
                    maj_result = maj(working[0], working[1], working[2])

                with cnp.tag("t1"):
                    t1 = add_modulo_32_multiple(
                        np.array([working[7], big_sigma1_result, ch_result, k[t], w[t]], dtype=object))

                with cnp.tag("t2"):
                    t2 = add_modulo_32_multiple(
                        np.array([big_sigma0_result, maj_result], dtype=object))

                with cnp.tag("working_next"):
                    working_next = create_working_next(working, t1, t2)

                working = working_next

        with cnp.tag("update-hash"):
            h = add_modulo_32_encoded(h, working)

    return h


## Testing the Sha256 algorithm

* Lets test it out

In [None]:
# SHA-256 constants
K = np.array([
    0x428a2f98, 0x71374491, 0xb5c0fbcf, 0xe9b5dba5,
    0x3956c25b, 0x59f111f1, 0x923f82a4, 0xab1c5ed5,

    0xd807aa98, 0x12835b01, 0x243185be, 0x550c7dc3,
    0x72be5d74, 0x80deb1fe, 0x9bdc06a7, 0xc19bf174,

    0xe49b69c1, 0xefbe4786, 0x0fc19dc6, 0x240ca1cc,
    0x2de92c6f, 0x4a7484aa, 0x5cb0a9dc, 0x76f988da,

    0x983e5152, 0xa831c66d, 0xb00327c8, 0xbf597fc7,
    0xc6e00bf3, 0xd5a79147, 0x06ca6351, 0x14292967,

    0x27b70a85, 0x2e1b2138, 0x4d2c6dfc, 0x53380d13,
    0x650a7354, 0x766a0abb, 0x81c2c92e, 0x92722c85,

    0xa2bfe8a1, 0xa81a664b, 0xc24b8b70, 0xc76c51a3,
    0xd192e819, 0xd6990624, 0xf40e3585, 0x106aa070,

    0x19a4c116, 0x1e376c08, 0x2748774c, 0x34b0bcb5,
    0x391c0cb3, 0x4ed8aa4a, 0x5b9cca4f, 0x682e6ff3,

    0x748f82ee, 0x78a5636f, 0x84c87814, 0x8cc70208,
    0x90befffa, 0xa4506ceb, 0xbef9a3f7, 0xc67178f2,
])

H = np.array([
    0x6a09e667,  # h0
    0xbb67ae85,  # h1
    0x3c6ef372,  # h2
    0xa54ff53a,  # h3
    0x510e527f,  # h4
    0x9b05688c,  # h5
    0x1f83d9ab,  # h6
    0x5be0cd19,  # h7
], dtype=np.uint32)


In [None]:
# Initialize the inputs to the SHA-256 algorithm using constants
H_INIT = np.array([encode(H[i], np.uint32(32), chunk_size=np.uint32(
    CHUNK_SIZE_IN_BITS)) for i in range(H.shape[0])])

K_ENCODED = np.array([
    encode(K[i], np.uint32(32), chunk_size=np.uint32(CHUNK_SIZE_IN_BITS)) for i in range(K.shape[0])
])


In [None]:
Init Message Input
text = (
    b"Lorem ipsum dolor sit amet, consectetur adipiscing elit. "
    b"Curabitur bibendum, urna eu bibendum egestas, neque augue eleifend odio, et sagittis viverra."
)
assert len(text) == 150


### Firstly test the function without compiling

In [None]:
import hashlib


def to_hex(data, chunks=8, delim=""):
    # assert data.dtype == np.uint8
    return delim.join(list(map(lambda x: hex_with_chunks(x, chunks), data)))


hasher = hashlib.sha256()
hasher.update(text)

expected_output = np.array(list(hasher.digest()))


preprocessed = sha256_preprocess(np.array(list(text), dtype=np.uint8))

normal_run_result = sha256_hash(preprocessed, H_INIT, K_ENCODED)
normal_run_result = [decode(item, chunk_size=np.uint32(
    CHUNK_SIZE_IN_BITS)) for item in normal_run_result]
normal_run_result = to_hex(normal_run_result)

# print (f"Expected output: {expected_output}")
expected_output = "".join([hex_with_chunks(item, 2)
                          for item in expected_output])
print(f"Expected output  : {expected_output}")
print(f"Normal run result: {normal_run_result}")
assert expected_output == normal_run_result


## Now lets compile the function, and test

In [None]:
import time

# Now, compile and run fhe version
VIRTUAL = True

configuration = cnp.Configuration(
    p_error=0.01,
    loop_parallelize=True,
    enable_unsafe_features=True,
    use_insecure_key_cache=True,
    insecure_key_cache_location=".keys",

    verbose=True,
    dataflow_parallelize=True,
    virtual=VIRTUAL,
)

start_time = time.time()
input = preprocessed
print(f"Input shape {input.shape}")

compiler = cnp.Compiler(
    sha256_hash, {"data": "encrypted", "h0": "clear", "k": "clear"})
circuit = compiler.compile(
    inputset=[
        (np.random.randint(0, 15, size=input.shape, dtype=np.uint8), H_INIT, K_ENCODED)
    ],
    configuration=configuration,
)

took = time.time() - start_time
print(f"Compilation took {took:.2f} seconds")

if not VIRTUAL:
    inter_time = time.time()
    print("Encrypting...")
    encrypted = circuit.encrypt(input, h_init, K_ENCODED)
    print("encrypted:", encrypted)
    print("Encrypting took", time.time() - inter_time, "seconds")

    inter_time = time.time()
    print("Running...")
    fhe_output_encrypted = circuit.run(encrypted)
    print("fhe_output_enc:", fhe_output_encrypted)
    print("Running took", time.time() - inter_time, "seconds")

    inter_time = time.time()
    print("Decrypting...")
    fhe_output = circuit.decrypt(fhe_output_encrypted)
    fhe_output = [decode(item, chunk_size=np.uint32(
        CHUNK_SIZE_IN_BITS)) for item in fhe_output]
    fhe_output = to_hex(fhe_output)
    assert fhe_output == expected_output
