#  Coding SHA-256 with concrete numpy



## I. SHA-256 introduction
SHA is a hashing algorithm taking as input a sequence of bytes of any length, like a text string, and outputing a sequence of fixed length called a hash. This hash will change if the input bytes change ever so slightly, with a very small propability of two inputs having the same hash, and the hash is also impossible to reverse to find one of the possible inputs. This makes SHA a perfect way to create reliable footprints of files, and garantying they don't get corrupted. The SHA-256 is the version outputing a hash of 256 bits (64 hexadecimal characters) hash.  

The goal of thhis tutorial is to explain how SHA-256 works and how to make it run on encrypted data using concrete numpy.


## II. SHA-256 in python with numpy
First of all, let's take a look at the initial algorithm, working on clear, unencrypted inputs, in regular numpy. In the file `sha256_original.py`, it has been coded in python following the [official publication](http://csrc.nist.gov/publications/fips/fips180-4/fips-180-4.pdf) and keeping the exact same notations, for a fixed input length of 150 text characters.

### Constants
First, let's declare the constants used by the algorithm:
- **K256**: an array of 64 **32-bits** integers used in the algorithm (They are the first thirty-two bits of the fractional parts of the cube roots of the first sixty-four prime numbers).
- **H_INITIAL**: an array of 8 **32-bits** integers used to initialise H in the algorithm.


In [11]:
import numpy as np

K256 = [
    0x428a2f98, 0x71374491, 0xb5c0fbcf, 0xe9b5dba5, 0x3956c25b, 0x59f111f1, 0x923f82a4, 0xab1c5ed5,
    0xd807aa98, 0x12835b01, 0x243185be, 0x550c7dc3, 0x72be5d74, 0x80deb1fe, 0x9bdc06a7, 0xc19bf174,
    0xe49b69c1, 0xefbe4786, 0x0fc19dc6, 0x240ca1cc, 0x2de92c6f, 0x4a7484aa, 0x5cb0a9dc, 0x76f988da,
    0x983e5152, 0xa831c66d, 0xb00327c8, 0xbf597fc7, 0xc6e00bf3, 0xd5a79147, 0x06ca6351, 0x14292967,
    0x27b70a85, 0x2e1b2138, 0x4d2c6dfc, 0x53380d13, 0x650a7354, 0x766a0abb, 0x81c2c92e, 0x92722c85,
    0xa2bfe8a1, 0xa81a664b, 0xc24b8b70, 0xc76c51a3, 0xd192e819, 0xd6990624, 0xf40e3585, 0x106aa070,
    0x19a4c116, 0x1e376c08, 0x2748774c, 0x34b0bcb5, 0x391c0cb3, 0x4ed8aa4a, 0x5b9cca4f, 0x682e6ff3,
    0x748f82ee, 0x78a5636f, 0x84c87814, 0x8cc70208, 0x90befffa, 0xa4506ceb, 0xbef9a3f7, 0xc67178f2,
]

H_INITIAL = [0x6a09e667, 0xbb67ae85, 0x3c6ef372, 0xa54ff53a, 0x510e527f, 0x9b05688c, 0x1f83d9ab, 0x5be0cd19]



### Preparing the input
Let's now prepare our input message from a text of length 150 characters. This message could be any array of bytes of length 150, not necessary a text.

In [12]:
text = (
    b"Lorem ipsum dolor sit amet, consectetur adipiscing elit. "
    b"Curabitur bibendum, urna eu bibendum egestas, neque augue eleifend odio, et sagittis viverra."
)
assert len(text) == 150

# convert text to byte list and then to numpy array
message = list(text)
message = np.array(message).astype(np.int32)


### Padding the message
The first operation is a padding of the message that will make the bit length a multiple of 512. In our scenario this function is just a concatenation of constant arrays because the padding content depends only on the length of the message, which is a constant for us. To determine the arrays to concatenate, we followed the steps:
- Compute **l** the bit length of the message: `l = 150*8 = 1200`
- Find **k** such that `l+k+1 = 448 % 512`, which gives `k=271`
- Convert **l** into its 64-bits binary representation, following big-endian encoding (big bits first): `l64 = 00000000...0000010010110000`
- Split the bits into bytes (8-bits): `l64_8 = [0,0,0,0,0,0,4,176]`
- Write 1 followed by **k** zeros as bytes in **_1k0**
- Concatenate the message with **_1k0** and **l64_8**  

The padded message has now a bit length of **3** times **512 bits**. Chunks of **512 bits** are called **blocks**.  

