In [None]:
from qiskit import QuantumCircuit
import numpy as np

def dj_function(num_qubits):
    """
    Create a random Deutsch-Jozsa function.
    """

    qc = QuantumCircuit(num_qubits + 1)
    if np.random.randint(0, 2):
        # Flip output qubit with 50% chance
        qc.x(num_qubits)
    if np.random.randint(0, 2):
        # return constant circuit with 50% chance
        return qc

    # Next, choose half the possible input states
    on_states = np.random.choice(
        range(2 ** num_qubits),  # numbers to sample from
        2 ** num_qubits // 2,  # number of samples
        replace=False,  # makes sure states are only sampled once
    )

    def add_cx(qc, bit_string):
        for qubit, bit in enumerate(reversed(bit_string)):
            if bit == "1":
                qc.x(qubit)
        return qc

    for state in on_states:
        # Barriers are added to help visualize how the functions are created. They can safely be removed.
        qc.barrier()
        qc = add_cx(qc, f"{state:0b}")
        qc.mcx(list(range(num_qubits)), num_qubits)
        qc = add_cx(qc, f"{state:0b}")

    qc.barrier()

    return qc

In [None]:
display(dj_function(3).draw("mpl"))

In [None]:
def compile_circuit(function: QuantumCircuit):
    """
    Compiles a circuit for use in the Deutsch-Jozsa algorithm.
    """
    n = function.num_qubits - 1
    qc = QuantumCircuit(n + 1, n)
    qc.x(n)
    qc.h(range(n + 1))
    qc.compose(function, inplace=True)
    qc.h(range(n))
    qc.measure(range(n), range(n))
    return qc

In [None]:
from qiskit_aer import AerSimulator

def dj_algorithm(function: QuantumCircuit):
    """
    Determine if a Deutsch-Jozsa function is constant or balanced.
    """
    qc = compile_circuit(function)

    result = AerSimulator().run(qc, shots=500, memory=True).result()
    measurements = result.get_memory()
    if "1" in measurements[0]:
        return "balanced"
    return "constant"

In [None]:
f = dj_function(3)
display(f.draw("mpl"))
display(dj_algorithm(f))

In [None]:
def subject_constant_oracle():
    qc = QuantumCircuit(4)
    qc.x(3)

    return qc

f = subject_constant_oracle()
display(f.draw("mpl"))
display(dj_algorithm(f))

In [None]:
def subject_balanced_oracle():
    qc = QuantumCircuit(4)
    qc.x([0, 1, 2])
    qc.cx(0, 3)
    qc.cx(1, 3)
    qc.cx(2, 3)
    qc.x([0, 1, 2])

    return qc

f = subject_balanced_oracle()
display(f.draw("mpl"))
display(dj_algorithm(f))