# Implementation of a ChaCha Cipher

Understanding how the ChaCha cipher with its matrix and quarter round structure works. Implementation based on "Implementing Cryptography Using Python". This is for understanding how it works rather than good code!

The encryption is undertaken by having a 4x4 matrix where each word is 32 bits. The key operations are Add-Rotate-Xor. A series of quarter rounds operate on the four following columns (A,B,C,D) followed by the four diagonals with the data as specified in the table below:


A    |B    |C    |D
-----|-----|-----|----
Const0|Const1|Const2|Const3
Key0|Key1|Key2|Key3
Key4|Key5|Key6|Key7
Input0|Input1|Input2|Input3

In [6]:
# Import required modules.
from base64 import b64encode
import  os
import struct

In [7]:
class ChaCha():
    '''ChaCha cipher.'''
    
    def __init__(self):
        ''' Initialisation of member variables. None implemented - could put the initial matrix here.'''
        pass
    
    
    def chachaGen(self, key, iv, position=0):
        ''' Perform the quarter rounds for the ChaCha algorithm and prepare the key that input
            data will be xor'd against.
        
            INPUTS:
                key - The secret key as type bytes.
                iv -  The initialisation vector as type bytes.
                position - The stream position.
                
            OUTPUTS:
                This is a generator so will 'yield' a 
        '''

        # Validate input data.
        self.checkInput(position, key, iv)

        # Setup the ChaCha 4x4 table - see note above for structure.
        chaMtrx = [0] * 16   # Initialise all with zeros.
        chaMtrx[:4] = (1634760805, 857760878, 2036477234, 1797285236)   # Put the constants on first row.
        chaMtrx[4 : 12] = struct.unpack('<8L', key)   # Put the key on the next two rows.
        chaMtrx[12] = chaMtrx[13] = position    # Initialise first two inputs with the supplied position.
        chaMtrx[14 : 16] = struct.unpack('<LL', iv)  # Put the IV on to the remaining two inputs.
        
        # Create a generator that will run all the quaterrounds each time it is invoked.
        while 1:
            
            # Copy the list (i.e. not a reference). It is used below when adding the output of the 
            # quarter rounds back to the table.
            x = chaMtrx.copy()

            # Undertake 10 column rounds and 10 diagonal rounds (so 20 in total).
            for i in range(10):
                self.quarterRound(x, 0, 4,  8, 12)
                self.quarterRound(x, 1, 5,  9, 13)
                self.quarterRound(x, 2, 6, 10, 14)
                self.quarterRound(x, 3, 7, 11, 15)
                self.quarterRound(x, 0, 5, 10, 15)
                self.quarterRound(x, 1, 6, 11, 12)
                self.quarterRound(x, 2, 7,  8, 13)
                self.quarterRound(x, 3, 4,  9, 14)

            # Prepare output that will be used to xor data with.
            for c in struct.pack('<16L', *((x[i] + chaMtrx[i]) & 0xffffffff for i in range(16))):
                yield c


    def quarterRound(self, x, a, b, c, d):
        ''' Undertake a quarter round which is defined mathematically as:

            a = a + b
            d = d ⊕ a
            d = (d)<<16
            c = c + d
            b = b ⊕ c
            b = (b)<<12
            a = a + b
            d = d ⊕ a
            d = (d)<<8
            c = c + d
            b = b ⊕ c
            b = (b)<<7

        INPUTS:
            x - The chacha table as a list.
            a,b,c,d - integer position indexes into the list. NB a=0 is top left, b=4 would be const0,
                    c=8 would be key0, and d=12 would be input0 (i.e. first column).
        OUTPUTS:
            x - passed by reference in so in place modification.

        '''

        # Undertake a quarter round. In this we update the column/diagonal values with add, rotate, and xor operations.
        x[a] = (x[a] + x[b]) & 0xffffffff  # The 0xffffffff represents the max value a 32 bit number can have. So this prevents overflow of the x[a]+x[b].
        x[d] = self.rotate(x[d] ^ x[a], 16)
        x[c] = (x[c] + x[d]) & 0xffffffff
        x[b] = self.rotate(x[b] ^ x[c], 12)
        x[a] = (x[a] + x[b]) & 0xffffffff
        x[d] = self.rotate(x[d] ^ x[a], 8)
        x[c] = (x[c] + x[d]) & 0xffffffff
        x[b] = self.rotate(x[b] ^ x[c], 7)


    def rotate(self, v, c):
        ''' Rotate v by c as defined by ChaCha spec.'''

        return ((v << c) & 0xffffffff) | v >> (32 - c) # NB 0xffffffff to avoid 32 bit int overflow.

    
    def chachaEncryptDecrypt(self, data, key, iv=None, position=0):
        ''' Function to encyrpt or decrypt.
            INPUTS:
                data - The message as bytes.
                key - The key as bytes.
            OUTPUT: Either the cipher text (if encrypting) or plain text. Both as bytes.
        '''
        
        # Encrypt (or decrypt) with the ChaCha20 cipher.
        if not isinstance(data, bytes):
            raise TypeError
        if iv is None:
            iv = b'\0' * 8
        if isinstance(key, bytes):
            if not key:
                raise ValueError('Key is empty.')
            if len(key) < 32:
                key = (key * (32 // len(key) + 1))[:32]
            if len(key) > 32:
                raise ValueError('Key too long.')

        # Now xor input data with the generated ChaCha key. Cast to bytes.  
        # NB that the encryption is actually done on a byte by byte basis i.e. this
        # could easily be streamed in.
        outData = bytes(a ^ b for a, b in zip(data, self.chachaGen(key, iv, position)))
        
        return outData

    
    def checkInput(self, position, key, iv):
        ''' Helper function to check the input to the cipher. This will throw if conditions checked aren't met.'''

        # Check position is an int.
        if not isinstance(position, int):
            raise TypeError

        # Do a bitwise check with complement to check int32.
        if position & ~0xffffffff:
            raise ValueError('Position is not uint32.')

        # Type and length checks on key and IV.
        if not isinstance(key, bytes):
            raise TypeError
        if not isinstance(iv, bytes):
            raise TypeError
        if len(key) != 32:
            raise ValueError
        if len(iv) != 8:
            raise ValueError

### Encrypt a message and then decrypt it.

In [8]:
# Define a ChaCha object.
chaCipher = ChaCha()

In [9]:
# Define the message
key = b'ThisKeyMustBeKeptSecret-abcd!!'
plaintext = b'The start of the journey was ....'
print('\nThe plaintext is {}'.format(plaintext))

# Create an initialisation vector.
iv = os.urandom(8)

# Encrypt the text.
enc = chaCipher.chachaEncryptDecrypt(plaintext, key, iv)
decode_enc = b64encode(enc).decode('utf-8')
print('\nThe encrypted text is {}. '.format(decode_enc))


The plaintext is b'The start of the journey was ....'

The encrypted text is rUr4ivRtaRlVZ7QVpiVPd6nSSz+Tn8lFioHPpmWXqZtj. 


In [10]:
# Decrypt the cipher text.
dec = chaCipher.chachaEncryptDecrypt(enc, key, iv)
print('\nThe decrypted text is {}. '.format(dec))


The decrypted text is b'The start of the journey was ....'. 