In [13]:
def padding_150(M):
    """
    SHA-256 padding for a message with fixed lenght of 150 characters (see section 5.1.1 in the paper)
    """
    assert(len(M)==150)
   
    """
    Length in bits is : l = 150*8 = 1200
    Find k such that l+k+1 = 448 % 512 => this gives k = 271
    Convert l=1200 in 64 bits binary: l64 = 0000000000000000000000000000000000000000000000000000010010110000
    Convert l64 to an array of 8-bits integers, which is the format of the input message:
    l64_8 = [00000000,00000000,00000000,00000000,00000000,00000000,00000100,10110000] = [0,0,0,0,0,0,4,176]
    """
    l64_8 = np.array([0,0,0,0,0,0,4,176]).astype(np.int32)
    
    """
    Write 1 followed by k zeros in binary as an array of 8-bits integers:
    _1k0 = [ 10000000, 00000000, ..., 00000000]
    _1k0 = [128,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0]
    """
    _1k0 = np.array([128,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0]).astype(np.int32)

    """
    The padded message is the concatenation of <M> <1 followed by k zeros> <l in 64 bits> 
    It has size 192 bytes = 192*8 bits = 3*512 bits = 3 blocks of 512 bits
    """
    return np.concatenate([M, _1k0, l64_8])


padded_message = padding_150(message)

### Parsing the padded message
The parsing consists in regrouping the bits of each block into **16** words of size **32 bits** (512=16\*32). The original algorithm operates directly on **32-bits** integers, so each word is a **32-bits** integer, but this will be one a the main changes in the next section with concrete numpy.

In [14]:
def parsing_150(M):
    """
    SHA-256 parsing for a padded message with initial fixed lenght (before padding) of 150 characters
    (see section 5.2.1 in the paper)
    """    
    assert(len(M)==192); # after padding, length is now 192
    N=3
    """
    We split M into 3 x 64 bytes (3 x 512 bits) and split each 512 bits into 16 words of 32 bits
    """

    # initialize array
    parsed_M = np.zeros([N,16])

    # loop through blocks
    for i in range(0,N):
        # loop through words
        for j in range(0,16):
            # Convert groups of 4 bytes (8-bits integers) into 32-bits integers
            ind = 64*i + j*4
            parsed_M[i][j]= (M[ind] << 24) ^ (M[ind+1] << 16) ^ (M[ind+2] << 8) ^ M[ind+3]

    return parsed_M

parsed_message = parsing_150(padded_message)

### Hashing the parsed message

The algorithm uses functions operating on **32-bits** integers, using the operators:  

| Operation | Math symbol | Python symbol |
|-----------|--------|-------|
| AND | ∧ | `&` |  
| OR | ∨ | `\|` |  
| XOR | ⊕ | `^` | 
| Right-shift | >> | `>>` |  
| Left-shift | << | `<<` |  

- The **Right-shift** operation of n bits `x >> n` is obtained by discarding the rightmost n bits of x and then padding the result with n zeroes on the left (see function `SHR`)  
- The **Left-shift** operation of n bits `x << n` is obtained by discarding the left-most n bits of x and then padding the result with n zeroes on the right (see function `SHL`)  
- The **AND** operator can be used to cast a number `x` to 32 bits `x & (2**32 - 1)` which is equal to `x` modulo `2**32` 
- The `add32` function performs an addition of two integers modulo `2**32` as requested in the publication.  
- The right rotate `ROTR` function rotates the bits to the right using the union of a right shift of `n` bits and a left shit of `32-n` bits.  
- The `sigma0`, `sigma1`, `SIGMA0`, `SIGMA1`, choice function `Ch`and majority function `Maj` are also defined below.

In [15]:
TP32=2**32
BITS_PER_WORD = 32

def SHR(x, n):
    """
    The right shift operation.
    Cast x to 32 bits and then shift it of n bits to the right
    """
    return (x & (TP32-1)) >> n

def SHL(x, n):
    """
    The left shift operation.
    Shift x of n bits to the left and then cast it to 32 bits
    """
    return (x << n) & (TP32-1)

def ROTR(x, n):
    """
    The rotate right (circular right shift) operation.
    It is a union of a right shift of n bits and a left shit of w-n bits
    """
    return SHR(x,n) | SHL(x,BITS_PER_WORD-n)


def Ch(x, y, z):
    """
    The choose function
    """
    return z ^ (x & (y ^ z))


def Maj(x, y, z):
    """
    The majority function
    """
    return ((x | y) & z) | (x & y)


def SIGMA0(x):
    """
    Upper case sigma 0 function
    """     
    return ROTR(x, 2) ^ ROTR(x, 13) ^ ROTR(x, 22)


def SIGMA1(x):
    """
    Upper case sigma 1 function
    """      
    return ROTR(x, 6) ^ ROTR(x, 11) ^ ROTR(x, 25)


def sigma0(x):
    """
    Lower case sigma 0 function
    """
    return ROTR(x, 7) ^ ROTR(x, 18) ^ SHR(x, 3)


def sigma1(x):
    """
    Lower case sigma 1 function
    """      
    return ROTR(x, 17) ^ ROTR(x, 19) ^ SHR(x, 10)


