In [None]:
import numpy as np
import matplotlib.pyplot as plt
from qiskit import QuantumCircuit, QuantumRegister, ClassicalRegister, transpile
from qiskit_aer import AerSimulator
from qiskit_aer.noise import NoiseModel, pauli_error
from joblib import Parallel, delayed
import networkx as nx

def initialize_toric_code(circuit, k, logical_state="00"):
    """
    Initialize the toric code state for a k x k lattice with logical state |logical_state>.
    Uses Hadamard and CNOT gates to prepare the code space and applies logical X operators.
    """
    n = k * k
    data_qubits = circuit.qregs[0]

    # Prepare |00>_L by projecting onto +1 eigenstates of stabilizers
    for v in range(n):  # Vertex stabilizers
        base_idx = v  # Simplified qubit indexing (adjust based on lattice)
        circuit.h(data_qubits[base_idx])
        for offset in [1, k, -1, -k]:  # Neighboring qubits
            target_idx = (base_idx + offset) % (2 * n)
            circuit.cx(data_qubits[base_idx], data_qubits[target_idx])

    # Apply logical X operators for states other than |00>_L
    if logical_state[0] == "1":  # Logical X1
        for i in range(k):  # Horizontal loop
            circuit.x(data_qubits[i])
    if logical_state[1] == "1":  # Logical X2
        for i in range(k):  # Vertical loop
            circuit.x(data_qubits[k * i])

    return circuit

def measure_stabilizers(circuit, k):
    """
    Measure vertex (A_v) and plaquette (B_p) stabilizers.
    Adds ancilla qubits and classical registers for syndrome measurements.
    """
    n = k * k
    data_qubits = circuit.qregs[0]

    ancilla_v = QuantumRegister(n, name="anc_v")
    ancilla_p = QuantumRegister(n, name="anc_p")
    circuit.add_register(ancilla_v, ancilla_p)

    meas_v = ClassicalRegister(n, name="meas_v")
    meas_p = ClassicalRegister(n, name="meas_p")
    circuit.add_register(meas_v, meas_p)

    # Measure vertex stabilizers (X-type)
    for v in range(n):
        circuit.h(ancilla_v[v])
        for offset in [1, k, -1, -k]:  # Neighboring qubits
            target_idx = (v + offset) % (2 * n)
            circuit.cx(ancilla_v[v], data_qubits[target_idx])
        circuit.h(ancilla_v[v])
        circuit.measure(ancilla_v[v], meas_v[v])

    # Measure plaquette stabilizers (Z-type)
    for p in range(n):
        for offset in [1, k, -1, -k]:
            target_idx = (p + offset + n) % (2 * n)  # Offset for plaquette qubits
            circuit.cz(ancilla_p[p], data_qubits[target_idx])
        circuit.measure(ancilla_p[p], meas_p[p])

    return circuit

def apply_recovery(circuit, k, syndromes):
    """
    Apply recovery operations using minimum-weight matching on syndromes.
    Uses networkx for efficient decoding.
    """
    n = k * k
    data_qubits = circuit.qregs[0]

    # Extract vertex and plaquette syndromes
    vertex_defects = [i for i, bit in enumerate(syndromes["meas_v"]) if bit == "1"]
    plaquette_defects = [i for i, bit in enumerate(syndromes["meas_p"]) if bit == "1"]

    # Build graphs for minimum-weight matching
    G_v = nx.Graph()
    G_p = nx.Graph()

    # Add edges with Manhattan distances (toroidal)
    for i in vertex_defects:
        for j in vertex_defects:
            if i < j:
                row_i, col_i = divmod(i, k)
                row_j, col_j = divmod(j, k)
                dx = min(abs(row_i - row_j), k - abs(row_i - row_j))
                dy = min(abs(col_i - col_j), k - abs(col_i - col_j))
                G_v.add_edge(i, j, weight=dx + dy)

    for i in plaquette_defects:
        for j in plaquette_defects:
            if i < j:
                row_i, col_i = divmod(i, k)
                row_j, col_j = divmod(j, k)
                dx = min(abs(row_i - row_j), k - abs(row_i - row_j))
                dy = min(abs(col_i - col_j), k - abs(col_i - col_j))
                G_p.add_edge(i, j, weight=dx + dy)

    # Minimum-weight matching
    vertex_pairs = nx.algorithms.matching.min_weight_matching(G_v)
    plaquette_pairs = nx.algorithms.matching.min_weight_matching(G_p)

    # Apply Pauli corrections
    for i, j in vertex_pairs:
        circuit.x(data_qubits[i % (2 * n)])  # Simplified recovery
    for i, j in plaquette_pairs:
        circuit.z(data_qubits[(i + n) % (2 * n)])  # Simplified recovery

    return circuit

