In [None]:
import tsim
import stim
import time
import sinter
import matplotlib.pyplot as plt
import numpy as np
import pyzx as zx
from fractions import Fraction
from pyzx.graph.scalar import cexp
import math

## Part 1: Tsim demo

Tsim is a circuit sampler for Clifford + T gates. Currently, it is designed to be able to simulate Gemini circuits. For example, the [15,1,5] color code magic state distillation circuit:

In [None]:
c = tsim.Circuit.from_file("msd_circuits/d=5_X.stim")

In [None]:
print(
    f"""Qubits: {c.num_qubits}
T-gates: {c.tcount()}
Detectors: {c.num_detectors}
Observables: {c.num_observables}
"""
)

In [None]:
c.diagram("timeline-svg", height=550)

In [None]:
sampler = c.compile_detector_sampler()

In [None]:
sampler.sample(shots=1024)

In [None]:
start = time.perf_counter()
n_samples = 1024 * 32
sampler.sample(shots=n_samples, batch_size=n_samples, append_observables=True)
duration = time.perf_counter() - start
print(f"Time per shot: {duration * 1e6 / n_samples:.2f} microseconds")

## Overview

`tsim.Circuit` is a thin wrapper around `stim.Circuit`. The T gate is introduced by tagging S gates with `[T]`.

In [None]:
c = tsim.Circuit(
    """
    RX 0
    T 0
    H 0
    M 0
    """
)

In [None]:
c.diagram("timeline-svg", height=120)

In [None]:
c.diagram("pyzx");

Tsim support detectors and observables just like Stim.

In [None]:
c = tsim.Circuit(
    """
    R 0 1
    H 0
    CNOT 0 1
    M 0 1
    DETECTOR rec[-1] rec[-2]
    """
)
c.diagram("pyzx");

In [None]:
det_sampler = c.compile_sampler()
det_sampler.sample(5)

For Clifford circuits, Tsim behaves like a slow version of Stim. It supports all Stim gates (currently except "CORRELATED_ERROR") and noise channels:

In [None]:
p = 0.01
stim_circ = stim.Circuit.generated(
    "surface_code:rotated_memory_z",
    distance=3,
    rounds=10,
    after_clifford_depolarization=p,
)
c = tsim.Circuit.from_stim_program(stim_circ)
c.diagram("pyzx")

Tsim is also compatible with `sinter`:

In [None]:
noise_vals = np.logspace(-2.5, -1.5, 4)
tasks = [
    sinter.Task(
        circuit=tsim.Circuit.from_stim_program(
            stim.Circuit.generated(
                "surface_code:rotated_memory_z",
                distance=distance,
                rounds=3,
                after_clifford_depolarization=noise,
            )
        ).cast_to_stim(),
        json_metadata={"p": noise, "distance": distance, "rounds": 3},
    )
    for noise in noise_vals
    for distance in [3, 5]
]

collected_stats = sinter.collect(
    num_workers=8,
    tasks=tasks,
    decoders=["pymatching"],
    max_shots=1024 * 64,
    max_errors=1024 * 32,
    start_batch_size=1024 * 32,
    max_batch_size=1024 * 32,
)

fig, ax = plt.subplots(1, 1)
sinter.plot_error_rate(
    ax=ax,
    stats=collected_stats,
    x_func=lambda stats: stats.json_metadata["p"],
    group_func=lambda stats: stats.json_metadata["distance"],
    failure_units_per_shot_func=lambda stats: stats.json_metadata["rounds"],
)
plt.plot(noise_vals, noise_vals, color="k", linestyle="--", lw=0.5, label="uncoded")
ax.loglog()
ax.set_xlabel("Physical Error Rate")
ax.set_ylabel(f"Probability of logical $|\\bar{1}\\rangle$")
ax.legend();

## Part 2: How tsim works