def mod2p32(x):
    """
    Computes x modulo 2**32 which is equivalent to cast x to 32 bits
    """ 
    return x & (TP32-1)


def hexdigest(digest):
    """"
    Convert bytes (8 bits) to string of hex symbols (4 bits)
    """
    hexdigest = [0]*(2*len(digest));
    for i in range(0,len(digest)):
        d=digest[i]
        hexdigest[i*2] = (d >> 4) & 15
        hexdigest[i*2+1] = d & 15

    hex_chars = [hex(x)[-1] for x in hexdigest]
    return ''.join(hex_chars)


Now we can write the core of the algorithm, following the same notations as in the original publication:

In [16]:
# initialize array H
H=np.zeros(8,dtype=np.int32);
for i in range(0,8):
    H[i] = H_INITIAL[i];

# main loop
for i in range(0,3):
    #1 prepare the message schedule W
    W=np.zeros(64,dtype=np.int32);
    W[0:16]=parsed_message[i,:]
    for t in range(16,64):
        W[t] = mod2p32(sigma1(W[t-2]) + W[t-7] + sigma0(W[t-15]) + W[t-16])

    #2 initialize values of a,b,c,d,e,f,g,h with previous values in H
    a=H[0]; b=H[1]; c=H[2]; d=H[3]; e=H[4]; f=H[5]; g=H[6]; h=H[7];

    #3
    for t in range(0,64):
        # use mod2p32 to compute the addition modulo 2**32
        T1= mod2p32(h + SIGMA1(e) + Ch(e,f,g) + K256[t] + W[t])
        T2= mod2p32(SIGMA0(a) + Maj(a,b,c))
        h=g; g=f; f=e
        e= mod2p32(d + T1)
        d=c; c=b; b=a
        a= mod2p32(T1+T2)

    #4 compute update of H
    H[0]=mod2p32(H[0]+a); H[1]=mod2p32(H[1]+b); H[2]=mod2p32(H[2]+c); H[3]=mod2p32(H[3]+d);
    H[4]=mod2p32(H[4]+e); H[5]=mod2p32(H[5]+f); H[6]=mod2p32(H[6]+g); H[7]=mod2p32(H[7]+h);

"""
Finally, the result is the concatenation of the bits of H
"""
# The result is the concatenation of bits of H, which are 8 x 32 bits
# 8 x 32 bits is also 32 x 8 bits which is 32 bytes
digest = np.zeros(32,dtype=np.short);
for i in range(0,8):
    h=H[i]
    # split 32-bits integers into 4 x 8-bits integers
    for j in range(0,4):
        digest[i*4+j] = (h >> (8*(3-j))) & 255
     
print('digest: ', digest)

digest:  [142  81  42  35 184 164 123 211 178 193  74 131  72 226 202  27 129   5
  61 244   8  90  21 189 116 175 166  63 115 114  10 214]


The digest is as expected a sequence of **32 bytes** (32\*8 bits), which we can also convert into a sequence of **64** hexadecimal characters, called the **hash**:

In [17]:
message_hash = hexdigest(digest)
print(message_hash)

8e512a23b8a47bd3b2c14a8348e2ca1b81053df4085a15bd74afa63f73720ad6


And we can test it against the hashlib library:

In [18]:
import hashlib

hasher = hashlib.sha256()
hasher.update(text)
assert(message_hash == hasher.hexdigest());
print('Hash is correct !')

Hash is correct !


You can use the `sha256_original.py` file for importing the `sha256_150` function or use it directly in the commande line to run tests.

## III. SHA-256 with concrete numpy

Now that we have a straightforward working algorithm coded in python with numpy, let's convert it to a concrete numpy circuit, in order to compute a hash homomorphically on encrypted inputs. The tutorial follows code of the files `sha256.py` and using `utils.py`.

### Tackling the bit width limitation issue

The biggest issue for converting our algorithm into a concrete circuit is the limitation on the bit width of integers. In our algorithm, integers are **32-bits**, but concrete numpy has for now a limitation with a maximum of 16-bits encrypted integers. We thus need to find a way to represent a **32-bits** integer with an array of smaller integers of **n-bits** with length `nchunks=32/nbits`. Hence we can represent a **32-bits** integer as follows:
- **32** integers of **1-bit**
- **16** integers of **2-bits**
- **8** integers of **4-bits**
- **4** integers of **8-bits** 

We cannot use **2** integers of **16-bits** in this case as we will see in the summing algorithm.

#### *1. splitting and merging bits*

We first need a function to split the bits of a **32-bits** integer into an array of **n-bits** integers of length `nchunks=32/nbits`:

In [20]:
import concrete.numpy as cnp

