# Error correction with cat qubits
# Solved by Ritesh Roshan Sahoo


In [None]:
import stim
import pymatching
import sinter
import matplotlib.pyplot as plt
import numpy as np
print(stim.__version__)

## Task 2.1: Error detection code

In [None]:
# SOLUTION ===
circuit = stim.Circuit()
circuit.append("R", [0, 1, 2])
circuit.append("H", [0])
circuit.append("CX", [0, 1])
circuit.append("X_ERROR", [0, 1], 0.1)
circuit.append("CX", [0, 2])
circuit.append("CX", [1, 2])
circuit.append("M", [2])
circuit.append("DETECTOR", [stim.target_rec(-1)])
circuit.append("M", [0, 1])
circuit.append("OBSERVABLE_INCLUDE", [stim.target_rec(-2)], 0)
# ===

In [None]:
# SOLUTION ===
sampler = circuit.compile_detector_sampler()
stats = sampler.sample(shots=100000)
# stats[:, 0] is detector (Leakage)
# stats[:, 1] is observable (Logical Error)
valid = ~stats[:, 0]
simulated_logical_error_rate = np.mean(stats[valid, 1]) if np.any(valid) else 0.0
print(f"Simulated Logical Error Rate from {np.sum(valid)} valid shots: {simulated_logical_error_rate}")

p = 0.1
success_prob = 1 - 2*p*(1-p) - p*p # The probability of no error or X0X1 (which is ID)
# Wait, X0X1 is identity so it is 'success' logic wise, but is it 'detected'?
# X0X1 -> Parity even. Detector=0. Accepted.
# So accepted runs are ID and X0X1.
# Leakage (Detector=1) are X0 and X1.
# Prob(Detector=0) = (1-p)^2 + p^2.
print(f"Theoretical Success Probability: {(1-p)**2 + p**2}")
# ===

# 3 - Repetition Code

In [None]:
def decode_repetition_code(meas):
    # SOLUTION ===
    # meas is ([data...], [stabs...]) strings 
    # This function expects strings inputs as per previous signature
    data_bits = [int(x) for x in meas[0]]
    stab_bits = [int(x) for x in meas[1]]
    n = len(data_bits)
    correction = [0] * n
    curr = 0
    for i, s in enumerate(stab_bits):
        curr = (curr + s) % 2
        correction[i+1] = curr
    corrected = [(d+c)%2 for d, c in zip(data_bits, correction)]
    return 1 if sum(corrected) > n/2 else 0
    # ===

In [None]:
def repetition_code_circuit(n: int, p: float = 0.1) -> stim.Circuit:
    # SOLUTION ===
    c = stim.Circuit()
    c.append("R", range(2*n-1))
    data_qubits = [2*i for i in range(n)]
    # Apply noise
    c.append("X_ERROR", data_qubits, p)
    # Measure Stabilizers Z_i Z_{i+1}
    for i in range(n-1):
        # Ancilla is 2i+1
        # Data are 2i and 2i+2
        c.append("CX", [2*i, 2*i+1])
        c.append("CX", [2*i+2, 2*i+1])
        c.append("M", [2*i+1])
    c.append("M", data_qubits)
    return c
    # ===

def simulate_circuit(circuit, n, num_shots=100_000):
    # SOLUTION ===
    sampler = circuit.compile_sampler()
    samples = sampler.sample(shots=num_shots)
    results = {}
    for s in samples:
        # stabs are first n-1 bits, data are next n bits
        stabs = "".join(str(int(b)) for b in s[:n-1])
        data = "".join(str(int(b)) for b in s[n-1:])
        key = (data, stabs)
        results[key] = results.get(key, 0) + 1
    return results
    # ===

def logical_error_rate(results, logical_prepared=0):
    # SOLUTION ===
    errs = 0
    total = 0
    for (data, stabs), count in results.items():
        if decode_repetition_code((data, stabs)) != logical_prepared:
            errs += count
        total += count
    return errs / total if total > 0 else 0
    # ===

In [None]:
# SOLUTION ===
error_probabilities = np.logspace(-2, np.log10(0.5), 10)
distances = [3, 5, 7, 9]

plt.figure(figsize=(10, 6))

