### Make a heavy Hex Code in Stim ###

In [None]:
import os
import stim
import sinter
import pymatching
import numpy as np
import matplotlib.pyplot as plt


from heavy_hex_code import HeavyHexCode

In [None]:
CODE_DISTANCE=11
NUM_ROUNDS=7
BASIS='Z'
p_err=1e-3

hhc=HeavyHexCode(
    code_distance=CODE_DISTANCE,
    num_rounds=NUM_ROUNDS,
    basis=BASIS,
    after_clifford_depolarization=p_err,
    after_reset_flip_probability=p_err,
    before_measure_flip_probability=p_err,
    before_round_data_depolarization=p_err,
)

circuit_block=hhc.create_heavy_hex_code()
stim_hhc=stim.Circuit(circuit_block)

# preliminary checks
assert stim_hhc.count_determined_measurements()==stim_hhc.num_detectors+stim_hhc.num_observables

In [None]:
dem=stim_hhc.detector_error_model(decompose_errors=True)

### Threshold the code ###

In [None]:
ds = [3, 5, 7, 9, 11]
ps = np.geomspace(1e-4, 1e-2, 7)
tasks = []
for d in ds:
    for p in ps:
        circuit = HeavyHexCode(
            code_distance=d,
            num_rounds=d,
            basis='X',
            after_clifford_depolarization=p,
            after_reset_flip_probability=p,
            before_measure_flip_probability=p,
            before_round_data_depolarization=p,
        )
        circuit_block = circuit.create_heavy_hex_code()
        circ_heavyhex = stim.Circuit(circuit_block)
        tasks.append(sinter.Task(
            circuit=circ_heavyhex,
            json_metadata={'d': d, 'p': p, 'name':f'heavy hex d={d}'}
        ))

        circ_surface = stim.Circuit.generated(
            code_task=f'surface_code:rotated_memory_z',
            rounds=d,
            distance=d,
            after_clifford_depolarization=p,
            before_measure_flip_probability=p,
        )
        tasks.append(sinter.Task(
            circuit=circ_surface,
            json_metadata={'p':p, 'd':d, 'name':f'surface d={d}'},
        ))

stats = sinter.collect(
    num_workers=os.cpu_count()-2,
    tasks=tasks,
    max_shots=10**7,
    max_errors=1000,
    print_progress=True,
    decoders='pymatching'
)

In [None]:
fig,ax = plt.subplots()
sinter.plot_error_rate(
    ax=ax,
    stats=stats,
    group_func=lambda task: task.json_metadata['name'],
    x_func=lambda task: task.json_metadata['p']
)
ax.legend()
ax.loglog()
plt.show()