In [None]:
!pip install qiskit
from qiskit import QuantumCircuit, QuantumRegister
from qiskit.quantum_info import Statevector
from qiskit.circuit.library import ZGate
import numpy as np
import math
from collections import Counter

def sudoku_solver(grid):
    n = 2  # 4x4 Sudoku
    grid_qubits = 16 * 2  # 2 qubits per cell
    oracle_ancillas = 6  # all_valid, constraint_violated, 4 flags
    total_qubits = grid_qubits + oracle_ancillas
    qr = QuantumRegister(total_qubits, 'q')
    qc = QuantumCircuit(qr)

    all_valid_idx = 32
    constraint_violated_idx = 33
    flags_idxs = [34, 35, 36, 37]

    # Prepare input superposition (respecting fixed cells)
    variable_cells = []
    for i in range(4):
        for j in range(4):
            idx = 2 * (4 * i + j)
            if grid[i][j] != 0:
                value = grid[i][j] - 1
                bin_rep = format(value, '02b')
                if bin_rep[0] == '1':
                    qc.x(qr[idx])
                if bin_rep[1] == '1':
                    qc.x(qr[idx + 1])
            else:
                qc.h(qr[idx])
                qc.h(qr[idx + 1])
                variable_cells.append((i, j))

    variable_qubits = []
    for (i, j) in variable_cells:
        idx = 2 * (4 * i + j)
        variable_qubits.extend([idx, idx + 1])
    num_var_qubits = len(variable_qubits)

    qc.x(qr[all_valid_idx])  # Mark all_valid as |1⟩

    # Define constraints
    constraints = []
    for i in range(4):
        constraints.append([(i, j) for j in range(4)])  # Rows
        constraints.append([(j, i) for j in range(4)])  # Columns
    constraints += [
        [(0,0), (0,1), (1,0), (1,1)],
        [(0,2), (0,3), (1,2), (1,3)],
        [(2,0), (2,1), (3,0), (3,1)],
        [(2,2), (2,3), (3,2), (3,3)],
    ]

    def oracle_circuit():
        oracle = QuantumCircuit(qr, name="Oracle")
        for constraint in constraints:
            for (i, j) in constraint:
                idx = 2 * (4 * i + j)
                for flag_idx, ctrl_state in zip(flags_idxs, ['00', '01', '10', '11']):
                    oracle.mcx([qr[idx], qr[idx + 1]], qr[flag_idx], ctrl_state=ctrl_state)
            oracle.x(qr[constraint_violated_idx])
            oracle.mcx([qr[i] for i in flags_idxs], qr[constraint_violated_idx], ctrl_state='1111')
            oracle.cx(qr[constraint_violated_idx], qr[all_valid_idx])
            oracle.mcx([qr[i] for i in flags_idxs], qr[constraint_violated_idx], ctrl_state='1111')
            oracle.x(qr[constraint_violated_idx])
            for (i, j) in constraint:
                idx = 2 * (4 * i + j)
                for flag_idx, ctrl_state in zip(flags_idxs, ['00', '01', '10', '11']):
                    oracle.mcx([qr[idx], qr[idx + 1]], qr[flag_idx], ctrl_state=ctrl_state)
        oracle.z(qr[all_valid_idx])
        return oracle.to_gate()

    def diffuser_circuit():
        diffuser = QuantumCircuit(qr, name="Diffuser")
        diffuser.h(variable_qubits)
        diffuser.x(variable_qubits)
        num_controls = num_var_qubits - 1
        if num_controls > 0:
            control_qubits = [qr[i] for i in variable_qubits[:-1]]
            target_qubit = qr[variable_qubits[-1]]
            cz = ZGate().control(num_controls, ctrl_state='1' * num_controls)
            diffuser.append(cz, control_qubits + [target_qubit])
        else:
            diffuser.z(qr[variable_qubits[0]])
        diffuser.x(variable_qubits)
        diffuser.h(variable_qubits)
        return diffuser.to_gate()

    # Estimate number of Grover iterations
    k = len(variable_cells)
    N = 4 ** k
    T = max(1, int(round((np.pi / 4) * np.sqrt(N))))

    # Append Grover iterations
    oracle_gate = oracle_circuit()
    diffuser_gate = diffuser_circuit()
    for _ in range(T):
        qc.append(oracle_gate, qr)
        qc.append(diffuser_gate, qr)

    return qc, variable_qubits

# ----------------------
# Run & Post-process
# ----------------------

grid_example = [
    [1, 0, 3, 4],
    [3, 4, 1, 2],
    [2, 1, 4, 3],
    [4, 3, 2, 0]
]

circuit, var_qubits = sudoku_solver(grid_example)
state = Statevector.from_instruction(circuit)
probs = state.probabilities_dict()

# Only look at variable part
def extract_var(bits):
    bit_list = [bits[::-1][i] for i in var_qubits]
    return ''.join(bit_list)

filtered = [(extract_var(k), v) for k, v in probs.items() if v > 1e-6]
peak = max(filtered, key=lambda x: x[1])
solution_bits = peak[0]

print("Most probable solution (binary):", solution_bits)

# Format back to grid (optional)
def bits_to_grid(bits, template):
    grid = [[cell for cell in row] for row in template]
    idx = 0
    for i in range(4):
        for j in range(4):
            if grid[i][j] == 0:
                value = int(bits[idx:idx+2], 2) + 1
                grid[i][j] = value
                idx += 2
    return grid

sol_grid = bits_to_grid(solution_bits, grid_example)
print("Completed Sudoku:")
for row in sol_grid:
    print(row)