for d in distances:
    logical_errors = []
    for p in error_probabilities:
         circuit = repetition_code_circuit(d, p)
         results = simulate_circuit(circuit, d, num_shots=10000)
         p_L = logical_error_rate(results)
         logical_errors.append(p_L)
    plt.loglog(error_probabilities, logical_errors, 'o-', label=f'd={d}')

plt.xlabel("Physical Error Probability p")
plt.ylabel("Logical Error Probability p_L")
plt.title("Repetition Code Threshold Simulation")
plt.grid(True, which="both", ls="-", alpha=0.4)
plt.legend()
# plt.show() # Uncomment to see plot
# ===

# 4 - Sinter Benchmarks

In [None]:
def generate_rep_code_bit_flips(d, noise):
    # SOLUTION ===
    c = stim.Circuit.generated("repetition_code:memory", distance=d, rounds=3*d, after_clifford_depolarization=noise)
    # Replace depolarization with X_ERROR to simulate bit-flip only
    return stim.Circuit(str(c).replace("DEPOLARIZE1", "X_ERROR"))
    # ===

In [None]:
# Running Sinter Task 4.1
tasks = [
    sinter.Task(
        circuit=generate_rep_code_bit_flips(d, noise),
        json_metadata={'d': d, 'p': noise},
    )
    for d in [3, 5, 7]
    for noise in [0.001, 0.005, 0.01, 0.05, 0.1]
]
stats = sinter.collect(num_workers=4, tasks=tasks, decoders=['pymatching'], max_shots=100_000, max_errors=1000)

fig, ax = plt.subplots(1, 1)
sinter.plot_error_rate(
    ax=ax,
    stats=stats,
    x_func=lambda stats: stats.json_metadata['p'],
    group_func=lambda stats: stats.json_metadata['d'],
)
ax.loglog()
ax.set_title("Cat Repetition Code Threshold")
ax.grid(which='major')
# plt.show()


In [None]:
def generate_surface_code_depolarizing_noise(d, noise):
    # SOLUTION ===
    return stim.Circuit.generated("surface_code:rotated_memory_x", distance=d, rounds=3*d, after_clifford_depolarization=noise)
    # ===

# 5 - Core Task: Hamming Code

In [None]:
# SOLUTION ===
def hamming_7_4_x_memory(p):
    c = stim.Circuit()
    # 7 data qubits (0-6), 3 ancilla (7-9)
    # Hamming [7,4,3] stabilziers
    c.append("R", range(10))
    c.append("X_ERROR", range(7), p)
    # Stab 1: 0, 2, 4, 6. Anc 7.
    c.append("H", [7])
    for q in [0, 2, 4, 6]: c.append("CX", [q, 7])
    c.append("H", [7])
    c.append("M", [7])
    c.append("DETECTOR", [stim.target_rec(-1)])
    # Stab 2: 1, 2, 5, 6. Anc 8.
    c.append("H", [8])
    for q in [1, 2, 5, 6]: c.append("CX", [q, 8])
    c.append("H", [8])
    c.append("M", [8])
    c.append("DETECTOR", [stim.target_rec(-1)])
    # Stab 3: 3, 4, 5, 6. Anc 9.
    c.append("H", [9])
    for q in [3, 4, 5, 6]: c.append("CX", [q, 9])
    c.append("H", [9])
    c.append("M", [9])
    c.append("DETECTOR", [stim.target_rec(-1)])
    # Logical Meas
    c.append("M", range(7))
    # Observables. For [7,4,3], we have 4 logical qubits.
    # We need to define observables for them. 
    # For simplicity, let's track just one.
    c.append("OBSERVABLE_INCLUDE", [stim.target_rec(-7)], 0)
    return c

# Benchmarking Hamming Code
hamming_tasks = [sinter.Task(
    circuit=hamming_7_4_x_memory(p),
    json_metadata={'p': p, 'd': 3}
) for p in [0.001, 0.01, 0.1]]

hamming_stats = sinter.collect(num_workers=4, tasks=hamming_tasks, decoders=['pymatching'], max_shots=10000)

for stat in hamming_stats:
    print(f"p={stat.json_metadata['p']}: Logical Error Rate = {stat.errors / stat.shots}")

# ===