# Pipeline Stages: T-gate Cultivation Sampler

This notebook walks through each stage of the compilation pipeline,
printing intermediate structures.

In [None]:
import sys, re, time
from pathlib import Path
from copy import deepcopy

_THIS_DIR = Path('.').resolve()
sys.path.insert(0, str(_THIS_DIR))

import tsim, stim, numpy as np
import pyzx_param as zx
import gen

from d_3_circuit_definitions import circuit_source_injection_T

def replace_t_with_s(s):
    s = re.sub(r'^(\s*)T_DAG(\s)', r'\1S_DAG\2', s, flags=re.MULTILINE)
    return re.sub(r'^(\s*)T(\s)', r'\1S\2', s, flags=re.MULTILINE)

def replace_s_with_t(c):
    p = str(c)
    p = re.sub(r'^(\s*)S_DAG(\s)', r'\1T_DAG\2', p, flags=re.MULTILINE)
    return re.sub(r'^(\s*)S(\s)', r'\1T\2', p, flags=re.MULTILINE)

NOISE_STRENGTH = 0.001

## Stage 0: Raw Stim Circuit

The starting point is a stim circuit defining the distance-3 T state cultivation
protocol. We add depolarising noise and convert S gates back to T gates for
non-Clifford simulation.

In [None]:
# Build the projection + observable
_projection_T_Proj_decode = "\nTICK\nCX 8 3\nCX 11 6\nCX 0 9\nTICK\nCX 8 14\nCX 11 9\nCX 0 3\nTICK\nCX 11 14\nCX 8 9\nCX 0 6\nTICK\nCX 6 14\nTICK\nCX 6 3\nTICK\n"
_projection_T_Proj_measure = "\nTICK\nMX 0 11 8\nM 9 3 14\nMX 6\nDETECTOR(0.625, 0.125, 0, -1, -9) rec[-20] rec[-19] rec[-14] rec[-7]\nDETECTOR(0.875, 0.125, 0, -1, -9) rec[-17] rec[-4]\nDETECTOR(1.25, 1.4375, 0, -1, -9) rec[-20] rec[-14] rec[-6] rec[-5]\nDETECTOR(1.5, 1.4375, 0, -1, -9) rec[-16] rec[-3]\nDETECTOR(2.5, 0.9375, 0, -1, -9) rec[-14] rec[-6]\nDETECTOR(2.75, 0.9375, 0, -1, -9) rec[-15] rec[-2]\n"

# Clifford version of injection circuit
clifford_source = replace_t_with_s(circuit_source_injection_T)

# Find observable include
from run import compute_projection_obs_include
proj_str_t, noiseless_raw = compute_projection_obs_include(clifford_source)

# Build noisy circuit
clifford_circuit = stim.Circuit(clifford_source)
noise_model = gen.NoiseModel.uniform_depolarizing(NOISE_STRENGTH)
noisy_clifford = noise_model.noisy_circuit_skipping_mpp_boundaries(clifford_circuit)
noisy_injection_str = replace_s_with_t(noisy_clifford)

c_injection = tsim.Circuit(noisy_injection_str)
c_projection = tsim.Circuit(proj_str_t)
circ = c_injection + c_projection

# Use Clifford version for stim stats (stim can't parse T gates)
proj_clifford = _projection_T_Proj_decode + "S 6\n" + _projection_T_Proj_measure
stim_circ = stim.Circuit(str(noisy_clifford) + proj_clifford + "OBSERVABLE_INCLUDE(0) rec[-1]\n")

print('=== Stage 0: Raw Circuit ===')
print(f'  Qubits:       {stim_circ.num_qubits}')
print(f'  Measurements: {stim_circ.num_measurements}')
print(f'  Detectors:    {stim_circ.num_detectors}')
print(f'  Observables:  {stim_circ.num_observables}')
print(f'  Noiseless observable raw: {noiseless_raw}')
print()
print('First 20 lines of noisy injection circuit (T-gate version):')
lines = noisy_injection_str.strip().split('\n')
for line in lines[:20]:
    print(f'  {line}')