def split_bits(num, nbits, bit_width=32):
    """
    Splits an integer of 'bit_width' bits into smaller integers of 'nbits' bits each.
    """
    # Check if nbits is a valid value
    if not (nbits in [1, 2, 4, 8, 16]):
        raise ValueError("nbits must be in [1, 2, 4, 8, 16]")

    if bit_width<nbits:
        raise ValueError("nbits must be lower than bit_width")

    # Convert the input integer to binary string
    binary_str = format(num, '0' + str(bit_width) + 'b')

    # Split the binary string into smaller chunks of 'nbits' bits each
    chunks = [binary_str[i:i+nbits] for i in range(0, bit_width, nbits)]

    # Convert each chunk to integer
    integers = [int(chunk, 2) for chunk in chunks]

    # Convert the list of integers to a NumPy array
    return np.array(integers)

print(split_bits((2**32)-1, 1, bit_width=32))
print(split_bits((2**32)-1, 2, bit_width=32))

[1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1]
[3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3]


The file `utils.py` also provides a function for merging such array back into a 32-bit integer, and also functions to create an array of arrays from an array of **32-bits** integers, and the reverse.

In [23]:
import utils

assert utils.merge_bits( split_bits((2**32)-1, 1, bit_width=32), 1, 32) == (2**32-1)
assert utils.merge_bits( split_bits((2**32)-1, 2, bit_width=32), 2, 32) == (2**32-1)
print('split and merge correctly')

split and merge correctly


In [24]:
binarray=utils.ints_to_nbits([1111,2222,3333],1,32)
print(binarray)

[[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0 0 0 1 0 1 0 1 1 1]
 [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0 0 0 1 0 1 0 1 1 1 0]
 [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 0 1 0 0 0 0 0 1 0 1]]


In [25]:
print(utils.ints_from_nbits( binarray, 1, 32))

[1111 2222 3333]


#### 2. Implementing operations on the n-bits integers array representation
We cannot use the regular array summing of two arays to sum these arrays, because if we did, we would get integers going up to **n+1 bits**, leading to an incorrect encoding of the **32-bits** integer represented.  

Let's first create a summing algoritm (modulo `2**32`) for two binary arrays reprensenting each a **32-bits** integer, and that can operate homomorphically.  

A typical algorithm of this kind could be: 
```python
def sumBin(x: np.ndarray, y: np.ndarray):
    result = cnp.zeros(32)
    carry = 0
    for i in range(32, -1, -1):
        # Calculate the sum of the current integer and the carry
        s = x[i] + y[i] + carry
        # If the sum is greater than 3, there is a carry for the next integer
        if s > 1:
            carry = 1
            s -= 2
        else:
            carry = 0
        # Set the current integer in the result, last carry is unused modulo 2**32
        result[i] = s
    return result

```

The main issue here would be that comparing `s > 1` would not work on encrypted data. So we need to make it without the use of a carry. The following algorithm uses instead only homomorphic operations:

In [2]:
def sumBin(x: np.ndarray, y: np.ndarray):
        """
        Addition mod 2**32 of two 32-bits integers encoded as arrays of 32 1-bit integers
        """
        n=32-1
        z=np.zeros(32, dtype=np.uint32)
        # first sum the last integer (starting from the end to account for big-endian bit encoding)
        z[n] = x[n] + y[n]

        for i in range(1,32):
            # each integer of the result is then the sum of the integers of the inputs
            # and of the previously summed integer shifted of 1 bit
            z[n-i] = x[n-i] + y[n-i] + (z[n-i+1] >> 1)
            # then we cast this previous integer to 1 bit
            z[n-i+1] = z[n-i+1] & (2**1-1) 

        # finally cast the first integer (32th bit) to 1 bit so the result is a sum modulo 2**32
        z[0] = z[0] & (2**1-1)
        
        return z
        
# test it on our binary array
result1 = sumBin(binarray[0,:], binarray[1,:])
assert( np.all(result1==binarray[2,:]))
print('sum: ',result1)

NameError: name 'np' is not defined

Now we can make this code **generic** by making a **template** function that can deal with other bit widths. In python this can be made with either a factory function (see later) or a **class** definition taking the `nbits` value at initiliazation.

We can also specify wether we want to use concrete numpy or simply numpy. This is not mandatory for creating a circuit in concrete numpy, but it is very usefull for debugging, because we can compare the outputs produced by the numpy only, clear version, and the concrete numpy, homomorphic version. For now, let's test with numpy only.

Note that when we add up two integers, we need to store them in an integer that at least 1 more bit.  
This is why we cannot use **16-bits** integers here, because they would require a **17-bits** integer that is not availabe in concrete numpy.

In [27]:
class FnSplit:
    """
    Mother class of factory classes for functions operating on 32-bits integers encoded into several smaller integers
    """
    def __init__(self, nbits, use_cnp=True):
        assert( nbits in [1,2,4,8,16] )
        self.nbits = nbits
        self.nchunks = int(32/nbits)        
        self._2p_nbits = 2**nbits
        if use_cnp:
            self.zeros = lambda : cnp.zeros(self.nchunks)
        else:
            self.zeros = lambda : np.zeros(self.nchunks, dtype=np.uint32)