Tsim is developed based on [Sutcliffe and Kissinger (2025)](https://arxiv.org/abs/2403.06777) and the corresponding code at [mjsutcliffe99/ParamZX](https://github.com/mjsutcliffe99/ParamZX). In particular, the  parametric pyzx extension is taken from there.

In [None]:
def scalar_to_str(scalar: zx.Scalar) -> str:
    scalar_str = ""

    def format_phase_str(alpha, params):
        a_str = str(alpha) if alpha != 0 else ""
        for vars in params:
            a_str += f"+{vars}"
        if len(a_str) > 0 and a_str[0] == "+":
            a_str = a_str[1:]
        return a_str

    for const, vars in zip(scalar.phasenodes, scalar.phasenodevars):
        scalar_str += f"(1 + exp(iπ {format_phase_str(const, vars)}))"

    for pp in scalar.phasepairs:
        a_str = format_phase_str(pp.alpha / 4, pp.paramsA)
        b_str = format_phase_str(pp.beta / 4, pp.paramsB)

        scalar_str += (
            f"(1 + exp(iπ {a_str}) + exp(iπ {b_str}) - exp(iπ ({a_str} + {b_str})))"
        )

    for c in [1, 3]:
        if c not in scalar.phasevars_halfpi:
            continue
        for vars in scalar.phasevars_halfpi[c]:
            a_str = " + ".join(vars)
            scalar_str += f"exp(iπ {a_str} * {c}/4)"

    for pp in scalar.phasevars_pi_pair:
        if len(pp[0]) == 0 or len(pp[1]) == 0:
            continue
        a_str = " + ".join(pp[0])
        b_str = " + ".join(pp[1])

        scalar_str += f"exp(iπ {a_str} * {b_str})"

    if scalar.power2 % 2 == 0:
        if scalar.power2 > 0:
            scalar_str += f" * {2 ** (scalar.power2 // 2)}"
        elif scalar.power2 < 0:
            scalar_str += f" / {2 ** ((-scalar.power2) // 2)}"
    else:
        scalar_str += f" * sqrt(2) ** {scalar.power2}"

    return scalar_str


def evaluate_scalar(scalar: zx.Scalar, vals: dict[str, Fraction]) -> complex:
    number = 1

    vals["1"] = Fraction(1)

    # phase nodes
    for const, vars in zip(scalar.phasenodes, scalar.phasenodevars):
        number *= 1 + cexp(const + sum(vals[var] for var in vars))

    # phase pairs
    for pp in scalar.phasepairs:
        psi = pp.alpha / 4 + sum(vals[var] for var in pp.paramsA)
        phi = pp.beta / 4 + sum(vals[var] for var in pp.paramsB)
        number *= 1 + cexp(psi) + cexp(phi) - cexp(psi + phi)

    # half-pi
    for c in [1, 3]:
        if c not in scalar.phasevars_halfpi:
            continue
        for vars in scalar.phasevars_halfpi[c]:
            number *= cexp(sum(vals[var] for var in vars) * c / 2)

    # pi-pair
    for pp in scalar.phasevars_pi_pair:
        psi = sum(vals[var] for var in pp[0])
        phi = sum(vals[var] for var in pp[1])
        number *= cexp(psi * phi)

    if scalar.is_zero:
        return 0

    number *= cexp(scalar.phase)

    number *= math.sqrt(2) ** scalar.power2
    number *= scalar.floatfactor.to_complex()

    return number

In [None]:
c = tsim.Circuit(
    """
    R 0 1 2
    X_ERROR(0.3) 0
    H 1
    CNOT 1 2
    M 0 1 2
    """
)

g = c.get_graph()
g.normalize()
zx.draw(g)

In [None]:
g_ = g.copy()
g_.set_type(6, 2)
g_.set_phase(6, "a")
g_.remove_vertex(9)
g_.set_phase(7, 0)
g_.set_phase(8, 0)
g_.scalar.add_power(2)
g1 = g_ + g_.adjoint()
zx.draw(g1)
zx.full_reduce(g1)
print(scalar_to_str(g1.scalar))

In [None]:
g_ = g.copy()
g_.set_type(6, 2)
g_.set_phase(6, "a")
g_.remove_vertex(9)
g_.set_type(7, 2)
g_.set_phase(7, "b")
g_.remove_vertex(10)
g_.set_phase(8, 0)
g_.scalar.add_power(1)
g2 = g_ + g_.adjoint()
zx.draw(g2)
zx.full_reduce(g2)
print(scalar_to_str(g2.scalar))

In [None]:
g_ = g.copy()
g_.set_type(6, 2)
g_.set_phase(6, "a")
g_.remove_vertex(9)
g_.set_type(7, 2)
g_.set_phase(7, "b")
g_.remove_vertex(10)
g_.set_type(8, 2)
g_.set_phase(8, "c")
g_.remove_vertex(11)
g2 = g_ + g_.adjoint()
zx.draw(g2)
zx.full_reduce(g2)
print(scalar_to_str(g2.scalar))

In [None]:
s = c.compile_sampler(seed=1)
s.sample(shots=6)


<img src="figures/sampling_flow.png" alt="overview" width=1000/>

### The non-Clifford Case: Stabilizer Rank Decomposition

In [None]:
import random

random.seed(0)
g = zx.generate.cliffordT(5, 200, p_t=0.1)
g.apply_effect("0" * 5)
g.apply_state("0" * 5)
zx.full_reduce(g)
g.normalize()
zx.draw(g, show_scalar=True)

In [None]:
from IPython.display import display, Markdown
import ipywidgets as widgets

gsum = zx.simulate.replace_magic_states(g)
[zx.full_reduce(gi) for gi in gsum.graphs]
graphs = [zx.draw_matplotlib(gi, figsize=(6.5, 4)) for gi in gsum.graphs]
num_t = [zx.tcount(gi) for gi in gsum.graphs]


def plotter(term):
    display(Markdown(f"Number of T-gates: {num_t[term]}"))
    display(graphs[term])


widgets.interactive(plotter, term=widgets.ToggleButtons(options=[0, 1, 2, 3, 4, 5, 6]))

<img src="figures/datastructure.png" alt="overview" width=400/>

```mermaid
flowchart LR
    Start([Stim Circuit]) --> Step1[Parse into ZX Graph]
    
    Step1 --> Step2[Split Into<br/>Connected Components]
    
    Step2 --> Step3[Stabilizer Rank<br/>Decomposition]
    
    Step3 --> Step4[Fully Reduce<br/>Each Diagram Into Scalar]
    
    Step4 --> Step5[Compile Into<br/>Contiguous JAX Arrays]
    
    Step5 --> Step6[Sampling Loop:<br/>Sample Errors,<br/>Autoregressively Build<br/>Measurement Bitstring]
    
    Step6 --> Output([Samples])
    
    %% Styling
    classDef step fill:#e3f2fd,stroke:#1976d2,stroke-width:2px
    classDef startEnd fill:#fff3e0,stroke:#f57c00,stroke-width:3px
    
    class Step1,Step2,Step3,Step4,Step5,Step6,Step7 step
    class Start,Output startEnd
```

# Appendix

### Detector Error Models as ZX Diagrams

In [None]:
from tsim.graph_util import squash_graph, transform_error_basis

<img src="figures/tannergraph.png" alt="detector error model" width=1000/>

In [None]:
c = tsim.Circuit(
    """
    X_ERROR(0.0) 0 # Dummy error that ensures 1-indexing like in https://arxiv.org/pdf/2407.13826
    R 0 4
    X_ERROR(0.1) 1 2 3
    TICK
    CNOT 1 0 2 4 2 0 3 4
    X_ERROR(0.1) 0 4
    TICK
    M 0 4
    X_ERROR(0.1)  1 2 3
    TICK
    M 1 2 3
    DETECTOR rec[-5]
    DETECTOR rec[-4]
    DETECTOR rec[-5] rec[-3] rec[-2]
    DETECTOR rec[-4] rec[-2] rec[-1]
"""
)
# c.diagram("timeline-svg", height=300)
# c.diagram("pyzx")

In [None]:
g = c.get_sampling_graph(sample_detectors=True)
zx.full_reduce(g)
squash_graph(g)
zx.draw(g)

Repetition code with general noise:

In [None]:
p = 0.01
stim_circ = stim.Circuit.generated(
    "repetition_code:memory",
    distance=3,
    rounds=9,
    after_clifford_depolarization=p,
)
c = tsim.Circuit.from_stim_program(stim_circ)
c.diagram("pyzx")

In [None]:
g = c.get_sampling_graph(sample_detectors=True)
zx.full_reduce(g)
squash_graph(g)
zx.draw(g)

In [None]:
graph, error_transform = transform_error_basis(g)
zx.draw(graph)

Inspecting the MSD circuit, we see that all detectors have their own connected component of the ZX diagram. The 5 observables form a single connected component. Essentially, ZX reduction has separated the stabilizer part of the circuit from the observable part. We have reduced the problem of simulating a 85 qubit physical circuit to an equivalent problem of simulating a 5 qubits logical circuit.

In [None]:
c = tsim.Circuit.from_file("msd_circuits/d=5_X.stim")
g = c.get_sampling_graph(sample_detectors=True)
zx.draw(g)

In [None]:
zx.full_reduce(g)
squash_graph(g)
g, _ = transform_error_basis(g)
zx.draw(g)

In [None]:
from tsim.graph_util import connected_components

components = connected_components(g)
components = sorted(components, key=lambda x: len(x.output_indices))
zx.draw(components[-1].graph)

In [None]:
c = tsim.Circuit.from_file("msc_circuits/d=3-degenerate-basis=Y-p=0.001_T.stim")
print("T-gates:", c.tcount())
g = c.get_sampling_graph(sample_detectors=True)
zx.full_reduce(g)
squash_graph(g)
g, _ = transform_error_basis(g)
zx.draw(g)

In [None]:
c = tsim.Circuit.from_file("msc_circuits/d=5-degenerate-basis=Y-p=0.001_T.stim")
print("T-gates:", c.tcount())
g = c.get_sampling_graph(sample_detectors=True)
zx.full_reduce(g)
squash_graph(g)
g, _ = transform_error_basis(g)
zx.draw(g)

In [None]:
c = tsim.Circuit.from_file("msc_circuits/d=3-degenerate-basis=Y-p=0.001_T.stim")
c.tcount()