print(f'  ... ({len(lines)} lines total)')

## Stage 1: Graph Preparation (ZX Calculus)

`prepare_graph()` converts the tsim Circuit into a ZX-calculus graph representation,
extracting boundary vertices, parameter names (f-params for noise, m-params for
measurements), and error channel information.

In [None]:
from tsim.core.graph import prepare_graph, get_params, connected_components

prepared = prepare_graph(circ, sample_detectors=True)
graph = prepared.graph

print('=== Stage 1: Prepared ZX Graph ===')
print(f'  Vertices:      {graph.num_vertices()}')
print(f'  Edges:         {graph.num_edges()}')
print(f'  Inputs:        {len(graph.inputs())}')
print(f'  Outputs:       {len(graph.outputs())}')
print(f'  Num outputs:   {prepared.num_outputs}')
print(f'  Num detectors: {prepared.num_detectors}')
print()

# Parameters
all_params = get_params(graph)
f_params = sorted([p for p in all_params if p.startswith('f')])
m_params = sorted([p for p in all_params if p.startswith('m')])
print(f'  f-params (noise):       {len(f_params)}')
print(f'  m-params (measurement): {len(m_params)}')
print(f'  First 10 f-params: {f_params[:10]}')
print(f'  m-params: {m_params}')
print()

# Channel info
from collections import Counter
n_channels = len(prepared.channel_probs)
outcome_counts = [len(cp) for cp in prepared.channel_probs]
outcome_dist = Counter(outcome_counts)
print(f'  Channels: {n_channels}')
print(f'  Outcomes per channel: {dict(sorted(outcome_dist.items()))}')
print(f'  Error transform entries: {len(prepared.error_transform)}')

# Vertex type distribution
type_names = {0: 'boundary', 1: 'Z-spider', 2: 'X-spider'}
type_counts = Counter(graph.type(v) for v in graph.vertices())
print(f'\n  Vertex types:')
for t, count in sorted(type_counts.items()):
    print(f'    {type_names.get(t, f"type-{t}")}: {count}')

### Full ZX Graph (Stage 1)
The complete prepared graph â€” likely large. Green = Z-spiders, red = X-spiders, grey = boundary.

In [None]:
zx.draw(graph, figsize=(18, 10), labels=True)

## Stage 2: Connected Components

The ZX graph decomposes into independent connected components. Each component
maps to a subset of output indices (detectors + observables). Components are
compiled independently.

In [None]:
components = connected_components(prepared.graph)
sorted_components = sorted(components, key=lambda c: len(c.output_indices))

print(f'=== Stage 2: Connected Components ===')
print(f'  Total components: {len(components)}')
print()

for i, cc in enumerate(sorted_components):
    g = cc.graph
    n_out = len(g.outputs())
    params = get_params(g)
    f_p = [p for p in params if p.startswith('f')]
    m_p = [p for p in params if p.startswith('m')]
    
    # Count vertex types
    tc = Counter(g.type(v) for v in g.vertices())
    
    # Count T-like phases
    from stab_rank_cut import is_t_like
    t_count = sum(1 for v in g.vertices() if is_t_like(g.phase(v)))
    
    print(f'  Component {i}: {g.num_vertices()} vertices, '
          f'{g.num_edges()} edges, {n_out} outputs')
    print(f'    Output indices: {cc.output_indices}')
    print(f'    f-params: {len(f_p)}, m-params: {len(m_p)}')
    print(f'    T-count: {t_count}')
    print(f'    Z-spiders: {tc.get(1, 0)}, X-spiders: {tc.get(2, 0)}, '
          f'boundary: {tc.get(0, 0)}')
    print()

### Connected Component Graphs (Stage 2)
Draw the smallest and largest components.

In [None]:
# Smallest component
print(f'Smallest component (component 0): {sorted_components[0].graph.num_vertices()} vertices, '
      f'{len(sorted_components[0].output_indices)} outputs')
