In [None]:
import json
import requests
from random import choice, randint
from requests.packages.urllib3.exceptions import InsecureRequestWarning
requests.packages.urllib3.disable_warnings(InsecureRequestWarning)

def deutsch_jozsa_quokka_batch(
    my_quokka: str,
    n: int,
    shots: int = 10,
    oracle_type: str | None = None,   # "Constant0", "Constant1", "Balanced", or None=random
    verify: bool = False,
    timeout: int = 120
):
    """
    Deutsch–Jozsa on n query qubits + 1 ancilla.
    Returns: (oracle_type_used, a_bits_or_None, raw_response_json)
    """

    # Batch endpoint (as per Quokka docs)
    url = f"http://{my_quokka}.quokkacomputing.com/qsim/perform_experiment"

    # Choose oracle type if not fixed
    if oracle_type is None:
        oracle_type = choice(["Constant0", "Constant1", "Balanced"])

    a_bits = None  # only used for Balanced (a != 0)
    nq = n + 1      # total qubits (query + ancilla)
    anc = n         # ancilla index

    ops = []

    # 1) Create + initialize |0...0>
    ops.append({"operation": "create_circuit", "num_qubits": nq})
    ops.append({
        "operation": "set_state",
        "state": 0,
        "complex_value": {"re": 1, "im": 0}
    })

    # 2) Prepare |+...+> on query
    for i in range(n):
        ops.append({"operation": "gate", "gate": "hadamard", "q": i})

    # 3) Prepare ancilla in |-> = H X |0>
    ops.append({"operation": "gate", "gate": "X", "q": anc})
    ops.append({"operation": "gate", "gate": "hadamard", "q": anc})


    # 4) Oracle U_f
    if oracle_type == "Constant0":
        # do nothing
        pass

    elif oracle_type == "Constant1":
        # f(x)=1 => flip ancilla (because ancilla is |-> this only adds a global phase, DJ still says "constant")
        ops.append({"operation": "gate", "gate": "X", "q": anc})

    elif oracle_type == "Balanced":
        # f(x) = a·x (mod 2), with a != 0 to guarantee balanced
        a = randint(1, 2**n - 1)
        a_bits = format(a, f"0{n}b")
        for i, bit in enumerate(a_bits):
            if bit == "1":
                ops.append({
                    "operation": "gate",
                    "gate": "cnot",
                    "q_control": i,
                    "q_target": anc
                })
    else:
        raise ValueError("oracle_type must be Constant0, Constant1, Balanced, or None")

   
    # 5) Interference: H on query again
    for i in range(n):
        ops.append({"operation": "gate", "gate": "hadamard", "q": i})

    # 6) Measure query qubits ONLY
    # Quokka returns an integer for the measured subset; format into bits yourself.
    ops.append({"operation": "measure", "lq2m": list(range(n))})
    
    # 7) Tear down
    ops.append({"operation": "destroy_circuit"})

    # --- Run shots (batch runs one experiment; repeat shots times) ---
    # This matches how your earlier raw outputs looked: per-shot raw classical data.
    results = []
    for _ in range(shots):
        r = requests.post(url, json=ops, verify=verify, timeout=timeout)
        if r.status_code != 200:
            raise RuntimeError(f"HTTP {r.status_code} {r.headers.get('Content-Type')}\n{r.text[:800]}")
        results.append(r.json())

    return oracle_type, a_bits, results

def extract_measure_int_from_shot(shot_ops):
    """shot_ops is the list of op-result dicts for one run."""
    for item in reversed(shot_ops):
        if isinstance(item, dict) and item.get("operation") == "measure":
            val = item.get("result")
            # ensure it's an int but NOT a bool
            if isinstance(val, bool) or not isinstance(val, int):
                raise RuntimeError(f"Measure returned non-int: {item}")
            return val
    raise RuntimeError(f"No measure op found in shot: {shot_ops}")


def predict_from_results(results, n):
    """
    Handles results shaped like:
      results = [ [shot1_ops, shot2_ops, ...] ]
    or:
      results = [shot1_ops, shot2_ops, ...]
    """
    # flatten one level if needed
    if len(results) == 1 and isinstance(results[0], list) and results and isinstance(results[0][0], list):
        shots = results[0]
    else:
        shots = results

    measured_ints = [extract_measure_int_from_shot(shot_ops) for shot_ops in shots]

    # Deutsch–Jozsa decision rule
    pred = "Constant" if all(v == 0 for v in measured_ints) else "Balanced"
    return pred, measured_ints


# ---------- Example ----------
my_quokka = "quokka1"   
n = 24

oracle_type, a_bits, results = deutsch_jozsa_quokka_batch(
    my_quokka=my_quokka,
    n=n,
    shots=10,
    oracle_type=None,   # random!
    verify=False
)

#print(results)

pred, ints = predict_from_results(results, n)

print("Oracle was:", oracle_type)
print("Measured ints:", ints)
print("Prediction:", pred)
