In [39]:
import random
import string

def rand_hex(chars):
    return ''.join(random.choices('0123456789ABCDEF', k = chars))

def rand_bits(l):
    return "{0:b}".format(random.getrandbits(l)).rjust(l, '0')

def xor_str(key, s):
    if len(key) < len(s):
        key = key * (len(s) // len(key) + 1)

    return ''.join([chr(ord(a) ^ ord(b)) for a, b in zip(key, s)])

In [68]:
from Crypto.Cipher import AES
from Crypto.Random import get_random_bytes
import random
import hashlib

KEY_PREFIX = '0123456789ABCDEF0123456789ABCDEF'

class Alice:
    def __init__(self):
        self.puzzles = {}
        self.aes_key = None

    def gen_puzzles(self, N):
        if len(self.puzzles) > 0:
            raise Exception("Puzzles already generated")

        for i in range(0, 2 ** N):
            id = rand_hex(20)
            aes_key = get_random_bytes(16).hex()

            guess_key_len = N // 4
            key = KEY_PREFIX[guess_key_len:] + rand_hex(guess_key_len)

            message = f"id: {id} aes_key: {aes_key}"

            cipher = AES.new(bytearray.fromhex(key), AES.MODE_GCM)
            puzzle, tag = cipher.encrypt_and_digest(message.encode('utf-8'))

            self.puzzles[id] = {
                "aes_key": aes_key,
                "puzzle": puzzle,
                "nonce": cipher.nonce
            }

    def get_puzzles(self):
        return [(p["puzzle"], p["nonce"]) for p in self.puzzles.values()]
    
    def receive_puzzle_id(self, id):
        if self.aes_key is not None:
            raise Exception("Puzzle ID already received")
        
        self.aes_key = bytearray.fromhex(self.puzzles[id]["aes_key"])
    
    def encrypt(self, m):
        cipher = AES.new(self.aes_key, AES.MODE_GCM)
        ciphertext, tag = cipher.encrypt_and_digest(m)
        return ciphertext, tag, cipher.nonce
       


In [74]:
import re

class Bob:
    def __init__(self):
        self.aes_key = None
    
    def solve_rand_puzzle(self, N, puzzles):
        idx = random.randint(1, len(puzzles)) - 1
        print("Puzzle index ", idx)
        puzzle, nonce = puzzles[idx]
        pattern = re.compile("id: (.*) aes_key: ([0123456789abcdef]*)")
        key_len = N // 4

        for i in range(int('F' * key_len, 16), 0, -1):
            key = KEY_PREFIX[key_len:] + hex(i)[2:].rjust(key_len, '0')
            cipher = AES.new(bytearray.fromhex(key), AES.MODE_GCM, nonce=nonce)

            try:
                plaintext = cipher.decrypt(puzzle).decode('utf-8')
                print(plaintext)
                matches = [(g.group(1), g.group(2)) for g in pattern.finditer(plaintext)]

                if len(matches) > 0:
                    self.aes_key = bytearray.fromhex(matches[0][1])
                    return matches[0][0]
            except:
                pass

        raise Exception("Unable to solve a puzzle")
    
    def decrypt(self, ciphertext, tag, nonce):
        cipher = AES.new(self.aes_key, AES.MODE_GCM, nonce=nonce)
        plaintext = cipher.decrypt(ciphertext)

        try:
            cipher.verify(tag)
            return plaintext
        except ValueError:
            raise Exception("Key incorrect or message corrupted")


In [80]:
N = 16

In [81]:
%%time
alice = Alice()
puzzles = alice.gen_puzzles(N)

CPU times: user 3.08 s, sys: 45.7 ms, total: 3.12 s
Wall time: 3.13 s


In [82]:
%%time
bob = Bob()
puzzle_id = bob.solve_rand_puzzle(N, alice.get_puzzles())

Puzzle index  31783
id: 78A2F1F7197894603EAC aes_key: 302cc03ef76a54b28bc18d7d6ca8948a
CPU times: user 2.37 s, sys: 15.4 ms, total: 2.39 s
Wall time: 2.39 s


In [83]:
alice.receive_puzzle_id(puzzle_id)

In [84]:
m = "Hi Bob"

c, tag, nonce = alice.encrypt(m.encode('utf-8'))
print("Ciphertext ", c)
print("Decrypted ", bob.decrypt(c, tag, nonce))

Ciphertext  b'Q<\xe3\x99\xc7o'
Decrypted  b'Hi Bob'
