In [1]:
import random
import string

def rand_str(chars):
    return ''.join(random.choices(string.ascii_uppercase, k = chars))

def rand_bits(l):
    return "{0:b}".format(random.getrandbits(l)).rjust(l, '0')

def xor_str(key, s):
    if len(key) < len(s):
        key = key * (len(s) // len(key) + 1)

    return ''.join([chr(ord(a) ^ ord(b)) for a, b in zip(key, s)])

In [2]:
from Crypto.Cipher import AES
from Crypto.Random import get_random_bytes
import random
import hashlib

KEY_LEN = 20

class Alice:
    def __init__(self, N):
        self.N = N
        self.puzzles = {}
        self.aes_key = None

    def gen_puzzles(self):
        if len(self.puzzles) > 0:
            raise "Puzzles already generated"

        while len(self.puzzles) < self.N:
            id = rand_str(20)
            aes_key = get_random_bytes(16).hex()

            message = f" id: {id} aes_key: {aes_key}"
            message += f" hash: {hashlib.sha256(message.encode('utf-8')).hexdigest()} "
            key = rand_bits(KEY_LEN)

            self.puzzles[id] = {
                "aes_key": aes_key,
                "puzzle": xor_str(key, message)
            }

    def get_puzzles(self):
        return [p["puzzle"] for p in self.puzzles.values()]
    
    def receive_puzzle_id(self, id):
        if self.aes_key is not None:
            raise Exception("Puzzle ID already received")
        
        self.aes_key = bytearray.fromhex(self.puzzles[id]["aes_key"])
    
    def encrypt(self, m):
        cipher = AES.new(self.aes_key, AES.MODE_GCM)
        ciphertext, tag = cipher.encrypt_and_digest(m)
        return ciphertext, tag, cipher.nonce
       


In [3]:
import re

class Bob:
    def __init__(self):
        self.aes_key = None
    
    def solve_rand_puzzle(self, puzzles):
        idx = random.randint(1, len(puzzles)) - 1
        print("Puzzle index ", idx)
        puzzle = puzzles[idx]
        pattern = re.compile("( id: ([A-Z]*) aes_key: ([0123456789abcdef]*)) hash: ([0123456789abcdef]*) ")

        for i in range(2 ** KEY_LEN, 0, -1):
            key = "{0:b}".format(i).rjust(KEY_LEN, '0')
            s = xor_str(key, puzzle)

            matches = [(g.group(1), g.group(2), g.group(3), g.group(4)) for g in pattern.finditer(s)]

            if len(matches) > 0:
                if hashlib.sha256(matches[0][0].encode('utf-8')).hexdigest() == matches[0][3]:
                    self.aes_key = bytearray.fromhex(matches[0][2])
                    return matches[0][1]

        raise Exception("Unable to solve a puzzle")
    
    def decrypt(self, ciphertext, tag, nonce):
        cipher = AES.new(self.aes_key, AES.MODE_GCM, nonce=nonce)
        plaintext = cipher.decrypt(ciphertext)

        try:
            cipher.verify(tag)
            return plaintext
        except ValueError:
            raise Exception("Key incorrect or message corrupted")


In [4]:
%%time
alice = Alice(2 ** 24)
puzzles = alice.gen_puzzles()

CPU times: user 4min 56s, sys: 24.5 s, total: 5min 20s
Wall time: 5min 21s


In [5]:
%%time
bob = Bob()
puzzle_id = bob.solve_rand_puzzle(alice.get_puzzles())

Puzzle index  15165906
CPU times: user 6.86 s, sys: 53 ms, total: 6.92 s
Wall time: 6.92 s


In [6]:
alice.receive_puzzle_id(puzzle_id)

In [13]:
m = "Hi Bob"

c, tag, nonce = alice.encrypt(m.encode('utf-8'))
print("Ciphertext ", c)
print("Decrypted ", bob.decrypt(c, tag, nonce))

Ciphertext  b'Q\xff\xa0\x0f\xe2\x9f'
Decrypted  b'Hi Bob'
