In [1]:
# ==========================================================================
# SECTION 1: Setup and Imports
# ==========================================================================
import sys
import os
import time
import warnings
import numpy as np
from dataclasses import dataclass
from typing import Optional, Dict, Any, List, Tuple
from collections import defaultdict
import traceback

warnings.filterwarnings('ignore')

# Ensure src is in path
src_path = os.path.abspath(os.path.join(os.getcwd(), '..'))
if src_path not in sys.path:
    sys.path.insert(0, src_path)

# Clear module cache for fresh import
import importlib
modules_to_clear = [k for k in list(sys.modules.keys()) if 'qectostim' in k]
for mod in modules_to_clear:
    del sys.modules[mod]
print(f"Cleared {len(modules_to_clear)} cached modules")

# Import testing utilities
from qectostim.testing import load_all_decoders, STATUS_OK, STATUS_WARN, STATUS_SKIP, STATUS_FAIL

# Import code discovery
from qectostim.codes import discover_all_codes

# Import noise models
from qectostim.noise.models import NoiseModel, CircuitDepolarizingNoise

# Import gadgets - Updated to include Bell-state teleportation
from qectostim.gadgets import (
    # Transversal gates
    TransversalHadamard, TransversalS, TransversalT,
    TransversalX, TransversalY, TransversalZ,
    TransversalCNOT, TransversalCZ, TransversalSWAP,
    # Teleportation-based gates (product-state - 2 block)
    TeleportedHadamard, TeleportedIdentity,
    # Bell-state teleportation (3 block) - NEW! Correctly implements S gate
    BellTeleportedS, BellTeleportedSDag,
    # CSS surgery
    LatticeZZMerge, LatticeXXMerge, SurgeryCNOT,
)

# Import experiment infrastructure
from qectostim.experiments.ft_gadget_experiment import FaultTolerantGadgetExperiment
from qectostim.gadgets.base import PhaseResult, PhaseType, ObservableTransform
from qectostim.gadgets.layout import QubitAllocation
from qectostim.experiments.stabilizer_rounds import DetectorContext
import stim

print("✓ All imports successful")

Cleared 0 cached modules
✓ All imports successful
✓ All imports successful


In [2]:
# ==========================================================================
# Load decoders and codes
# ==========================================================================
decoder_classes = load_all_decoders()
print(f"Loaded {len(decoder_classes)} decoders: {list(decoder_classes.keys())}")

all_codes = discover_all_codes(include_non_css=False)
print(f"Discovered {len(all_codes)} CSS codes")

# Standard noise model for testing
noise = CircuitDepolarizingNoise(p1=0.001, p2=0.001)
print(f"Noise model: CircuitDepolarizingNoise(p1=0.001, p2=0.001)")

# ==========================================================================
# Select ~10 representative test codes from discovered codes
# ==========================================================================
# Use discover_all_codes to get properly instantiated codes

# Find good test codes from the discovered set
target_codes = [
    'FourQubit422Code',     # Small CSS [[4,2,2]]
    'SixQubit622Code',      # Small CSS [[6,2,2]]
    'SteanCode713',         # Classic [[7,1,3]]
    'ShorCode91',           # Shor's [[9,1,3]]
    'HammingCSSCode_m3',    # Hamming-based
    'RotatedSurfaceCode_d3',# Rotated surface
    'ToricCode33',          # Toric code
    'XZZXSurfaceCode_d3',   # XZZX variant
    'TriangularColourCode_d3', # Color code
    'RepetitionCode_d5',    # Repetition
    'XZZX_Surface_3',
    'LoopToricCode4D',
    'Hamming_CSS_15',
    'Code_832',
    'Colour488_[[9,1,3]]',
    'QuantumTanner_4',
    'BalancedProduct_5x5_G1'


]

# Build TEST_CODES from discovered codes, falling back to known-good codes
TEST_CODES = {}

# First try discovered codes
for name in target_codes:
    if name in all_codes:
        TEST_CODES[name] = all_codes[name]

# Add fallback small codes if we didn't find 10
from qectostim.codes.small.four_two_two import FourQubit422Code
from qectostim.codes.small.steane_713 import SteanCode713
from qectostim.codes.small.shor_code import ShorCode91

# Always include the basic test codes
TEST_CODES['FourQubit422'] = FourQubit422Code()
TEST_CODES['SteanCode713'] = SteanCode713()
TEST_CODES['ShorCode91'] = ShorCode91()

# Try to add more from discovered codes matching patterns
for name, code in all_codes.items():
    # if len(TEST_CODES) >= 10:
    #     break
    # Add interesting codes
    if any(x in name for x in ['Surface', 'Toric', 'Color', 'Colour', 'XZZX']):
        if name not in TEST_CODES:
            TEST_CODES[name] = code

print(f"\nTest codes ({len(TEST_CODES)} total):")
for name, code in TEST_CODES.items():
    n = code.n
    k = code.k if hasattr(code, 'k') else '?'
    d = code.d if hasattr(code, 'd') else '?'
    print(f"  {name:<30}: [[{n}, {k}, {d}]]")

