This is a demo.

In [None]:
import stim
from tsim.circuit import Circuit
import matplotlib.pyplot as plt
import numpy as np
import time

In [None]:
p = 0.01
stim_circ = stim.Circuit.generated(
    # "repetition_code:memory",
    "surface_code:rotated_memory_z",
    distance=5,
    rounds=2,
    after_clifford_depolarization=p,
    after_reset_flip_probability=p * 2,
    before_measure_flip_probability=p,
    before_round_data_depolarization=p * 3,
)
stim_circ.diagram("timeline-svg")

In [None]:
c = Circuit.from_stim_program(stim_circ)
c.diagram()
c.without_noise().diagram()

In [None]:
sampler = c.compile_sampler()
print(sampler)

In [None]:
n_samples = 200
sampler.sample(n_samples, batch_size=100)

In [None]:
det_sampler = c.compile_detector_sampler()
det_sampler

In [None]:
det_sampler.sample(n_samples)

In [None]:
stim_sampler = stim_circ.compile_sampler()
stim_det_sampler = stim_circ.compile_detector_sampler()

In [None]:
n_samples = 50_000
samples = sampler.sample(n_samples, batch_size=50_000)
stim_samples = stim_sampler.sample(n_samples)

In [None]:
def compare_hist(s1, s2, bins=50):
    h1 = np.count_nonzero(s1, axis=1)
    h2 = np.count_nonzero(s2, axis=1)
    m = np.max([np.max(h1), np.max(h2)])
    plt.hist(h1, alpha=0.5, label="ZX", range=(0, m), bins=bins, color="blue")
    plt.hist(h2, alpha=0.5, label="Stim", range=(0, m), bins=bins, color="red")
    plt.legend()


compare_hist(samples, stim_samples, bins=20)

In [None]:
n_samples = 5_000

start = time.perf_counter()
obs_samples = det_sampler.sample(n_samples, append_observables=True)
duration_zx = time.perf_counter() - start

start = time.perf_counter()
obs_stim_samples = stim_det_sampler.sample(n_samples, append_observables=True)
duration_stim = time.perf_counter() - start


print("\nTime per sample:")
print(f"(ZX)   {duration_zx / n_samples:.2e} seconds")
print(f"(Stim) {duration_stim / n_samples:.2e} seconds")

In [None]:
compare_hist(obs_samples, obs_stim_samples, bins=20)

Magic state distillation

In [None]:
c = Circuit().from_file("msd_circuits/d=5_X.stim")
c.diagram(labels=False)

In [None]:
sampler = c.compile_detector_sampler()
print(sampler)

In [None]:
start = time.perf_counter()
n_samples = 1024 * 64
sampler.sample(shots=n_samples, batch_size=n_samples, append_observables=True)
duration = time.perf_counter() - start
print(f"Time per shot: {duration * 1e6 / n_samples:.2f} microseconds")

In [None]:
num_stim_samples = n_samples * 100
stim_sampler = stim_circ.compile_detector_sampler()
start = time.perf_counter()
stim_sampler.sample(shots=num_stim_samples, append_observables=True)
duration = time.perf_counter() - start
print(f"Time per shot: {duration * 1e6 / num_stim_samples:.2f} microseconds")