In [72]:
from bitarray import bitarray
from bitarray.util import hex2ba, parity, ba2int, int2ba, zeros, ba2hex

from pprint import pprint

In [2]:
DES_IP_TABLE = (58, 50, 42, 34, 26, 18, 10, 2, 60, 52, 44, 36, 28, 20, 12, 4,
                62, 54, 46, 38, 30, 22, 14, 6, 64, 56, 48, 40, 32, 24, 16, 8,
                57, 49, 41, 33, 25, 17, 9, 1, 59, 51, 43, 35, 27, 19, 11, 3,
                61, 53, 45, 37, 29, 21, 13, 5, 63, 55, 47, 39, 31, 23, 15, 7)

DES_SHIFT_TABLE = (1, 1, 2, 2, 2, 2, 2, 2, 1, 2, 2, 2, 2, 2, 2, 1)

DES_PC_1_TABLE = (57, 49, 41, 33, 25, 17, 9, 1, 58, 50, 42, 34, 26, 18,
                  10, 2, 59, 51, 43, 35, 27, 19, 11, 3, 60, 52, 44, 36,
                  63, 55, 47, 39, 31, 23, 15, 7, 62, 54, 46, 38, 30, 22,
                  14, 6, 61, 53, 45, 37, 29, 21, 13, 5, 28, 20, 12, 4)

DES_PC_2_TABLE = (14, 17, 11, 24, 1, 5, 3, 28, 15, 6, 21, 10, 23, 19, 12, 4,
                  26, 8, 16, 7, 27, 20, 13, 2, 41, 52, 31, 37, 47, 55, 30, 40,
                  51, 45, 33, 48, 44, 49, 39, 56, 34, 53, 46, 42, 50, 36, 29, 32)

DES_E_TABLE = (32, 1, 2, 3, 4, 5, 4, 5, 6, 7, 8, 9, 8, 9,
               10, 11, 12, 13, 12, 13, 14, 15, 16, 17, 16,
               17, 18, 19, 20, 21, 20, 21, 22, 23, 24, 25,
               24, 25, 26, 27, 28, 29, 28, 29, 30, 31, 32, 1)

DES_S_TABLE = (
    # s1
    ((14, 4, 13, 1, 2, 15, 11, 8, 3, 10, 6, 12, 5, 9, 0, 7),
     (0, 15, 7, 4, 14, 2, 13, 1, 10, 6, 12, 11, 9, 5, 3, 8),
     (4, 1, 14, 8, 13, 6, 2, 11, 15, 12, 9, 7, 3, 10, 5, 0),
     (15, 12, 8, 2, 4, 9, 1, 7, 5, 11, 3, 14, 10, 0, 6, 13)),
    
    # s2
    ((15, 1, 8, 14, 6, 11, 3, 4, 9, 7, 2, 13, 12, 0, 5, 10),
     (3, 13, 4, 7, 15, 2, 8, 14, 12, 0, 1, 10, 6, 9, 11, 5),
     (0, 14, 7, 11, 10, 4, 13, 1, 5, 8, 12, 6, 9, 3, 2, 15),
     (13, 8, 10, 1, 3, 15, 4, 2, 11, 6, 7, 12, 0, 5, 14, 9)),
    
    # s3
    ((10, 0, 9, 14, 6, 3, 15, 5, 1, 13, 12, 7, 11, 4, 2, 8),
     (13, 7, 0, 9, 3, 4, 6, 10, 2, 8, 5, 14, 12, 11, 15, 1),
     (13, 6, 4, 9, 8, 15, 3, 0, 11, 1, 2, 12, 5, 10, 14, 7),
     (1, 10, 13, 0, 6, 9, 8, 7, 4, 15, 14, 3, 11, 5, 2, 12)),
    
    # s4
    ((7, 13, 14, 3, 0, 6, 9, 10, 1, 2, 8, 5, 11, 12, 4, 15), 
     (13, 8, 11, 5, 6, 15, 0, 3, 4, 7, 2, 12, 1, 10, 14, 9), 
     (10, 6, 9, 0, 12, 11, 7, 13, 15, 1, 3, 14, 5, 2, 8, 4), 
     (3, 15, 0, 6, 10, 1, 13, 8, 9, 4, 5, 11, 12, 7, 2, 14)),
    
    # s5
    ((2, 12, 4, 1, 7, 10, 11, 6, 8, 5, 3, 15, 13, 0, 14, 9), 
     (14, 11, 2, 12, 4, 7, 13, 1, 5, 0, 15, 10, 3, 9, 8, 6),
     (4, 2, 1, 11, 10, 13, 7, 8, 15, 9, 12, 5, 6, 3, 0, 14),
     (11, 8, 12, 7, 1, 14, 2, 13, 6, 15, 0, 9, 10, 4, 5, 3)),
    
    # s6
    ((12, 1, 10, 15, 9, 2, 6, 8, 0, 13, 3, 4, 14, 7, 5, 11),
     (10, 15, 4, 2, 7, 12, 9, 5, 6, 1, 13, 14, 0, 11, 3, 8),
     (9, 14, 15, 5, 2, 8, 12, 3, 7, 0, 4, 10, 1, 13, 11, 6),
     (4, 3, 2, 12, 9, 5, 15, 10, 11, 14, 1, 7, 6, 0, 8, 13)),
    
    # s7
    ((4, 11, 2, 14, 15, 0, 8, 13, 3, 12, 9, 7, 5, 10, 6, 1),
     (13, 0, 11, 7, 4, 9, 1, 10, 14, 3, 5, 12, 2, 15, 8, 6),
     (1, 4, 11, 13, 12, 3, 7, 14, 10, 15, 6, 8, 0, 5, 9, 2),
     (6, 11, 13, 8, 1, 4, 10, 7, 9, 5, 0, 15, 14, 2, 3, 12)),
    
    # s8
    ((13, 2, 8, 4, 6, 15, 11, 1, 10, 9, 3, 14, 5, 0, 12, 7),
     (1, 15, 13, 8, 10, 3, 7, 4, 12, 5, 6, 11, 0, 14, 9, 2),
     (7, 11, 4, 1, 9, 12, 14, 2, 0, 6, 10, 13, 15, 3, 5, 8),
     (2, 1, 14, 7, 4, 10, 8, 13, 15, 12, 9, 0, 3, 5, 6, 11))
)