zx.draw(sorted_components[0].graph, figsize=(10, 5), labels=True)

# Largest component
print(f'\nLargest component (component {len(sorted_components)-1}): '
      f'{sorted_components[-1].graph.num_vertices()} vertices, '
      f'{len(sorted_components[-1].output_indices)} outputs')
zx.draw(sorted_components[-1].graph, figsize=(18, 10), labels=True)

## Stage 3: Cutting Decomposition

For each component, non-Clifford (T-gate) spiders are decomposed via the
cutting rule: each T-gate is split into 2 branches, doubling the term count.
After cutting, `full_reduce` simplifies, and remaining T-gates use BSS.

We demonstrate on the largest component.

In [None]:
from stab_rank_cut import decompose as stab_rank_decompose

# Pick the largest component (most outputs)
demo_cc = sorted_components[-1]
demo_g = deepcopy(demo_cc.graph)

print(f'=== Stage 3: Cutting Decomposition ===')
print(f'  Demonstrating on component with {len(demo_cc.output_indices)} outputs')
print(f'  Initial: {demo_g.num_vertices()} vertices, {demo_g.num_edges()} edges')
print()

# Step-by-step cutting with debug
g_work = deepcopy(demo_g)
zx.full_reduce(g_work, paramSafe=True)
tc = zx.simplify.tcount(g_work)
print(f'  After full_reduce: {g_work.num_vertices()} vertices, '
      f'{g_work.num_edges()} edges, T-count={tc}')
print()

# Show cutting iterations
print('  Cutting iterations (on a fresh copy):')
cut_terms = stab_rank_decompose(
    deepcopy(demo_g),
    debug=True,
    use_bss_fallback=False,
    max_iterations=10,
    param_safe=True,
    cut_strategy='fewest_neighbors',
    use_tsim_bss=False,
)
print(f'\n  Cutting produced {len(cut_terms)} terms')

# Reduce each and show stats
clifford_count = 0
remaining_t = 0
for i, term in enumerate(cut_terms):
    zx.full_reduce(term, paramSafe=True)
    tc = zx.simplify.tcount(term)
    if tc == 0:
        clifford_count += 1
    else:
        remaining_t += tc

print(f'  After reducing cut terms: {clifford_count} Clifford, '
      f'{len(cut_terms) - clifford_count} with remaining T-count={remaining_t}')

### Cutting Decomposition Graphs (Stage 3)
The reduced graph before cutting, and a sample of the resulting Clifford terms.

In [None]:
# Reduced graph before cutting (T-gates still present)
print(f'Reduced graph (before cutting): T-count={zx.simplify.tcount(g_work)}')
zx.draw(g_work, figsize=(14, 8), labels=True)

# First few Clifford terms after cutting + reduce
n_show = min(3, len(cut_terms))
for i in range(n_show):
    tc_i = zx.simplify.tcount(cut_terms[i])
    print(f'\nCut term {i}: {cut_terms[i].num_vertices()} vertices, T-count={tc_i}, '
          f'scalar.power2={cut_terms[i].scalar.power2}')
    zx.draw(cut_terms[i], figsize=(10, 5), labels=True)

## Stage 4: Output Plugging & Disconnection

For enumeration-based sampling, outputs are plugged (set to 0-effect with
m-parameter phase). After `full_reduce`, the fully-plugged graph may disconnect
into independent sub-components, enabling product evaluation.

In [None]:
from tsim_cutting import _get_f_indices, _plug_outputs, _find_zx_components

# Use same demo component
demo_g2 = deepcopy(demo_cc.graph)
output_indices = demo_cc.output_indices
num_outputs = len(demo_g2.outputs())

f_indices_global = _get_f_indices(prepared.graph)
component_f_set = set(_get_f_indices(demo_g2))
f_selection = [i for i in f_indices_global if i in component_f_set]
component_m_chars = [f'm{i}' for i in output_indices]

print(f'=== Stage 4: Output Plugging ===')
print(f'  Component: {num_outputs} outputs, output_indices={output_indices}')
print(f'  f-params used: {len(f_selection)}')
print()

