In [None]:
import stim

from faulttools.glue.stim import from_stim

r = 2
d = 7


c = stim.Circuit.generated(
    "surface_code:rotated_memory_z",
    rounds=r,
    distance=d,
    after_clifford_depolarization=1e-3,
    after_reset_flip_probability=1e-3,
    before_measure_flip_probability=1e-3,
    before_round_data_depolarization=1e-3,
)
c = c.flattened()
_, nm, nm_symbols, measurement_nodes, observables, detectors = from_stim(c)

In [None]:
from faulttools.glue.stim import export_to_stim_dem, push_out_for_measurement_detectors, wrap_dem_as_sinter_task
from faulttools.noise import NoiseModel

pushed_out, logical_regions, detector_regions = push_out_for_measurement_detectors(
    nm,
    measurement_nodes=measurement_nodes,
    logicals=list(observables.values()),
    detectors=detectors,
)
p = next(nm_symbols.__iter__())

In [None]:
import sinter


def generate_tasks_for_p(p_val: float):
    cc = stim.Circuit.generated(
        "surface_code:rotated_memory_z",
        rounds=r,
        distance=d,
        after_clifford_depolarization=p_val,
        after_reset_flip_probability=p_val,
        before_measure_flip_probability=p_val,
        before_round_data_depolarization=p_val,
    )
    stim_dem = cc.detector_error_model()

    even_more_replaced_nm = NoiseModel(
        nm.diagram,
        [(f, float(w.subs(p, p_val))) for f, w in pushed_out.atomic_faults_with_weight()],
    )
    even_more_replaced_nm.compress(lambda x, y: x * (1 - y) + (1 - x) * y)
    dem = export_to_stim_dem(
        even_more_replaced_nm,
        logical_regions=logical_regions,
        detector_regions=detector_regions,
    )

    return (
        sinter.Task(circuit=cc, detector_error_model=stim_dem, json_metadata={"p": p_val, "name": "stim original"}),
        wrap_dem_as_sinter_task(dem, json_metadata={"p": p_val, "name": "replica"}),
    )


collected_stats = sinter.collect(
    num_workers=16,
    tasks=[t for pp in [i * 10 ** (-j) for j in range(1, 6) for i in [1, 2, 5]] for t in generate_tasks_for_p(pp)],
    max_shots=100_000_000,
    max_errors=10_000,
    decoders=["pymatching"],
    print_progress=True,
)

In [None]:
from matplotlib import pyplot as plt

fig, ax = plt.subplots(1, 1)
sinter.plot_error_rate(
    ax=ax,
    stats=collected_stats,
    x_func=lambda stats: stats.json_metadata["p"],
    group_func=lambda stats: stats.json_metadata["name"],
)
print(f"Stim: {collected_stats[0].errors / collected_stats[0].shots} error rate")
print(f"NEW: {collected_stats[1].errors / collected_stats[1].shots} error rate")
ax.set_ylim(auto=True)
ax.set_xlim(auto=True)
ax.loglog()
ax.set_xlabel("Physical Error Rate")
ax.set_ylabel("Logical Error Rate per Shot")
ax.grid(which="major")
ax.grid(which="minor")
ax.legend()
fig.set_dpi(120)  # Show it bigger