Loaded 13 decoders: ['PyMatching', 'FusionBlossom', 'BeliefMatching', 'BPOSD', 'Tesseract', 'UnionFind', 'MLE', 'Hypergraph', 'Chromobius', 'Concatenated', 'FlatConcat', 'Hierarchical', 'SingleShot']
Discovered 76 CSS codes
Noise model: CircuitDepolarizingNoise(p1=0.001, p2=0.001)

Test codes (36 total):
  XZZX_Surface_3                : [[9, 1, ?]]
  LoopToricCode4D               : [[8, 2, ?]]
  Hamming_CSS_15                : [[15, 7, ?]]
  Code_832                      : [[8, 3, ?]]
  Colour488_[[9,1,3]]           : [[9, 1, ?]]
  QuantumTanner_4               : [[32, 2, ?]]
  BalancedProduct_5x5_G1        : [[41, 1, ?]]
  FourQubit422                  : [[4, 2, ?]]
  SteanCode713                  : [[7, 1, ?]]
  ShorCode91                    : [[9, 1, ?]]
  RotatedSurface_[[9,1,3]]      : [[9, 1, ?]]
  RotatedSurface_[[25,1,5]]     : [[25, 1, ?]]
  ToricCode_3x3                 : [[18, 2, ?]]
  ToricCode_5x5                 : [[50, 2, ?]]
  XZZX_Surface_5                : [[25, 1, ?

## Section 2: Zero-Noise Verification

**Critical Sanity Check**: With no noise, the logical error rate MUST be 0.

If LER ≠ 0 with no noise, it indicates:
- Observable tracking bug (wrong qubits in OBSERVABLE_INCLUDE)
- Detector emission bug (wrong measurements referenced)
- Gate implementation bug (wrong physical operations)

In [3]:
# ==========================================================================
# SECTION 2: Zero-Noise Verification Helper
# ==========================================================================

# Known non-working gadgets (expected to fail)
KNOWN_BROKEN_GADGETS = {
    'TransversalT': 'Non-Clifford (Stim cannot simulate)',
}

def verify_zero_noise(
    gadget,
    codes: List,
    gadget_name: str,
    num_shots: int = 1000,
    measurement_basis: str = "Z",
) -> Tuple[bool, float, Optional[str]]:
    """
    Verify that a gadget produces LER=0 with no noise.
    
    Returns (passed, ler, error_message)
    """
    try:
        exp = FaultTolerantGadgetExperiment(
            codes=codes,
            gadget=gadget,
            noise_model=None,  # NO NOISE!
            num_rounds_before=1,
            num_rounds_after=1,
            measurement_basis=measurement_basis,
        )
        
        passed, ler, error_msg = exp.verify_zero_noise_ler(num_shots)
        return passed, ler, error_msg
        
    except Exception as e:
        return False, 1.0, f"Exception: {e}"


def run_zero_noise_tests(gadgets_config: Dict, test_code, test_code_name: str):
    """
    Run zero-noise verification on all gadgets.
    """
    print("=" * 80)
    print(f"ZERO-NOISE VERIFICATION on {test_code_name}")
    print("=" * 80)
    print("Gadget                        | Status | LER    | Notes")
    print("-" * 80)
    
    results = {}
    for name, (cls, kwargs, is_two_qubit) in gadgets_config.items():
        # Check if this is a known broken gadget
        if name in KNOWN_BROKEN_GADGETS:
            reason = KNOWN_BROKEN_GADGETS[name]
            print(f"{name:<30} | ~ SKIP | --     | {reason[:35]}")
            results[name] = {'passed': None, 'ler': None, 'error': reason, 'skipped': True}
            continue
        
        try:
            gadget = cls(**kwargs)
            codes = [test_code, test_code] if is_two_qubit else [test_code]
            
            passed, ler, error_msg = verify_zero_noise(gadget, codes, name, measurement_basis="Z")
            
            status = "✓ PASS" if passed else "✗ FAIL"
            notes = error_msg if error_msg else ""
            print(f"{name:<30} | {status:<6} | {ler:.4f} | {notes[:35]}")
            results[name] = {'passed': passed, 'ler': ler, 'error': error_msg}
            
        except Exception as e:
            print(f"{name:<30} | ✗ ERR  | --     | {str(e)[:35]}")
            results[name] = {'passed': False, 'ler': None, 'error': str(e)}
    
    print("-" * 80)
    passed_count = sum(1 for r in results.values() if r.get('passed', False))
    skipped_count = sum(1 for r in results.values() if r.get('skipped', False))
    tested_count = len(results) - skipped_count
    print(f"\nSummary: {passed_count}/{tested_count} passed zero-noise verification")
    print(f"         {skipped_count} skipped (known broken)")
    
    return results

In [4]:
# ==========================================================================
# Define all gadgets for testing
# ==========================================================================

SINGLE_QUBIT_GADGETS = {
    'TransversalH': (TransversalHadamard, {'include_stabilizer_rounds': False}, False),
    'TransversalS': (TransversalS, {'include_stabilizer_rounds': False}, False),
    'TransversalT': (TransversalT, {'include_stabilizer_rounds': False}, False),  # Non-Clifford - expected to fail
    'TransversalX': (TransversalX, {'include_stabilizer_rounds': False}, False),
    'TransversalY': (TransversalY, {'include_stabilizer_rounds': False}, False),
    'TransversalZ': (TransversalZ, {'include_stabilizer_rounds': False}, False),
}

TWO_QUBIT_GADGETS = {
    'TransversalCNOT': (TransversalCNOT, {'include_stabilizer_rounds': False}, True),
    'TransversalCZ': (TransversalCZ, {'include_stabilizer_rounds': False}, True),
    'TransversalSWAP': (TransversalSWAP, {'include_stabilizer_rounds': False}, True),
}

# Teleportation gadgets:
# - Product-state (2 blocks): TeleportedH, TeleportedIdentity
# - Bell-state (3 blocks): BellTeleportedS, BellTeleportedSDag (NEW - fixes S gate issue!)
TELEPORTATION_GADGETS = {
    # Product-state teleportation (2 blocks) - works for H and Identity
    'TeleportedH': (TeleportedHadamard, {'include_stabilizer_rounds': False}, False),
    'TeleportedIdentity': (TeleportedIdentity, {'include_stabilizer_rounds': False}, False),
    # Bell-state teleportation (3 blocks) - CORRECTLY implements S gate!
    'BellTeleportedS': (BellTeleportedS, {}, False),
    'BellTeleportedSDag': (BellTeleportedSDag, {}, False),
}

# Note: Legacy TeleportedS and TeleportedT are BROKEN and excluded
# - TeleportedS: Product-state ancilla |+i⟩ doesn't give correct S transformation
# - TeleportedT: Same issue PLUS Stim cannot simulate non-Clifford T gate

SURGERY_GADGETS = {
    'LatticeZZMerge': (LatticeZZMerge, {'num_merge_rounds': 1}, True),
    'LatticeXXMerge': (LatticeXXMerge, {'num_merge_rounds': 1}, True),
    'SurgeryCNOT': (SurgeryCNOT, {'num_rounds_before': 0, 'num_rounds_after': 0, 'num_merge_rounds': 1}, True),
}

ALL_GADGETS = {
    **SINGLE_QUBIT_GADGETS,
    **TWO_QUBIT_GADGETS,
    **TELEPORTATION_GADGETS,
    **SURGERY_GADGETS,
}

print(f"Defined {len(ALL_GADGETS)} gadgets for testing:")
print(f"  Single-qubit transversal: {list(SINGLE_QUBIT_GADGETS.keys())}")
print(f"  Two-qubit transversal: {list(TWO_QUBIT_GADGETS.keys())}")
print(f"  Teleportation (product+Bell): {list(TELEPORTATION_GADGETS.keys())}")
print(f"  Surgery: {list(SURGERY_GADGETS.keys())}")
print(f"\nNote: TransversalT is included but expected to fail (non-Clifford)")

Defined 16 gadgets for testing:
  Single-qubit transversal: ['TransversalH', 'TransversalS', 'TransversalT', 'TransversalX', 'TransversalY', 'TransversalZ']
  Two-qubit transversal: ['TransversalCNOT', 'TransversalCZ', 'TransversalSWAP']
  Teleportation (product+Bell): ['TeleportedH', 'TeleportedIdentity', 'BellTeleportedS', 'BellTeleportedSDag']
  Surgery: ['LatticeZZMerge', 'LatticeXXMerge', 'SurgeryCNOT']

Note: TransversalT is included but expected to fail (non-Clifford)


In [5]:
# ==========================================================================
# Decoder Compatibility Test Infrastructure
# ==========================================================================

STATUS_OK = 'OK'
STATUS_WARN = 'WARN'
STATUS_FAIL = 'FAIL'
STATUS_SKIP = 'SKIP'
STATUS_NA = 'N/A'

# Decoders that have special requirements
SPECIALIZED_DECODERS = {
    'Chromobius': 'color_code',  # Requires color code with 4th coordinate annotations
    'Concatenated': 'concatenated_code',  # Requires ConcatenatedCode
    'FlatConcat': 'concatenated_code',  # Requires ConcatenatedCode
    'MLE': 'small_detector_count',  # MLE is only practical for ≤30 detectors
}

MLE_MAX_DETECTORS = 30  # MLEDecoder limit

def check_decoder_compatibility(decoder_name: str, code, code_name: str, circuit=None) -> Tuple[bool, str]:
    """Check if a decoder is compatible with a code.
    
    Returns (is_compatible, reason) tuple.
    
    Parameters
    ----------
    decoder_name : str
        Name of the decoder
    code : Code
        The code being tested
    code_name : str
        Name of the code
    circuit : stim.Circuit, optional
        The Stim circuit (needed for detector count checks)
    """
    requirement = SPECIALIZED_DECODERS.get(decoder_name)
    
    if requirement == 'color_code':
        # Chromobius requires color codes with proper annotations
        meta = getattr(code, 'metadata', {}) if hasattr(code, 'metadata') else {}
        if not meta.get('is_chromobius_compatible', False):
            return False, "Requires color code"
    
    elif requirement == 'concatenated_code':
        # Concatenated decoders require ConcatenatedCode
        code_class_name = type(code).__name__
        if 'Concatenated' not in code_class_name:
            return False, "Requires concatenated code"
    
    elif requirement == 'small_detector_count':
        # MLE decoder is only practical for small detector counts
        if circuit is not None and circuit.num_detectors > MLE_MAX_DETECTORS:
            return False, f">30 detectors ({circuit.num_detectors})"
    
    return True, ""

def test_gadget_decoder(
    code,
    code_name: str,
    gadget_class,
    gadget_kwargs: Dict,
    decoder_class,
    decoder_name: str,
    is_two_qubit: bool = False,
    shots: int = 1000,
) -> Dict[str, Any]:
    """
    Test a gadget+decoder combination on a code.
    """
    result = {
        'status': STATUS_FAIL,
        'ler': None,
        'error': None,
    }
    
    # Check decoder compatibility BEFORE running experiment (code-level checks)
    is_compatible, reason = check_decoder_compatibility(decoder_name, code, code_name)
    if not is_compatible:
        result['status'] = STATUS_SKIP
        result['error'] = reason
        return result
    
    try:
        gadget = gadget_class(**gadget_kwargs)
        codes = [code, code] if is_two_qubit else [code]
        
        noise = CircuitDepolarizingNoise(p1=0.001, p2=0.001)
        exp = FaultTolerantGadgetExperiment(
            codes=codes,
            gadget=gadget,
            noise_model=noise,
            num_rounds_before=2,
            num_rounds_after=2,
        )
        
        circuit = exp.to_stim()
        
        if circuit.num_detectors == 0:
            result['status'] = STATUS_SKIP
            result['error'] = 'No detectors'
            return result
        
        # Check decoder compatibility AFTER circuit generation (circuit-level checks like detector count)
        is_compatible, reason = check_decoder_compatibility(decoder_name, code, code_name, circuit=circuit)
        if not is_compatible:
            result['status'] = STATUS_SKIP
            result['error'] = reason
            return result
        
        # Try to decode
        try:
            dem = circuit.detector_error_model(decompose_errors=True)
            
            # Create decoder - use dataclass pattern (dem=dem), not classmethod
            decoder = decoder_class(dem=dem)
            
            sampler = circuit.compile_detector_sampler()
            samples = sampler.sample(shots, append_observables=True)
            
            det_samples = samples[:, :circuit.num_detectors]
            obs_samples = samples[:, circuit.num_detectors:]
            
            predictions = decoder.decode_batch(det_samples)
            if len(predictions.shape) == 1:
                predictions = predictions.reshape(-1, 1)
            
            errors = np.any(predictions != obs_samples, axis=1)
            result['ler'] = float(np.mean(errors))
            result['status'] = STATUS_OK
            
        except Exception as e:
            result['status'] = STATUS_WARN
            result['error'] = str(e)[:50]
    
    except Exception as e:
        result['status'] = STATUS_FAIL
        result['error'] = str(e)[:50]
    
    return result


def run_decoder_matrix(
    gadgets_config: Dict,
    test_codes: Dict,
    decoder_classes: Dict,
    category_name: str,
) -> Dict:
    """
    Run decoder compatibility matrix for a gadget category.
    """
    
    all_results = {}
    rows = []
    dec_names = list(decoder_classes.keys())
    
    for gadget_name, (gadget_class, gadget_kwargs, is_two_qubit) in gadgets_config.items():
        all_results[gadget_name] = {}
        for code_name, code in test_codes.items():
            all_results[gadget_name][code_name] = {}
            
            row = f"{gadget_name:<20} | {code_name:<20}"
            
            for dec_name in dec_names:
                result = test_gadget_decoder(
                    code=code,
                    code_name=code_name,
                    gadget_class=gadget_class,
                    gadget_kwargs=gadget_kwargs,
                    decoder_class=decoder_classes[dec_name],
                    decoder_name=dec_name,
                    is_two_qubit=is_two_qubit,
                )
                
                all_results[gadget_name][code_name][dec_name] = result
                
                if result['status'] == STATUS_OK:
                    cell = f"{result['ler']:.3f}" if result['ler'] is not None else "  OK  "
                elif result['status'] == STATUS_SKIP:
                    cell = " SKIP "
                elif result['status'] == STATUS_WARN:
                    cell = " WARN "
                else:
                    cell = " FAIL "
                
                row += f" | {cell:^8}"
            rows.append(row)

    print("=" * 140)
    print(f"DECODER COMPATIBILITY: {category_name}")
    print("=" * 140)

    
    header = f"{'Gadget':<20} | {'Code':<20}"
    for dec in dec_names:
        header += f" | {dec[:8]:^8}"
    print(header)
    print("-" * len(header))
    for r in rows:
        print(r)
    
    print("-" * len(header))
    return all_results

In [None]:
# ==========================================================================
# COMPREHENSIVE TEST: All Gadgets x All Codes
# ==========================================================================
from qectostim.decoders import PyMatchingDecoder
print("=" * 100)
print("COMPREHENSIVE GADGET TEST: All Gadgets x All Codes")
print("=" * 100)

# Build results matrix
all_results = {}
code_names = list(TEST_CODES.keys())

# Header
header = f"{'Gadget':<20}"
for code_name in code_names:
    header += f" | {code_name[:12]:<12}"
print(header)
print("-" * len(header))

for gadget_name, (cls, kwargs, is_two) in ALL_GADGETS.items():
    row = f"{gadget_name:<20}"
    all_results[gadget_name] = {}
    
    for code_name, code in TEST_CODES.items():
        try:
            gadget = cls(**kwargs)
            codes = [code, code] if is_two else [code]
            
            exp = FaultTolerantGadgetExperiment(
                codes=codes,
                gadget=gadget,
                noise_model=noise,
                num_rounds_before=2,
                num_rounds_after=2,
            )
            
            circuit = exp.to_stim()
            
            if circuit.num_detectors == 0:
                all_results[gadget_name][code_name] = ('SKIP', 'No dets')
                row += f" | {'SKIP':^12}"
                continue
            
            dem = circuit.detector_error_model(decompose_errors=True)
            decoder = PyMatchingDecoder(dem=dem)
            
            sampler = circuit.compile_detector_sampler()
            samples = sampler.sample(2000, append_observables=True)
            
            det_samples = samples[:, :circuit.num_detectors]
            obs_samples = samples[:, circuit.num_detectors:]
            
            predictions = decoder.decode_batch(det_samples)
            if len(predictions.shape) == 1:
                predictions = predictions.reshape(-1, 1)
            
            errors = np.any(predictions != obs_samples, axis=1)
            ler = float(np.mean(errors))
            
            all_results[gadget_name][code_name] = ('OK', ler)
            row += f" | {ler:^12.4f}"
            
        except Exception as e:
            err_msg = str(e).lower()
            if "non-deterministic" in err_msg:
                all_results[gadget_name][code_name] = ('DET', str(e)[:30])
                row += f" | {'DET_ERR':^12}"
            elif "non-clifford" in err_msg or "notimplementederror" in err_msg:
                all_results[gadget_name][code_name] = ('SKIP', 'NonCliff')
                row += f" | {'SKIP':^12}"
            else:
                all_results[gadget_name][code_name] = ('FAIL', str(e)[:30])
                row += f" | {'FAIL':^12}"
    
    print(row)

print("-" * len(header))

# Summary statistics
print("\n" + "=" * 100)
print("RESULTS SUMMARY")
print("=" * 100)

status_counts = {'OK': 0, 'DET': 0, 'SKIP': 0, 'FAIL': 0}
problem_cases = []

for gadget_name, code_results in all_results.items():
    for code_name, (status, info) in code_results.items():
        if status == 'OK':
            status_counts['OK'] += 1
        elif status == 'DET':
            status_counts['DET'] += 1
            problem_cases.append((gadget_name, code_name, 'Non-deterministic detectors'))
        elif status == 'SKIP':
            status_counts['SKIP'] += 1
        else:
            status_counts['FAIL'] += 1
            problem_cases.append((gadget_name, code_name, info))

total = sum(status_counts.values())
print(f"\nTotal test cases: {total}")
print(f"  ✓ OK:      {status_counts['OK']} ({100*status_counts['OK']/total:.1f}%)")
print(f"  ⚠ DET_ERR: {status_counts['DET']} ({100*status_counts['DET']/total:.1f}%)")
print(f"  ~ SKIP:    {status_counts['SKIP']} ({100*status_counts['SKIP']/total:.1f}%)")
print(f"  ✗ FAIL:    {status_counts['FAIL']} ({100*status_counts['FAIL']/total:.1f}%)")

if problem_cases:
    print("\n" + "=" * 100)
    print("PROBLEM CASES REQUIRING INVESTIGATION")
    print("=" * 100)
    for gadget, code, msg in problem_cases:
        print(f"  {gadget:<20} + {code:<20}: {msg[:40]}")

COMPREHENSIVE GADGET TEST: All Gadgets x All Codes
Gadget               | XZZX_Surface | LoopToricCod | Hamming_CSS_ | Code_832     | Colour488_[[ | QuantumTanne | BalancedProd | FourQubit422 | SteanCode713 | ShorCode91   | RotatedSurfa | RotatedSurfa | ToricCode_3x | ToricCode_5x | XZZX_Surface | ToricCode4D_ | ToricCode4D_ | ToricCode3D_ | ToricCode3D_ | ToricCode3DF | HyperbolicSu | TwistedToric | TwistedToric | ProjectivePl | TriangularCo | TriangularCo | HexagonalCol | HexagonalCol | TruncatedTri | ColorCode3D_ | ColorCode3D_ | ColorCode3DP | TetrahedralC | BallColorCod | HyperbolicCo | GaugeColor_3
----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------

In [7]:
# ==========================================================================
# COMPREHENSIVE TEST: Code × Decoder Matrix for Each Gadget
# ==========================================================================
# This cell runs all tests and stores results in `all_gadget_results`

from qectostim.decoders import PyMatchingDecoder

# Select decoders to test (exclude problematic ones)
TEST_DECODERS = {k: v for k, v in decoder_classes.items() 
                 if k not in ['Chromobius', 'Concatenated', 'FlatConcat']}  # These have special requirements

print(f"Testing with {len(TEST_DECODERS)} decoders: {list(TEST_DECODERS.keys())}")
print(f"Testing on {len(TEST_CODES)} codes")
print(f"Testing {len(ALL_GADGETS)} gadgets")
total_combos = len(ALL_GADGETS) * len(TEST_CODES) * len(TEST_DECODERS)
print(f"Total test combinations: {len(ALL_GADGETS)} × {len(TEST_CODES)} × {len(TEST_DECODERS)} = {total_combos}")
print("\nRunning tests...")

def run_gadget_code_decoder_tests(
    gadget_name: str,
    gadget_class,
    gadget_kwargs: Dict,
    is_two_qubit: bool,
    test_codes: Dict,
    decoder_classes: Dict,
    shots: int = 1000,
) -> Dict:
    """
    Run Code × Decoder matrix for a single gadget.
    Returns results dict (no printing).
    """
    results = {}
    dec_names = list(decoder_classes.keys())
    
    for code_name, code in test_codes.items():
        results[code_name] = {}
        
        # First, try to build the circuit once (shared across decoders)
        circuit = None
        circuit_error = None
        dem = None
        det_samples = None
        obs_samples = None
        
        try:
            gadget = gadget_class(**gadget_kwargs)
            codes = [code, code] if is_two_qubit else [code]
            
            exp = FaultTolerantGadgetExperiment(
                codes=codes,
                gadget=gadget,
                noise_model=noise,
                num_rounds_before=2,
                num_rounds_after=2,
            )
            
            circuit = exp.to_stim()
            
            if circuit.num_detectors == 0:
                circuit_error = "No detectors"
            else:
                # Try to build DEM
                try:
                    dem = circuit.detector_error_model(decompose_errors=True)
                    sampler = circuit.compile_detector_sampler()
                    samples = sampler.sample(shots, append_observables=True)
                    det_samples = samples[:, :circuit.num_detectors]
                    obs_samples = samples[:, circuit.num_detectors:]
                except ValueError as e:
                    if "non-deterministic" in str(e).lower():
                        circuit_error = "DET_ERR"
                    else:
                        circuit_error = f"DEM: {str(e)[:20]}"
                        
        except ValueError as e:
            if "not supported" in str(e).lower() or "placeholder" in str(e).lower():
                circuit_error = "UNSUPPORTED"
            else:
                circuit_error = f"VAL: {str(e)[:15]}"
        except Exception as e:
            err_str = str(e).lower()
            if "non-clifford" in err_str:
                circuit_error = "NonCliff"
            else:
                circuit_error = f"ERR: {str(e)[:15]}"
        
        # If circuit failed, mark all decoders
        if circuit_error:
            for dec_name in dec_names:
                results[code_name][dec_name] = {'status': circuit_error, 'ler': None}
        else:
            # Test each decoder
            for dec_name, dec_class in decoder_classes.items():
                try:
                    is_compat, reason = check_decoder_compatibility(dec_name, code, code_name, circuit)
                    if not is_compat:
                        results[code_name][dec_name] = {'status': 'SKIP', 'ler': None, 'reason': reason}
                        continue
                    
                    decoder = dec_class(dem=dem)
                    predictions = decoder.decode_batch(det_samples)
                    if len(predictions.shape) == 1:
                        predictions = predictions.reshape(-1, 1)
                    
                    errors = np.any(predictions != obs_samples, axis=1)
                    ler = float(np.mean(errors))
                    
                    results[code_name][dec_name] = {'status': 'OK', 'ler': ler}
                    
                except Exception as e:
                    results[code_name][dec_name] = {'status': 'FAIL', 'ler': None, 'error': str(e)}
    
    return results


# ==========================================================================
# Run the full matrix for each gadget
# ==========================================================================
all_gadget_results = {}
skipped_gadgets = []

for i, (gadget_name, (cls, kwargs, is_two)) in enumerate(ALL_GADGETS.items()):
    # Skip known broken gadgets
    if gadget_name in KNOWN_BROKEN_GADGETS:
        skipped_gadgets.append((gadget_name, KNOWN_BROKEN_GADGETS[gadget_name]))
        continue
    
    print(f"  [{i+1}/{len(ALL_GADGETS)}] Testing {gadget_name}...", end=" ")
    start_time = time.time()
    
    all_gadget_results[gadget_name] = run_gadget_code_decoder_tests(
        gadget_name=gadget_name,
        gadget_class=cls,
        gadget_kwargs=kwargs,
        is_two_qubit=is_two,
        test_codes=TEST_CODES,
        decoder_classes=TEST_DECODERS,
        shots=1000,
    )
    
    elapsed = time.time() - start_time
    ok_count = sum(1 for c in all_gadget_results[gadget_name].values() 
                   for d in c.values() if d.get('status') == 'OK')
    total_count = len(TEST_CODES) * len(TEST_DECODERS)
    print(f"done ({elapsed:.1f}s) - {ok_count}/{total_count} OK")

print(f"\n✓ All tests complete! Results stored in `all_gadget_results`")
print(f"  Tested: {len(all_gadget_results)} gadgets")
print(f"  Skipped: {len(skipped_gadgets)} gadgets")

Testing with 10 decoders: ['PyMatching', 'FusionBlossom', 'BeliefMatching', 'BPOSD', 'Tesseract', 'UnionFind', 'MLE', 'Hypergraph', 'Hierarchical', 'SingleShot']
Testing on 36 codes
Testing 16 gadgets
Total test combinations: 16 × 36 × 10 = 5760

Running tests...
done (52.5s) - 108/360 OK
  [2/16] Testing TransversalS... done (52.5s) - 108/360 OK
done (54.0s) - 100/360 OK
  [4/16] Testing TransversalX... done (54.0s) - 100/360 OK
done (74.9s) - 144/360 OK
  [5/16] Testing TransversalY... done (74.9s) - 144/360 OK
done (80.1s) - 144/360 OK
  [6/16] Testing TransversalZ... done (80.1s) - 144/360 OK
done (70.4s) - 144/360 OK
  [7/16] Testing TransversalCNOT... done (70.4s) - 144/360 OK
done (96.1s) - 50/360 OK
  [8/16] Testing TransversalCZ... done (96.1s) - 50/360 OK
done (110.9s) - 92/360 OK
  [9/16] Testing TransversalSWAP... done (110.9s) - 92/360 OK
done (98.8s) - 92/360 OK
  [10/16] Testing TeleportedH... done (98.8s) - 92/360 OK
done (96.8s) - 159/360 OK
  [11/16] Testing Teleporte

In [8]:
# ==========================================================================
# DISPLAY RESULTS: Code × Decoder Matrix for Each Gadget
# ==========================================================================
# This cell pretty-prints the results from `all_gadget_results`

def display_gadget_matrix(gadget_name: str, results: Dict, decoder_names: List[str]):
    """Pretty print a Code × Decoder matrix for one gadget."""
    code_names = list(results.keys())
    dec_display = [name[:10] for name in decoder_names]
    
    # Header
    header = f"{'Code':<18}"
    for dec in dec_display:
        header += f" | {dec:^10}"
    
    print(f"\n{'='*len(header)}")
    print(f"GADGET: {gadget_name}")
    print(f"{'='*len(header)}")
    print(header)
    print("-" * len(header))
    
    for code_name in code_names:
        row = f"{code_name[:18]:<18}"
        for dec_name in decoder_names:
            result = results[code_name].get(dec_name, {'status': '?', 'ler': None})
            status = result.get('status', '?')
            ler = result.get('ler')
            
            if status == 'OK' and ler is not None:
                cell = f"{ler:.4f}"
            else:
                cell = status[:10]
            row += f" | {cell:^10}"
        print(row)
    
    print("-" * len(header))
    
    # Summary
    ok_count = sum(1 for c in results.values() for d in c.values() if d.get('status') == 'OK')
    total = len(code_names) * len(decoder_names)
    print(f"Summary: {ok_count}/{total} OK ({100*ok_count/total:.0f}%)")


# Display each gadget's matrix
dec_names = list(TEST_DECODERS.keys())

for gadget_name, results in all_gadget_results.items():
    display_gadget_matrix(gadget_name, results, dec_names)

# Display skipped gadgets
if skipped_gadgets:
    print(f"\n{'='*60}")
    print("SKIPPED GADGETS")
    print(f"{'='*60}")
    for name, reason in skipped_gadgets:
        print(f"  {name:<25}: {reason}")

# ==========================================================================
# OVERALL SUMMARY
# ==========================================================================
print("\n" + "=" * 100)
print("OVERALL SUMMARY: All Gadgets × All Codes × All Decoders")
print("=" * 100)

# Aggregate statistics
status_counts = defaultdict(int)

for gadget_name, code_results in all_gadget_results.items():
    for code_name, dec_results in code_results.items():
        for dec_name, result in dec_results.items():
            status = result.get('status', 'FAIL')
            if status == 'OK':
                status_counts['OK'] += 1
            elif status == 'SKIP':
                status_counts['SKIP'] += 1
            elif status == 'DET_ERR':
                status_counts['DET_ERR'] += 1
            elif status == 'UNSUPPORTED':
                status_counts['UNSUPPORTED'] += 1
            elif status == 'NonCliff':
                status_counts['NonCliff'] += 1
            else:
                status_counts['FAIL'] += 1

total = sum(status_counts.values())
print(f"\nTotal test combinations: {total}")
print(f"  ✓ OK:          {status_counts['OK']:4d} ({100*status_counts['OK']/total:.1f}%)")
print(f"  ⚠ DET_ERR:     {status_counts['DET_ERR']:4d} ({100*status_counts['DET_ERR']/total:.1f}%) - Non-deterministic detectors")
print(f"  ⊘ UNSUPPORTED: {status_counts['UNSUPPORTED']:4d} ({100*status_counts['UNSUPPORTED']/total:.1f}%) - Code not supported")
print(f"  ~ SKIP:        {status_counts['SKIP']:4d} ({100*status_counts['SKIP']/total:.1f}%) - Decoder incompatible")
print(f"  ✗ FAIL:        {status_counts['FAIL']:4d} ({100*status_counts['FAIL']/total:.1f}%) - Decoder error")

# Per-gadget summary table
print("\n" + "-" * 80)
print("Per-Gadget Summary:")
print("-" * 80)
print(f"{'Gadget':<25} | {'OK':>6} | {'DET_ERR':>7} | {'FAIL':>6} | {'SKIP':>6} | {'Rate':>6}")
print("-" * 80)

for gadget_name, code_results in all_gadget_results.items():
    g_ok = sum(1 for c in code_results.values() for d in c.values() if d.get('status') == 'OK')
    g_det = sum(1 for c in code_results.values() for d in c.values() if d.get('status') == 'DET_ERR')
    g_fail = sum(1 for c in code_results.values() for d in c.values() if d.get('status') == 'FAIL')
    g_skip = sum(1 for c in code_results.values() for d in c.values() if d.get('status') == 'SKIP')
    g_total = sum(len(d) for d in code_results.values())
    rate = 100 * g_ok / g_total if g_total > 0 else 0
    print(f"{gadget_name:<25} | {g_ok:>6} | {g_det:>7} | {g_fail:>6} | {g_skip:>6} | {rate:>5.0f}%")

# Per-code summary table
print("\n" + "-" * 80)
print("Per-Code Summary:")
print("-" * 80)
print(f"{'Code':<25} | {'OK':>6} | {'DET_ERR':>7} | {'FAIL':>6} | {'SKIP':>6} | {'Rate':>6}")
print("-" * 80)

code_stats = defaultdict(lambda: defaultdict(int))
for gadget_name, code_results in all_gadget_results.items():
    for code_name, dec_results in code_results.items():
        for dec_name, result in dec_results.items():
            status = result.get('status', 'FAIL')
            if status == 'OK':
                code_stats[code_name]['OK'] += 1
            elif status == 'DET_ERR':
                code_stats[code_name]['DET_ERR'] += 1
            elif status == 'SKIP':
                code_stats[code_name]['SKIP'] += 1
            else:
                code_stats[code_name]['FAIL'] += 1
            code_stats[code_name]['total'] += 1

for code_name, stats in code_stats.items():
    rate = 100 * stats['OK'] / stats['total'] if stats['total'] > 0 else 0
    print(f"{code_name[:25]:<25} | {stats['OK']:>6} | {stats['DET_ERR']:>7} | {stats['FAIL']:>6} | {stats['SKIP']:>6} | {rate:>5.0f}%")

# Per-decoder summary table
print("\n" + "-" * 80)
print("Per-Decoder Summary:")
print("-" * 80)
print(f"{'Decoder':<15} | {'OK':>6} | {'DET_ERR':>7} | {'FAIL':>6} | {'SKIP':>6} | {'Rate':>6}")
print("-" * 80)

dec_stats = defaultdict(lambda: defaultdict(int))
for gadget_name, code_results in all_gadget_results.items():
    for code_name, dec_results in code_results.items():
        for dec_name, result in dec_results.items():
            status = result.get('status', 'FAIL')
            if status == 'OK':
                dec_stats[dec_name]['OK'] += 1
            elif status == 'DET_ERR':
                dec_stats[dec_name]['DET_ERR'] += 1
            elif status == 'SKIP':
                dec_stats[dec_name]['SKIP'] += 1
            else:
                dec_stats[dec_name]['FAIL'] += 1
            dec_stats[dec_name]['total'] += 1

for dec_name, stats in dec_stats.items():
    rate = 100 * stats['OK'] / stats['total'] if stats['total'] > 0 else 0
    print(f"{dec_name:<15} | {stats['OK']:>6} | {stats['DET_ERR']:>7} | {stats['FAIL']:>6} | {stats['SKIP']:>6} | {rate:>5.0f}%")


GADGET: TransversalH
Code               | PyMatching | FusionBlos | BeliefMatc |   BPOSD    | Tesseract  | UnionFind  |    MLE     | Hypergraph | Hierarchic | SingleShot
----------------------------------------------------------------------------------------------------------------------------------------------------
XZZX_Surface_3     |   0.0080   |   0.0100   |   0.0100   |   0.0020   |   0.0000   |   0.0080   |   0.0010   |   0.0230   |    FAIL    |    FAIL   
LoopToricCode4D    | DEM: Faile | DEM: Faile | DEM: Faile | DEM: Faile | DEM: Faile | DEM: Faile | DEM: Faile | DEM: Faile | DEM: Faile | DEM: Faile
Hamming_CSS_15     |   0.0540   |   0.0450   |   0.0520   |   0.0150   |   0.0110   |   0.0540   |   0.0120   |   0.0870   |    FAIL    |    FAIL   
Code_832           |   0.0240   |   0.0420   |   0.0320   |   0.0180   |   0.0040   |   0.0240   |   0.0040   |   0.0240   |    FAIL    |    FAIL   
Colour488_[[9,1,3] |   0.0260   |   0.0290   |   0.0430   |   0.0290   |   0.0100   