In [None]:
import pennylane as qml
from fractions import Fraction
from pennylane import numpy as np


def factor_with_shor(N):
    """Implements Shor's algorithm to determine the prime factors of a provided
    value N.

    This function _must_ execute a QNode in order to factor the number.

    Args:
        N (int): The number to factor.

    Returns:
        int, int: p, q such that N = p * q.
    """
    p, q = 0, 0

    # YOUR CODE HERE
    def phase_to_order(phase, max_denominator):
        """Estimating which integer values divide to produce a float.
        
        Given some floating-point phase, estimate integers s, r such
        that s / r = phase, where r is no greater than some specified value.
        
        Args:
            phase (float): Some fractional value (here, will be the output
                of running QPE).
            max_denominator (int): The largest r to be considered when looking
                for s, r such that s / r = phase.
                
        Returns:
            int: The estimated value of r.
        """
        s_over_r = Fraction(phase)
        return s_over_r.limit_denominator(max_denominator).denominator
        
    def fractional_binary_to_float(sample):
        """Convert an n-bit sample [k1, k2, ..., kn] to a floating point 
        value using fractional binary representation,
        
            k = (k1 / 2) + (k2 / 2 ** 2) + ... + (kn / 2 ** n)
            
        Args:
            sample (list[int] or array[int]): A list or array of bits, e.g.,
                the sample output of quantum circuit.
                
        Returns:
            float: The floating point value corresponding computed from the
            fractional binary representation.
        """
        return np.sum(
            [int(sample[bit]) / 2 ** (bit + 1) for bit in range(len(sample))]
        )
    
    def get_U_Na(N, a):
        """Computes the unitary matrix U_(N, a) used in the order-finding
        portion of Shor's algorithm.
        
        U_(N, a) multiples a computational basis state by a modulo N, i.e.,
            U_(N, a) |k> = |ak mod N>
            
        In Shor's algorithm, we try to find its order, i.e., the smallest
        m such that 
            U_(N, a)^m |k> = |k mod N> = |k>
        
        Args:
            N (int): The modulus. In Shor's algorithm, this is the number 
                we are trying to find the prime factors of.
            a (int): The candidate value a which we are testing to try and
                find a non-trivial square root (which will then allow us to
                recover the prime factors of N).
                
        Returns:
            array[int]: The matrix representation U_(N, a).
        """    
        # Compute size of the matrix; we need at least log2(N) qubits
        # because we are looking at computational basis states modulo N
        n_qubits = int(np.ceil(np.log2(N)))
        
        U_Na = np.zeros([2 ** n_qubits, 2 ** n_qubits])
        
        # U_Na is a permutation matrix; for each k < N, need to compute
        # |l> = |a k mod N>, and then set the value of U_Na[l, k] = 1
        for k in range(N):
            U_Na[(k * a) % N, k] = 1
    
        # We might have more basis states than we need, if N < 2 ** n_qubits
        # so we set the remaining rows to identity rows
        for extra in range(N, 2 ** n_qubits):
            U_Na[extra, extra] = 1
            
        return U_Na
        
    def run_order_finding(a, N):
        U_Na = get_U_Na(a, N)
        
        num_estimation_qubits = 10
        num_target_qubits = int(np.log2(len(U_Na)))
        
        estimation_wires = range(num_estimation_qubits)
        target_wires = range(num_estimation_qubits, num_estimation_qubits + num_target_qubits)
        
        dev = qml.device('default.qubit', wires=num_estimation_qubits+num_target_qubits, shots=1)
        
        @qml.qnode(dev)
        def find_order():
            # Prepare target register
            qml.PauliX(wires=target_wires[-1])
            
            # Do phase estimation
            qml.QuantumPhaseEstimation(
                U_Na,
                estimation_wires=estimation_wires,
                target_wires=target_wires
            )
            
            return qml.sample(wires=estimation_wires)
    
        possible_r = []
        
        for _ in range(10):
            sample = find_order()
            #print(f"Sample = {sample}")
            phase = fractional_binary_to_float(sample)
            #print(f"Numerical phase = {phase}")
            est_r = phase_to_order(phase, N)
            #print(f"Guess for r = {est_r}")
            possible_r.append(est_r)
            
        return max(possible_r)

    def shors_algorithm(N):
        for _ in range(10):
            a = np.random.choice(list(range(2, N-1)))
    
            # if np.gcd(a, N) != 1:
            #     print("We got lucky!")
            #     p = np.gcd(a, N)
            #     q = N // p
            #     return p, q
    
            # Slightly adjusted from the live demo in class
            r = run_order_finding(a, N)
    
            if r % 2 == 1:
                continue
                
            x = (a ** (r // 2)) % N
    
            if x == 1 or x == (N - 1):
                continue
                
            p = np.gcd(x - 1, N)
            q = np.gcd(x + 1, N)
            return p, q
    
    print(N)
    for _ in range(100):
        p, q = shors_algorithm(N)
        if p * q == N:
            return p, q


def decode_message(message, key_pair):
    """Use Shor's algorithm to decrypt an arbitrary secret message encoded using
    the provided RSA public key pair.

    Messages are encoded as a list of integers; the mapping between characters in
    the message and `decoded` integers is:
      - 0-9: numbers 0-9
      - 10-35: letters a-z (only lowercase is used)
      - 36: space

    Args:
        message (List[int]): A list of integers representing the secret message.
            Each integer in the list represents a different character in the  message.
        key_pair ((int, int)): The public RSA key (e, N).

    Returns:
        str: The decoded message.
    """
    decoded_message = ""

    # YOUR CODE HERE
    # for i in range(len(message)):
    #     for _ in range(400):
    #         p, q = factor_with_shor(message[i])
    #         if p * q == message[i]:
    #             print(f"p={p}\nq={q}")
    #             break
    
    return decoded_message


def find_party_location(party_message, party_key_pair):
    """Recover the location of the surprise party.

    *Note that there is no explicit test function provided for this; a hidden
    test function will be used (with just this one test case), and is worth 1
    point. The values that will be passed to the test function are in
    problem_3_data.py.*

    Args:
        party_message (List[int]): A list of integers representing the secret message.
            Each integer in the list represents a different character in the message.
        party_key_pair ((int, int)): The public RSA key (e, N).

    Returns:
        str: a 10-character string indicating the location of the surprise party.

    """

    location = ""

    # YOUR CODE HERE

    return location