class AddSplit(FnSplit):
    """
    Factory class for Addition functions operating on 32-bits integers encoded into several smaller integers
    """
    def __call__(self, x: np.ndarray, y: np.ndarray):
        """
        Addition mod 2**32 of two 32-bits integers encoded as arrays of nchunks n-bits integers (big-endian)
        """
        z = self.zeros()
        n=self.nchunks-1
        z[n] = x[n] + y[n] # z will need 1 more bit than x and y, so they cannot be 16 bits

        for i in range(1,self.nchunks):
            z[n-i] = x[n-i] + y[n-i] + (z[n-i+1] >> self.nbits)
            z[n-i+1] = z[n-i+1] & (self._2p_nbits-1) 

        # cast first integer to nbits to be modulo 2**32
        z[0] = z[0] & (self._2p_nbits-1)

        return z

        
# create a function to sum modulo 2**32 two arrays with 2-bits encoding
add2 = AddSplit(2, False)

twoBitsArray=utils.ints_to_nbits([1111,2222,3333],2,32)
print(twoBitsArray)

result2 = add2(twoBitsArray[0,:], twoBitsArray[1,:])
assert( np.all(result2==twoBitsArray[2,:]))
print("\nresult:",result2)

[[0 0 0 0 0 0 0 0 0 0 1 0 1 1 1 3]
 [0 0 0 0 0 0 0 0 0 0 2 0 2 2 3 2]
 [0 0 0 0 0 0 0 0 0 0 3 1 0 0 1 1]]

result: [0 0 0 0 0 0 0 0 0 0 3 1 0 0 1 1]


In the same way, let's make a **template** function for the right rotation of the bits of a **32-bits** integer encoded as an array of **n-bits** integers. Making a **template** function also allows us to store some variables locally in the class that will serve the computation, and not bother with them outside.  
\
To rotate the encoded 32-bits integer of `y` bits, the idea is to first decompose `y` into `yc` and `yr` such that `y = yc*nbits + yr`, then to rotate the whole array of `yc` spots, and then to rotate each integer of the new array of the remaining `yr` bits, where `yr` is now smaller than `nbits`.  

The parameter `y` will not be encrypted, it is a known constant. So, we can write conditions such as `if(yc>0):`. At the opposite, `x` will be encrypted, so we need to work on it with only homomorphic compatible operations such as bit shifting `>>`, bitwse AND `&`, addition `+` etc.

In [3]:
class ROTRSplit(FnSplit):
    """
    Factory class for ROTR functions operating on 32-bits integers encoded into several smaller integers
    """       

    def __call__(self, x: np.ndarray, y: np.uint32):
        """
        Right bits rotation of a 32-bits integers encoded as array of nchunks n-bits integers (big-endian)
        """  
        z = self.zeros()
        temp = self.zeros()

        # first right rotate array of amount int(y/nbits)
        yc = int(y/self.nbits)
        if(yc>0):
            # right shift of yc chunks
            temp[yc:] = x[:-yc]
            # left shift of nchunks-yc chunks
            temp[:yc] = x[-yc:]
        else:
            temp[:] = x[:]

        # now rotate everything with remaining shift
        yr = y%self.nbits
        if(self.nbits>1 and yr>0):
            for i in range(0,self.nchunks):
                z[i] = (temp[i] >> yr) + ( (temp[i-1] << (self.nbits-yr)) & (self._2p_nbits-1) )
        else:
            z[:] = temp[:]

        return z
            
            
# test it on the previous arrays, using only numpy for now:
rotr1 = ROTRSplit(1, use_cnp=False)
result1 = rotr1(binarray[2,:], 3)
print('binary array:')
print(binarray[2,:])
print('\nthe array is right rotated of 3 bits:')
print(result1)

rotr2 = ROTRSplit(2, use_cnp=False)
result2 = rotr2(twoBitsArray[2,:], 2)
print('\n2-bits array:')
print(twoBitsArray[2,:])
print('\nthe bits are right rotated of 2 bits, which causes the 2-bits integers to be rotated themselves in the array:')
print(result2)
result2 =rotr2(twoBitsArray[2,:], 3)
print('\nWith 3 bits rotation, the integers are modified, but the overal bit sequence is well rotated:')
print(result2)

NameError: name 'FnSplit' is not defined

We can do a similar function for the right shift operation, where we first shift the array of integers itself of `yc` spots, then we shift the bits with the remaining `yr` shift.  
\
Note that during the shifting of the array, we need to set the left values of the array to zero. To do this on encrypted data, we cannot use the value `0`, so we use the values in `cnp.zeros(.)` which are encrypted zeros.

