In [9]:
!pip install qiskit matplotlib-venn

from qiskit import QuantumCircuit, QuantumRegister
from qiskit.quantum_info import Statevector
import numpy as np

def modular_mult_oracle(a, N):
    n = int(np.ceil(np.log2(N)))

    if N == 15:
        x_reg = QuantumRegister(1, 'x')
        y_reg = QuantumRegister(n, 'y')
        anc_reg = QuantumRegister(2, 'anc')
        qc = QuantumCircuit(x_reg, y_reg, anc_reg)

        # Step 1: Set anc[0] = 1 if y == 15
        qc.x(y_reg)
        qc.mcx(y_reg, anc_reg[0], ctrl_state='0'*n)
        qc.x(y_reg)

        # Step 2: Compute ctrl = x AND (y != 15)
        qc.x(anc_reg[0])
        qc.ccx(x_reg[0], anc_reg[0], anc_reg[1])
        qc.x(anc_reg[0])

        # Step 3: Apply multiplication by a mod 15 when anc[1] == 1
        # Hardcode classical map: i ↦ (a * i) % 15 for i = 0 to 14
        for i in range(15):
            result = (a * i) % 15
            input_bits = format(i, f'0{n}b')
            output_bits = format(result, f'0{n}b')
            for j in range(n):
                if input_bits[j] != output_bits[j]:
                    # Exclude the target qubit from the controls
                    controls = [anc_reg[1]] + [y_reg[k] for k in range(n) if k != n - 1 - j]
                    # Manually prepare the state of y_reg for the control
                    for k, bit in enumerate(reversed(input_bits)):
                        if bit == '0':
                            qc.x(y_reg[k])

                    qc.mcx(controls, y_reg[n - 1 - j], ctrl_state='1' + '1'*(n-1))

                    # Uncompute the X gates
                    for k, bit in enumerate(reversed(input_bits)):
                        if bit == '0':
                            qc.x(y_reg[k])

        # Step 4: Uncompute ancillas
        qc.x(anc_reg[0])
        qc.ccx(x_reg[0], anc_reg[0], anc_reg[1])
        qc.x(anc_reg[0])

        qc.x(y_reg)
        qc.mcx(y_reg, anc_reg[0], ctrl_state='0'*n)
        qc.x(y_reg)

        return qc
    else:
        raise NotImplementedError("Only N=15 is implemented without UnitaryGate.")

# Demo
def demo_oracle(a, N, y_values):
    n = int(np.ceil(np.log2(N)))
    oracle = modular_mult_oracle(a, N)

    for y in y_values:
        x_reg = QuantumRegister(1, 'x')
        y_reg = QuantumRegister(n, 'y')
        anc_reg = QuantumRegister(2, 'anc')
        qc = QuantumCircuit(x_reg, y_reg, anc_reg)

        qc.x(x_reg)  # Set control to 1
        y_bin = format(y, f'0{n}b')
        for i, bit in enumerate(reversed(y_bin)):
            if bit == '1':
                qc.x(y_reg[i])

        qc.compose(oracle, inplace=True)

        state = Statevector(qc)
        state_dict = state.probabilities_dict()

        print(f"\nInput y = {y} ({y_bin}):")
        if y < N:
            expected = (a * y) % N
            expected_bin = format(expected, f'0{n}b')
            print(f"Expected: |1⟩|{expected_bin}⟩ ({a}*{y} mod {N} = {expected})")
        else:
            print(f"Expected: |1⟩|{y_bin}⟩ (unchanged)")

        print("Output state (non-zero amplitudes):")
        for basis_state, prob in state_dict.items():
            if not np.isclose(prob, 0):
                print(f"  |{basis_state}⟩: amplitude = {np.sqrt(prob):.4f}")

# Example
demo_oracle(a=7, N=15, y_values=[5, 15])


Input y = 5 (0101):
Expected: |1⟩|0101⟩ (7*5 mod 15 = 5)
Output state (non-zero amplitudes):
  |0010011⟩: amplitude = 1.0000

Input y = 15 (1111):
Expected: |1⟩|1111⟩ (unchanged)
Output state (non-zero amplitudes):
  |0011111⟩: amplitude = 1.0000