def measure_logical(circuit, k):
    """
    Measure logical Z operators to determine the logical state.
    """
    n = k * k
    data_qubits = circuit.qregs[0]

    logical_anc = QuantumRegister(2, name="log_anc")
    logical_meas = ClassicalRegister(2, name="log_meas")
    circuit.add_register(logical_anc, logical_meas)

    # Logical Z1 (horizontal loop)
    for i in range(k):
        circuit.cz(logical_anc[0], data_qubits[i])
    circuit.measure(logical_anc[0], logical_meas[0])

    # Logical Z2 (vertical loop)
    for i in range(k):
        circuit.cz(logical_anc[1], data_qubits[k * i])
    circuit.measure(logical_anc[1], logical_meas[1])

    return circuit

def run_single_simulation(error_prob, lattice_size, logical_state, shots, sim_engine):
    """
    Simulate the toric code for a single error probability.
    Returns the success probability.
    """
    n = lattice_size * lattice_size
    qubits = QuantumRegister(2 * n, name="data")
    qc = QuantumCircuit(qubits)

    # Initialize and prepare logical state
    qc = initialize_toric_code(qc, lattice_size, logical_state)

    # Add identity gates as placeholders for noise
    for q in qubits:
        qc.id(q)

    # Measure stabilizers
    qc = measure_stabilizers(qc, lattice_size)

    # Simulate circuit to get syndromes (decoding is classical)
    error_model = NoiseModel()
    error_model.add_all_qubit_quantum_error(
        pauli_error([("X", error_prob), ("Y", error_prob), ("Z", error_prob), ("I", 1 - 3 * error_prob)]),
        ["id"]
    )
    noisy_sim = AerSimulator(method="stabilizer", noise_model=error_model)
    compiled_qc = transpile(qc, noisy_sim, optimization_level=2)

    result = noisy_sim.run(compiled_qc, shots=shots).result()
    counts = result.get_counts()

    # Process syndromes and apply recovery (simulated classically)
    success_count = 0
    for outcome, count in counts.items():
        syndromes = {"meas_v": outcome[-2 * n:-n], "meas_p": outcome[-n:]}
        # Simulate recovery (in practice, apply_recovery would modify the circuit)
        logical_outcome = outcome[:2]
        if logical_outcome == logical_state:
            success_count += count

    return success_count / shots

def simulate_toric_code(lattice_size=5, logical_state="00", p_min=0.01, p_max=0.2, num_points=20, shots=50):
    """
    Run simulations for a range of error probabilities and plot success probability.
    Uses parallelization for efficiency.
    """
    sim_engine = AerSimulator(method="stabilizer")
    error_probs = np.linspace(p_min, p_max, num_points)

    # Parallelize simulations
    success_probs = Parallel(n_jobs=-1)(
        delayed(run_single_simulation)(p, lattice_size, logical_state, shots, sim_engine)
        for p in error_probs
    )

    # Plot results
    plt.figure(figsize=(8, 6))
    plt.plot(error_probs, success_probs, 'o-', color='blue', label='Success Probability')
    plt.xlabel('Error Probability per Qubit (p)')
    plt.ylabel('Probability of Correct Logical State')
    plt.title(f'Toric Code Error Correction (k={lattice_size})')
    plt.ylim(-0.05, 1.05)
    plt.grid(True)
    plt.legend()
    plt.show()

    return error_probs, success_probs

if __name__ == "__main__":
    error_probs, success_probs = simulate_toric_code(lattice_size=5, logical_state="00")