In [29]:
class SHRSplit(FnSplit):
    """
    Factory class for SHR functions operating on 32-bits integers encoded into several smaller integers 
    """

    def __call__(self, x: np.ndarray, y: np.uint32):
        """
        Right shift of a 32-bits integers encoded as array of nchunks n-bits integers (big-endian)
        """        
        # first right shift array of int(y/nbits)

        z = self.zeros()
        temp = self.zeros()

        yc = int(y/self.nbits)
        if(yc>0):
            # right shift of yc chunks
            temp[yc:] = x[:-yc]
            # the left values are null
        else:
            temp[:] = x[:]

        # now shift everything with remaining shift
        yr = y%self.nbits
        if(self.nbits>1 and yr>0):
            for i in range(1,self.nchunks):
                z[i] = (temp[i] >> yr) + ( (temp[i-1] << (self.nbits-yr)) & (self._2p_nbits-1) )
            z[0] = (temp[0] >> yr) # first chunk is a simple right shift
        else:
            z[:] = temp[:]

        return z

# test it on the previous arrays, using only numpy for now:
shr1 = SHRSplit(1, use_cnp=False)
result1 = shr1(binarray[2,:], 3)
print('binary array:')
print(binarray[2,:])
print('\nthe array is right shifted of 3 bits:')
print(result1)

shr2 = SHRSplit(2, use_cnp=False)
result2 = shr2(twoBitsArray[2,:], 2)
print('\n2-bits array:')
print(twoBitsArray[2,:])
print('\nthe bits are right shifted of 2 bits, which causes the 2-bits integers to be shifted themselves in the array:')
print(result2)
result2 = shr2(twoBitsArray[2,:], 3)
print('\nWith 3 bits shifting, the integers are modified, lefts ones stay 0')
print(result2)

binary array:
[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 0 1 0 0 0 0 0 1 0 1]

the array is right shifted of 3 bits:
[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 0 1 0 0 0 0 0]

2-bits array:
[0 0 0 0 0 0 0 0 0 0 3 1 0 0 1 1]

the bits are right shifted of 2 bits, which causes the 2-bits integers to be shifted themselves in the array:
[0 0 0 0 0 0 0 0 0 0 0 3 1 0 0 1]

With 3 bits shifting, the integers are modified, lefts ones stay 0
[0 0 0 0 0 0 0 0 0 0 0 1 2 2 0 0]


In the same fashion, we can design a template class `Sigmas` to hold the four functions `sigma0`, `sigma1`, `SIGMA0`, `SIGMA1` using internally a `ROTR` and a `SHR` as well as the bitwise **XOR** operator `^` which is included in concrete numpy:

In [30]:
class Sigmas(FnSplit):
    """
    Class holding variables to run sigma functions
    """
    def __init__(self, nbits, use_cnp=True):
        FnSplit.__init__(self, nbits)        

        self.ROTR = ROTRSplit(nbits,use_cnp)
        self.SHR = SHRSplit(nbits,use_cnp)

    def SIGMA0(self, x: np.ndarray):
        """
        Upper case sigma 0 function
        """     
        return self.ROTR(x, 2) ^ self.ROTR(x, 13) ^ self.ROTR(x, 22);

    def SIGMA1(self, x: np.ndarray):
        """
        Upper case sigma 1 function
        """      
        return self.ROTR(x, 6) ^ self.ROTR(x, 11) ^ self.ROTR(x, 25);


    def sigma0(self, x: np.ndarray):
        """
        Lower case sigma 0 function
        """
        return self.ROTR(x, 7) ^ self.ROTR(x, 18) ^ self.SHR(x, 3);


    def sigma1(self, x: np.ndarray):
        """
        Lower case sigma 1 function
        """   
        return self.ROTR(x, 17) ^ self.ROTR(x, 19) ^ self.SHR(x, 10);


#### 3. Processing the inputs to fit our new integers representation
Before jumping to the core algorithm, we need to process the input text and the constants to convert them into the **n-bits** representation.  

We also perform the padding operation here, before encryption, because the padded part would need to get encrypted during the computation, which we can as well do before. Also, doing so hides the length of the message, which makes it an even better secret.

In [31]:
def processInput(text, H, K256_, nbits):
    """"
    Processes SHA-256 inputs for concrete-numpy circuit
    """    
    # convert text to uint8
    textAsInts = list(text)

    # apply padding 
    textAsInts = padding_150(textAsInts)

    # convert text from 8 bits to nbits
    textAsInts = utils.ints_to_nbits(textAsInts, nbits, 8).flatten() # flatten this one
    # convert constants from 32 (default) bits to nbits
    H_nbits = utils.ints_to_nbits(H, nbits)
    K256_nbits = utils.ints_to_nbits(K256_, nbits)

    return (textAsInts, H_nbits, K256_nbits)

We will also need a function to convert the ouptut back to a readable format by converting it to a hexadecimal hash:

In [32]:
def outputToHash(output, nbits):
    """"
    Processes the circuit output into a hexadecimal hash
    """      
    nchunks = int(32/nbits)

    # first convert output back to uint8 values
    output32 = utils.ints_from_nbits(output.reshape((8,nchunks)),nbits,32)
    output8 = utils.ints_to_nbits( output32, 8, 32)
    
    # also convert to hex
    outputHash = hexdigest(output8.flatten())

    return outputHash


### Implementing the core algorithm

Now we can finally proceed to implementing the core algorithm. The main changes with the simple python version are the use of arrays of **n-bits** integers instead of simple **32-bits** integers, and the use of the corresponding operators to sum, rotate them etc. Thus, instead of directly summing many variables, we must split the sum into several summing operations.  

As earlier, let's make a class factory for **SHA-256** circuit with:
- Template value of **n** bits.
- The possibility to run with **numpy** only
- A **quick test** mode that will not loop, useful for testing quicly with concrete numpy.  

We provide as encrypted inputs to the circuit the preprocessed values of our text message, but also of the constants `H_INITIAL` and `K256` which are also inputs of the algorithm.  

When dealing with array objects, make sure to copy the array values and not the array objects themselves, by using for instance the notation `h[:]=g[:]`.

In [33]:
def sha256CircuitFactory(nbits, use_cnp=True, quick_test=False):

    assert(nbits in [1,2,4,8])
    nchunks = int(32/nbits)

    if use_cnp:
        zeros = lambda shape: cnp.zeros(shape)
        types = {1:cnp.uint1, 2:cnp.uint2, 4:cnp.uint4, 8:cnp.uint8}
        tensorTypes = (cnp.tensor[types[nbits], int(nchunks*192*8/32) ],
                       cnp.tensor[types[nbits], 8, nchunks],
                       cnp.tensor[types[nbits], 64, nchunks])
    else:
        # Warning, with numpy the inputs array H is modified by the circuit
        zeros = lambda shape: np.zeros(shape, dtype=np.uint32)
        tensorTypes = (np.ndarray, np.ndarray, np.ndarray)        

    #create functions for processing n-bits
    sigmas = Sigmas(nbits, use_cnp)
    add = AddSplit(nbits, use_cnp)

    # the computation being slow, a quick test can be made by setting these variables to 1
    N = 3 if not quick_test else 1
    Nt = 64 if not quick_test else 1  

    def sha256_150(M: tensorTypes[0], H: tensorTypes[1], K256: tensorTypes[2]):
        """
        SHA-256 implementation for a message with fixed lenght of 150 characters encoded as 150x8 binary values
        Returns a digest of 32 bytes
        """

        print('Compiling sha256_150...')

        """
        padding is already done, so apply parsing: making 3x16 32-bits words encoded as nchunks nbits integers
        """
        parsed_M = M.reshape((3,16,nchunks))

        """
        Then, proceed to the main computation (see section 6.2 in the paper)
        """

        #create W and a,b,c,d,e,f,g arrays
        W=zeros((64,nchunks))        
        a=zeros(nchunks); b=zeros(nchunks); c=zeros(nchunks); d=zeros(nchunks);
        e=zeros(nchunks); f=zeros(nchunks); g=zeros(nchunks); h=zeros(nchunks);

        #main loop
        for i in range(0,N):
            #1 prepare the message schedule W
            W[0:16,:]=parsed_M[i,:,:]
            for t in range(16,Nt):
                W[t] = add(add(sigmas.sigma1(W[t-2]), W[t-7]),
                           add(sigmas.sigma0(W[t-15]), W[t-16]))

            #2 initialize values of a,b,c,d,e,f,g,h with previous values in H
            a[:]=H[0]; b[:]=H[1]; c[:]=H[2]; d[:]=H[3]; e[:]=H[4]; f[:]=H[5]; g[:]=H[6]; h[:]=H[7];

            #3
            for t in range(0,Nt):
                T1 = add( add(h, sigmas.SIGMA1(e)), add( add(Ch(e,f,g),K256[t]),  W[t]))
                T2 = add(sigmas.SIGMA0(a), Maj(a,b,c))
                h[:]=g[:]; g[:]=f[:]; f[:]=e[:] # ! be sure to copy values with [:] and not the array objects
                e = add(d,T1)
                d[:]=c[:]; c[:]=b[:]; b[:]=a[:]
                a = add(T1,T2)

            if quick_test:
                # return (a,b,c,d,e,f,g,h) in H. Warning: the hash will be incorrect
                H[0,:]=a[:]; H[1,:]=b[:]; H[2,:]=c[:]; H[3,:]=d[:];
                H[4,:]=e[:]; H[5,:]=f[:]; H[6,:]=g[:]; H[7,:]=h[:];
                break

            #4 compute update of H
            H[0]=add(H[0],a); H[1]=add(H[1],b); H[2]=add(H[2],c); H[3]=add(H[3],d);
            H[4]=add(H[4],e); H[5]=add(H[5],f); H[6]=add(H[6],g); H[7]=add(H[7],h);

        """
        Finally, the result is the concatenation of the values of H
        """
        print('Done')    
        return H.reshape((8*nchunks,))

    # create this function to copy the value of H in python mode, otherwise it is modified by the circuit
    def circuit(M: tensorTypes[0], H: tensorTypes[1], K256: tensorTypes[2]):
        return sha256_150(M, H if use_cnp else H.copy(), K256)

    return circuit