# Plug: level-0 (normalization) and fully-plugged
plugged_graphs = _plug_outputs(demo_g2, component_m_chars, [0, num_outputs])

# Level 0: normalization
g_level0 = deepcopy(plugged_graphs[0])
print(f'  Level-0 (norm, no outputs plugged):')
print(f'    Before reduce: {g_level0.num_vertices()} vertices, '
      f'{g_level0.num_edges()} edges')
zx.full_reduce(g_level0, paramSafe=True)
g_level0.normalize()
tc0 = zx.simplify.tcount(g_level0)
print(f'    After reduce:  {g_level0.num_vertices()} vertices, '
      f'{g_level0.num_edges()} edges, T-count={tc0}')
print(f'    power2_base:   {g_level0.scalar.power2}')
print()

# Fully plugged
g_plugged = deepcopy(plugged_graphs[1])
print(f'  Fully-plugged (all {num_outputs} outputs plugged):')
print(f'    Before reduce: {g_plugged.num_vertices()} vertices, '
      f'{g_plugged.num_edges()} edges')
zx.full_reduce(g_plugged, paramSafe=True)
g_plugged.normalize()
tc_full = zx.simplify.tcount(g_plugged)
print(f'    After reduce:  {g_plugged.num_vertices()} vertices, '
      f'{g_plugged.num_edges()} edges, T-count={tc_full}')
print()

# Check disconnection
zx_comps = _find_zx_components(g_plugged)
print(f'  Disconnection check:')
if len(zx_comps) >= 2:
    print(f'    Graph DISCONNECTS into {len(zx_comps)} sub-components!')
    for j, comp_verts in enumerate(zx_comps):
        comp_params = set()
        for v in comp_verts:
            phase = g_plugged.phase(v)
            if isinstance(phase, str):
                comp_params.add(phase)
        print(f'    Sub-component {j}: {len(comp_verts)} vertices, '
              f'params: {sorted(comp_params)[:5]}...' if len(comp_params) > 5 else
              f'    Sub-component {j}: {len(comp_verts)} vertices, '
              f'params: {sorted(comp_params)}')
else:
    print(f'    Graph does NOT disconnect (monolithic evaluation)')
    print(f'    Single component: {len(zx_comps[0])} vertices')

### Plugged Graph & Sub-components (Stage 4)
The fully-plugged reduced graph, and its disconnected sub-components (if any).

In [None]:
from tsim_cutting import _extract_subgraph

# Draw the fully-plugged reduced graph
print(f'Fully-plugged graph (reduced): {g_plugged.num_vertices()} vertices, '
      f'T-count={zx.simplify.tcount(g_plugged)}')
zx.draw(g_plugged, figsize=(14, 8), labels=True)

# Draw sub-components if disconnected
if len(zx_comps) >= 2:
    for j, comp_verts in enumerate(zx_comps):
        sub_g = _extract_subgraph(g_plugged, comp_verts, reset_scalar=(j > 0))
        print(f'\nSub-component {j}: {sub_g.num_vertices()} vertices')
        zx.draw(sub_g, figsize=(10, 5), labels=True)

## Stage 5: Compiled Program

The full compilation produces a `SubcompEnumCompiledProgram` with per-component
data. Each component stores compiled scalar graphs (A/B/C/D terms), combo
tables, and parameter index maps.

In [None]:
from tsim_cutting import compile_program_subcomp_enum_general
from tsim_cutting import SubcompEnumComponentData, SubcompComponentData

t0 = time.time()
program = compile_program_subcomp_enum_general(
    prepared, max_cut_iterations=10, debug=False
)
compile_time = time.time() - t0

print(f'=== Stage 5: Compiled Program ===')
print(f'  Compilation time: {compile_time:.2f}s')
print(f'  Total components: {len(program.component_data)}')
print(f'  Total outputs:    {program.num_outputs}')
print(f'  Total f-params:   {program.num_f_params}')
print(f'  Num detectors:    {program.num_detectors}')
print()

