In [4]:
!pip install qiskit
!pip install matplotlib-venn
from qiskit import QuantumCircuit, QuantumRegister
from qiskit.circuit.library import QFT, UnitaryGate # Import UnitaryGate from circuit.library
from qiskit.quantum_info import Statevector
import numpy as np

def modular_mult_oracle(a, N):
    n = int(np.ceil(np.log2(N)))

    if N == 15:
        x_reg = QuantumRegister(1, 'x')
        y_reg = QuantumRegister(n, 'y')
        anc_reg = QuantumRegister(2, 'anc')
        qc = QuantumCircuit(x_reg, y_reg, anc_reg)

        # Step 1: Set anc[0] = 1 if y == 15 (1111)
        qc.x(y_reg)
        qc.mcx(y_reg, anc_reg[0], ctrl_state='0000')  # anc0 = 1 if y was 15
        qc.x(y_reg)  # Flip back

        # Step 2: Compute ctrl = x AND (y != 15) -> stored in anc[1]
        qc.x(anc_reg[0])
        qc.ccx(x_reg[0], anc_reg[0], anc_reg[1])
        qc.x(anc_reg[0])

        # Step 3: Apply multiplication by a mod N when ctrl=1
        # Build permutation matrix for multiplication by a mod 15
        U = np.eye(16)
        for i in range(15):
            target = (a * i) % 15
            U[:, i] = 0
            U[target, i] = 1

        # Apply controlled unitary on y_reg, controlled by anc[1]
        u_gate = UnitaryGate(U)
        cu = u_gate.control(1)
        qc.append(cu, [anc_reg[1]] + list(y_reg))

        # Step 4: Uncompute anc[1] and anc[0]
        qc.x(anc_reg[0])
        qc.ccx(x_reg[0], anc_reg[0], anc_reg[1])
        qc.x(anc_reg[0])

        qc.x(y_reg)
        qc.mcx(y_reg, anc_reg[0], ctrl_state='0000')
        qc.x(y_reg)

        return qc
    else:
        # For N != 15, use the original method (if needed)
        n = int(np.ceil(np.log2(N)))
        x = QuantumRegister(1, 'x')
        y = QuantumRegister(n, 'y')
        anc = QuantumRegister(2, 'anc')
        qc = QuantumCircuit(x, y, anc)

        qc.x(y)
        qc.mcx(list(y) + [x[0]], anc[0])
        qc.x(y)

        qft = QFT(n, do_swaps=False)
        qc.append(qft, y)

        for i in range(n):
            angle = 2 * np.pi * (a * (2 ** i)) / (2 ** n)
            qc.cp(angle, anc[0], y[i])

        iqft = QFT(n, do_swaps=False, inverse=True)
        qc.append(iqft, y)

        qc.x(y)
        qc.mcx(list(y) + [x[0]], anc[0])
        qc.x(y)

        return qc

# Demo for N=15
def demo_oracle(a, N, y_values):
    n = int(np.ceil(np.log2(N)))
    oracle = modular_mult_oracle(a, N)

    for y in y_values:
        x_reg = QuantumRegister(1, 'x')
        y_reg = QuantumRegister(n, 'y')
        if oracle.num_qubits - n - 1 == 2:  # with ancilla
            anc_reg = QuantumRegister(2, 'anc')
            qc = QuantumCircuit(x_reg, y_reg, anc_reg)
        else:
            qc = QuantumCircuit(x_reg, y_reg)

        qc.x(x_reg)  # Set control to 1
        y_bin = format(y, f'0{n}b')
        for i, bit in enumerate(reversed(y_bin)):
            if bit == '1':
                qc.x(y_reg[i])

        qc.compose(oracle, inplace=True)

        state = Statevector(qc)
        state_dict = state.probabilities_dict()

        print(f"\nInput y = {y} ({y_bin}):")
        if y < N:
            expected = (a * y) % N
            expected_bin = format(expected, f'0{n}b')
            print(f"Expected: |1⟩|{expected_bin}⟩ ({a}*{y} mod {N} = {expected})")
        else:
            print(f"Expected: |1⟩|{y_bin}⟩ (unchanged)")

        print("Output state (non-zero amplitudes):")
        for basis_state, prob in state_dict.items():
            if not np.isclose(prob, 0):
                # Format the state for readability
                # The first qubit is x, then y (n qubits), then ancillas (if any)
                # For N=15, we have ancillas, so we show the full state
                if oracle.num_qubits - n - 1 == 2:
                    # Format: |x>_1 |y>_n |anc0> |anc1>
                    print(f"  |{basis_state}⟩: amplitude = {np.sqrt(prob):.4f}")
                else:
                    # Without ancillas: |x> |y>
                    print(f"  |{basis_state}⟩: amplitude = {np.sqrt(prob):.4f}")

# Example usage
N = 15
a = 7
test_values = [5, 15]
demo_oracle(a, N, test_values)


Input y = 5 (0101):
Expected: |1⟩|0101⟩ (7*5 mod 15 = 5)
Output state (non-zero amplitudes):
  |0001011⟩: amplitude = 1.0000

Input y = 15 (1111):
Expected: |1⟩|1111⟩ (unchanged)
Output state (non-zero amplitudes):
  |0011111⟩: amplitude = 1.0000