Note that we did the parsing in one step this time with a simple reshaping, because we had already preprocessed the input bit sequence.

Now let's test this circuit with **numpy** only, and compare it to the hashing library **hashlib**:

In [48]:
# process input for n-bit encoding
nbits=2
(textAsInts, H_nbits, K256_nbits) = processInput(text, H_INITIAL, K256, nbits)

# create the full circuit for numpy only to test it
sha256circuit = sha256CircuitFactory(nbits, use_cnp=False, quick_test=False)

# no encryption, trivial computation with numpy
output = sha256circuit(textAsInts, H_nbits, K256_nbits)

# compute expected output with hashlib
hasher = hashlib.sha256()
hasher.update(text)

# Finally, convert output into a hexadecimal hash and compare it to hashlib
outputHash = outputToHash(output, nbits)
expectedHash = hasher.hexdigest()

print(outputHash)
assert(outputHash==expectedHash)
print("\nhashes are equal !")

Compiling sha256_150...
Done
8e512a23b8a47bd3b2c14a8348e2ca1b81053df4085a15bd74afa63f73720ad6

hashes are equal !


The circuit runs correctly with numpy. Let's now test it in **quick test** mode for both numpy and concrete, and compare the results.  
Note that input set must have values of maximum `2**nbits-1`, if not correctly set this will lead to errors.

In [46]:
# create quick mode circuits for numpy and concrete
sha256circuit_quick_np = sha256CircuitFactory(nbits, use_cnp=False, quick_test=True)
sha256circuit_quick_cnp = sha256CircuitFactory(nbits, use_cnp=True, quick_test=True)

# compile the cnp circuit
configuration = cnp.Configuration(
    enable_unsafe_features=True,
    use_insecure_key_cache=True,
    insecure_key_cache_location=".keys",
)

nchunks=int(32/nbits)

compiler = cnp.Compiler(sha256circuit_quick_cnp, {"M": "encrypted", "H": "encrypted", "K256": "encrypted"})
circuit = compiler.compile(
    inputset=[
        ( np.random.randint(0, 2**nbits, size=(int(192*8*nchunks/32),) ),
          np.random.randint(0, 2**nbits, size=(8,nchunks)), 
          np.random.randint(0, 2**nbits, size=(64,nchunks)) )
        for _ in range(100)
    ],
    configuration=configuration,
    verbose=True,
)

# run circuits
output_np = sha256circuit_quick_np( textAsInts, H_nbits, K256_nbits )
output_cnp = circuit.encrypt_run_decrypt( textAsInts, H_nbits, K256_nbits )

# these hash are not the true result of SHA-256 but they should be equal if everything is working well
hash_np = outputToHash(output_np, nbits)
hash_cnp = outputToHash(output_cnp, nbits)

print(hash_np)
assert(hash_np==hash_cnp)
print('hashes are equal !')

Compiling sha256_150...
Done
4877fab26a09e667bb67ae853c6ef372e5375507510e527f9b05688c1f83d9ab
hashes are equal !


You can try out the time it takes for different values of **n-bits**. The **1-bit** and **2-bits** versions should be by far the fastest.  

Finally, if you have a fast computer at your disposition, you can create the full concrete version of the circuit and run it:

```python

# create full circuit for concrete
sha256circuit_full_cnp = sha256CircuitFactory(nbits, use_cnp=True, quick_test=False)

# compile the cnp circuit
configuration = cnp.Configuration(
    enable_unsafe_features=True,
    use_insecure_key_cache=True,
    insecure_key_cache_location=".keys",
)

compiler = cnp.Compiler(sha256circuit_full_cnp, {"M": "encrypted", "H": "encrypted", "K256": "encrypted"})
circuit = compiler.compile(
    inputset=[
        ( np.random.randint(0, 2**nbits, size=(int(192*8*nchunks/32),) ),
          np.random.randint(0, 2**nbits, size=(8,nchunks)), 
          np.random.randint(0, 2**nbits, size=(64,nchunks)) )
        for _ in range(100)
    ],
    configuration=configuration,
    verbose=True,
)

# run circuits
output_full_cnp = circuit.encrypt_run_decrypt( textAsInts, H_nbits, K256_nbits )

# compute hash and compare it to the previous one
hash_full_cnp = outputToHash(output_full_cnp)

print(hash_full_cnp)
assert(expectedHash==hash_full_cnp)
print('hashes are equal !')
```