n_enum = 0
n_subcomp = 0
total_d_terms = 0

for i, cd in enumerate(program.component_data):
    if isinstance(cd, SubcompEnumComponentData):
        n_enum += 1
        n_out = cd.num_component_outputs
        n_combos = cd.m_combos.shape[0]
        n_sub = cd.num_subcomps
        
        d_terms = 0
        for sc in cd.subcomp_compiled:
            d_terms += sc.d_const_alpha.shape[0] * sc.d_const_alpha.shape[1]
        total_d_terms += d_terms
        
        print(f'  Component {i} [ENUM]: {n_out} outputs, '
              f'{n_combos} combos, {n_sub} sub-comp(s), '
              f'{d_terms} D-terms')
        print(f'    output_indices: {cd.output_indices}')
        print(f'    f_selection: {len(cd.f_selection)} params')
        for j, sc in enumerate(cd.subcomp_compiled):
            n_graphs = sc.phase_indices.shape[0]
            print(f'    Sub-comp {j}: {n_graphs} scalar graphs, '
                  f'A-terms={sc.a_const_phases.shape}, '
                  f'D-terms={sc.d_const_alpha.shape}')
    else:
        n_subcomp += 1
        n_out = len(cd.output_indices)
        n_levels = len(cd.compiled_scalar_graphs)
        has_prod = cd.has_product_level
        
        d_terms = 0
        for csg in cd.compiled_scalar_graphs:
            d_terms += csg.d_const_alpha.shape[0] * csg.d_const_alpha.shape[1]
        total_d_terms += d_terms
        
        print(f'  Component {i} [AUTOREGRESSIVE]: {n_out} outputs, '
              f'{n_levels} levels, product={has_prod}, '
              f'{d_terms} D-terms')
        print(f'    output_indices: {cd.output_indices}')

print(f'\n  Summary: {n_enum} enum + {n_subcomp} autoregressive components')
print(f'  Total D-terms: {total_d_terms}')

## Stage 6: Sampler + Optimizations

The final sampler wraps the compiled program with a channel sampler for noise.
We then apply the acceleration patches:
1. Noiseless cache (skip evaluation for identity-channel shots)
2. Dedup (numpy hash-based deduplication of f-params)
3. Inverse CDF channel sampling

In [None]:
from tsim_cutting import compile_detector_sampler_subcomp_enum_general
from evaluate_matmul_cfloat import evaluate_batch as evaluate_batch_cfloat
import tsim_cutting as mod_cutting
mod_cutting.evaluate_batch = evaluate_batch_cfloat

from sampler_noiseless_cache import add_noiseless_cache
from sampler_dedup import patch_sampler_fast
from channel_sampler_fast import patch_channel_sampler_fast

print('=== Stage 6: Sampler Assembly ===')
print()

# Base sampler
sampler = compile_detector_sampler_subcomp_enum_general(circ, seed=42, max_cut_iterations=10)
cs = sampler._channel_sampler
print(f'Channel sampler:')
print(f'  Channels:    {len(cs.channels)}')
print(f'  Outcomes per channel: {[len(ch.logits) for ch in cs.channels[:5]]}... ')
print()

# Patch 1: Noiseless cache
add_noiseless_cache(sampler)
print()

# Patch 2: Dedup
patch_sampler_fast(sampler, top_k=None, max_unique=None, use_dedup=True, verbose=True)
print()

# Patch 3: Inverse CDF channel sampling
patch_channel_sampler_fast(sampler, verbose=True)
print()

# Quick test
print('Quick sample test (1024 shots)...')
det, obs = sampler.sample(shots=1024, batch_size=1024, separate_observables=True)
trivial = np.all(det == 0, axis=1)
n_kept = int(np.sum(trivial))
n_errors = int(np.sum(obs[trivial, 0].astype(int) != noiseless_raw))
print(f'  Kept: {n_kept}/1024 (PSR={n_kept/1024:.3f})')
print(f'  Errors: {n_errors}')