DES_P_TABLE = (16, 7, 20, 21, 29, 12, 28, 17,
               1, 15, 23, 26, 5, 18, 31, 10,
               2, 8, 24, 14, 32, 27, 3, 9,
               19, 13, 30, 6, 22, 11, 4, 25)

DES_IP_INV_TABLE = (40, 8, 48, 16, 56, 24, 64, 32, 39, 7, 47, 15, 55, 23, 63, 31,
                    38, 6, 46, 14, 54, 22, 62, 30, 37, 5, 45, 13, 53, 21, 61, 29,
                    36, 4, 44, 12, 52, 20, 60, 28, 35, 3, 43, 11, 51, 19, 59, 27,
                    34, 2, 42, 10, 50, 18, 58, 26, 33, 1, 41, 9, 49, 17, 57, 25)

In [4]:
class DESError(Exception):
    pass

In [150]:
class DES:
    def __init__(self, key: str, iv: str = None, mode: str = "ECB"):
        self._modes = {"ECB": self._ECB,
                       "CBC": self._CBC,
                       "CFB": self._CFB,
                       "OFB": self._OFB}
        
        try:
            self.key = hex2ba(key)
        except ValueError:
            raise DESError("The entered key is not a hexadecimal one!")
        
        if len(self.key) != 56:
            raise DESError(f"Key length must be 56 bits (7 bytes)! ({len(self.key)} bits entered)")

        if iv:
            try:
                self.iv = hex2ba(iv)
            except ValueError:
                raise DESError("The entered IV is not a hexadecimal one!")

            if len(self.iv) != 64:
                raise DESError(f"IV length must be 64 bits (8 bytes)! ({len(self.iv)} bits entered)")
        
        if iv is None and mode != "ECB":
            raise DESError(f"Encryption in '{mode}' mode requires an initialization vector!")
        
        if mode not in self._modes.keys():
            raise DESError(f"Invalid encryption mode entered ({mode})! "
                           f"Possible modes: {tuple(self._modes.keys())}")

        self._mode = self._modes.get(mode)
        self.keys = self.generate_keys(self.key)
        
    def generate_keys(self, key: bitarray):
        ext_key = key.copy()
        
        # Getting a 64-bit key
        for pos in range(0, 64, 8):
            match parity(ext_key[pos:pos+7]):
                case 0:
                    ext_key.insert(pos+8, 1)

                case 1:
                    ext_key.insert(pos+8, 0)

        # Permutation
        new_key = self._permutation(ext_key, DES_PC_1_TABLE)

        l, d = new_key[:28], new_key[28:]
        keys = []

        for i in range(16):
            shift = DES_SHIFT_TABLE[i]

            addition_l, addition_d = l[:shift], d[:shift]        
            l <<= shift
            d <<= shift
            l[-shift:], d[-shift:] = addition_l, addition_d
            
            # Final permutation
            new_key = self._permutation(l + d, DES_PC_2_TABLE)
            keys.append(new_key)

        return keys        
        
    def _permutation(self, key: bitarray, table: tuple):
        return bitarray(key[i - 1] for i in table)
        
    def _transform(self, block: bitarray, mode: str = "encrypt"):
        block = self._permutation(block, DES_IP_TABLE)
        l, r = block[:32], block[32:]
        
        match mode:
            case "encrypt":
                for i in range(16):
                    l, r = r, l ^ self._feistel(r, self.keys[i])
            
            case "decrypt":
                for i in range(15, -1, -1):
                    r, l = l, r ^ self._feistel(l, self.keys[i])
            
            case _:
                raise DESError(f"Invalid processing mode! -> {mode}")
        
        return self._permutation(l + r, DES_IP_INV_TABLE)
    
    def _feistel(self, chunk: bitarray, key: bitarray):
        chunk = self._permutation(chunk, DES_E_TABLE)
        chunk ^= key

        new_chunk = bitarray()
        for k, pos in enumerate(range(0, 48, 6)):
            b = chunk[pos:pos+6]
            i = ba2int(b[::5])
            j = ba2int(b[1:5])
            b = int2ba(DES_S_TABLE[k][i][j], length=4)
            new_chunk += b
        
        return self._permutation(new_chunk, DES_P_TABLE)
    
    def _ECB(self, data: bitarray, mode: str = "encrypt"):
        processed_data = bitarray()
        
        for pos in range(0, len(data), 64):
            block = data[pos:pos+64]
            processed_block = self._transform(block, mode)
            processed_data += processed_block

        return processed_data
    
    def _CBC(self, data: bitarray, mode: str = "encrypt"):
        processed_data = bitarray()
        vector = self.iv
        
        for pos in range(0, len(data), 64):
            block = data[pos:pos+64]
            
            match mode:
                case "encrypt":
                    processed_block = self._transform(block ^ vector, "encrypt")
                    vector = processed_block
                
                case "decrypt":
                    processed_block = self._transform(block, "decrypt") ^ vector
                    vector = block
                    
                case _:
                    raise DESError(f"Invalid processing mode! -> {mode}")
                    
            processed_data += processed_block
            
        return processed_data
    
    def _CFB(self, data: bitarray, mode: str = "encrypt"):
        processed_data = bitarray()
        vector = self.iv
        
        for pos in range(0, len(data), 64):
            block = data[pos:pos+64]
            
            match mode:
                case "encrypt":
                    processed_block = self._transform(vector, "encrypt") ^ block
                    vector = processed_block
                
                case "decrypt":
                    processed_block = self._transform(vector, "encrypt") ^ block
                    vector = block
                
                case _:
                    raise DESError(f"Invalid processing mode! -> {mode}")
                    
            processed_data += processed_block
            
        return processed_data 
    
    def _OFB(self, data: bitarray, mode: str = "encrypt"):
        processed_data = bitarray()
        vector = self.iv

        for pos in range(0, len(data), 64):
            block = data[pos:pos+64]
            
            match mode:
                case "encrypt":
                    vector = self._transform(vector, "encrypt")
                    processed_block = vector ^ block
                
                case "decrypt":
                    vector = self._transform(vector, "encrypt")
                    processed_block = vector ^ block
                
                case _:
                    raise DESError(f"Invalid processing mode! -> {mode}")
                    
            processed_data += processed_block
            
        return processed_data 
    
    def _data_processing(self, data: bytes or str, mode: str = "encrypt"):
        data_bits = bitarray()
        
        match mode, data:
            case "encrypt", str():
                data_bits.frombytes(data.encode("utf-8"))
                
            case "decrypt", str():
                data_bits = hex2ba(data)
            
            case _, bytes():
                data_bits.frombytes(data)
                
            case _:
                raise DESError(f"Invalid processing mode! -> {mode}")
                
        if mode == "encrypt" and (k := len(data_bits) % 64) != 0:
            data_bits += bitarray("0" * (64 - k))     
                
        processed_data = self._mode(data_bits, mode)
        
        match mode, data:
            case "encrypt", str():
                return ba2hex(processed_data)
                
            case "decrypt", str():
                return processed_data.tobytes().decode("utf-8")
            
            case _, bytes():
                return processed_data.tobytes()
                
            case _:
                raise DESError(f"Invalid processing mode! -> {mode}")
    
    def encrypt(self, data: bytes or str):
        return self._data_processing(data, "encrypt")
    
    def decrypt(self, data: bytes or str):
        return self._data_processing(data, "decrypt")
    
    def make(self, data: bytes or str, mode: str = "encrypt"):
        match mode:
            case "encrypt":
                return self.encrypt(data)
            
            case "decrypt":
                return self.decrypt(data)
            
            case _:
                raise DESError(f"Invalid processing mode! -> {mode}")

In [134]:
string = "Hello, World!!! Привет, Мир!😎"
key = "2F7D8AAADA6C9D"
iv = "9A8D5AAAB32E21CC"

In [152]:
des = DES(key=key, iv=iv, mode="OFB")
encrypted_str = des.make(string, "encrypt")
print("encrypted str:", encrypted_str)

decrypted_str = des.make(encrypted_str, "decrypt")
print("decrypted str:", decrypted_str)

encrypted str: 6b95cbf1c15f148c79203f710ece99dd359e84a866c4479c016f0b3baf9ea2011bb26029a75e95da7ccc6dc83e4540a7
decrypted str: Hello, World!!! Привет, Мир!😎       


In [None]:
k = (64 - len(data_bits) % 64) // 8
data_bits += int2ba(k, length=8) * k