This repository is the code artifact for the SparseFuse NeurIPS 2026 submission.
For a direct paper-to-code mapping, see PAPER_RESULTS_MAP.md.
The repo supports three result families:
- SparseFuse kernel and systems results: layer microbenchmarks, sparse-vs-dense contraction, grid scaling, compile/checkpoint retention, Scale-B feasibility, ViT-KAN feasibility, and PINN runtime/memory/accuracy experiments.
- Appendix experiments: heatmaps, convergence curves, grid-quality probes, MatrixKAN comparison, cross-basis scaling, and PINN supplements.
- JHCG appendix experiments: module-boundary microbenchmarks, controlled Transformer audits, S4 attribution/adaptation diagnostics, cached fast-path checks, compile-drift diagnostics, and nanochat portability scouts.
Recommended baseline:
- Python 3.10 or newer.
- Linux + NVIDIA GPU for Triton fast paths.
- PyTorch 2.3 or newer; the paper measurements used CUDA 12.6 builds with PyTorch 2.10 and Triton 3.6 on RTX 3080 and H100.
- CPU is sufficient for import checks, config validation, and the small JHCG smoke tests.
Install:
python -m venv .venv
. .venv/bin/activate
pip install -r requirements.txt
pip install -e .pip install -e . makes the sparsefuse package visible to scripts. If you
skip the editable install during a quick local audit, run commands as
PYTHONPATH=. python ... from the repo root. Main paper comparisons against
efficient_kan require the external baseline installed by requirements.txt;
without it, SparseFuse-only smoke tests still run but baseline comparison
scripts fail fast.
For full H100 reproduction, install and configure Modal separately, then use the Modal launcher scripts described below.
No datasets or checkpoints are bundled in this repo.
- Kernel microbenchmarks, sparse/dense lowering, heatmaps, and most synthetic PINN/grid-quality runs generate data directly.
- CIFAR-100 and other torchvision-style datasets should be downloaded by the relevant script into
DATA_ROOT, or the default dataset cache if the script supports it. - Controlled Transformer and JHCG training scripts expect tokenized language-model data. Set
DATA_ROOTor pass the config-specific data paths inconfigs/controlled_transformer/*.yaml. - Nanochat portability scripts require an external nanochat checkout. Run
scripts/prepare_nanochat_external.shto clonekarpathy/nanochatinto the ignored./nanochat_external/directory and mirrornanochat_integration/*.pythere, or setNANOCHAT_ROOT=/path/to/nanochat_checkoutbefore usingscripts/modal_nanochat_h100.pynanochat tasks. - Write large logs/results outside git or under the ignored
benchmark_logs/tree.
Run these from the repo root:
python - <<'PY'
import torch
from sparsefuse import KANLinear
from sparsefuse.jhcg import SparseFuseJHCG
layer = KANLinear(4, 3, grid_size=5)
x = torch.randn(2, 4)
print(layer(x).shape)
jhcg = SparseFuseJHCG(d=16, rho=0.25, grid_size=5, variant="expanded")
print(jhcg(torch.randn(2, 8, 16)).shape)
PY
python scripts/rtx3080_local_dispatcher.py --list
OUTPUT_DIR="${OUTPUT_DIR:-smoke_outputs}"
python scripts/bench_jhcg_micro.py --method sf_jhcg --variant expanded --d 16 --rho 0.25 --grid-size 5 --batch-tokens 4 --device cpu --warmup 1 --iters 1 --out-dir "$OUTPUT_DIR"
pytest -q tests/test_basis_interface.py tests/test_jhcg_shapes.py tests/test_controlled_transformer_methods.pyThe CPU JHCG microbenchmark is a smoke test only. Paper latency and VRAM claims require CUDA.
The release baseline is the annotated tag paper-neurips2026-artifact-final: tag object 3a59755f3b1929577eae72eb2a31b63ae795deb1, peeled commit 512e6013c37a918d3c20f577bbfd6b2e60ba05bc. An earlier camera-ready artifact snapshot is recorded in PAPER_ARTIFACT.md.
Use:
python scripts/rtx3080_local_dispatcher.py --list
python scripts/rtx3080_local_dispatcher.py --all
modal run scripts/modal_h100_dispatcher.py::main --task allMain-paper coverage:
- Layer microbenchmark and kernel summary:
scripts/benchmark_micro.py,scripts/bench_section53.py,scripts/rtx3080_local_dispatcher.py. - Sparse-vs-dense contraction and padded geometry:
scripts/bench_section43_sparse_dense.py,scripts/bench_sparse_baseline.py,scripts/paper_padded_l_geometry.py. - H100 grid scaling, compile/checkpoint, Scale-B, sequence length, and feasibility:
scripts/modal_h100_dispatcher.py,scripts/h100_scaleb_compile_short.py,scripts/h100_minigpt_scale_b.py,scripts/h100_g_scaling.py,scripts/h100_seqlen_sweep.py. - ViT-KAN CIFAR-100 feasibility:
scripts/bench_vit_kan_cifar100.py. - PINN tables:
scripts/h100_multiscale_poisson.py,scripts/h100_nonsep_poisson.py,scripts/h100_advection_diffusion.py,scripts/h100_poisson_fourier_mlp.py.
See scripts/README_PAPER_REPRO.md and PAPER_ARTIFACT.md for the section-to-script mapping and original log-path manifest.
Appendix coverage:
- Heatmaps:
scripts/benchmark_heatmap.py,scripts/plot_heatmaps.py. - Convergence curves and large-grid quality:
scripts/benchmark_convergence_curves.py,scripts/benchmark_large_grid.py,scripts/bench_g_scaling_quality.py,scripts/bench_g_scaling_quality_2d.py. - Cross-basis and MatrixKAN comparisons:
scripts/bench_crossbasis_gscaling.py,scripts/bench_matrixkan_compare.py. - Grid extension and symbolic/tabular supplements:
scripts/bench_grid_extension.py,scripts/benchmark_symbolic.py,scripts/benchmark_tabular.py. - Plot/table generation:
scripts/generate_paper_figures.py,scripts/redraw_figures_neurips_2026_style_v6.py,analysis/*.py,tables/*.csv.
The repo intentionally excludes large generated figure outputs and raw run directories. Re-run scripts to regenerate them.
Core JHCG implementation:
sparsefuse/jhcg.pysparsefuse/ffn.pysparsefuse/param_utils.pysparsefuse/diagnostics.pysparsefuse/branch_profiler.pysparsefuse/gated_refiner.pyconfigs/controlled_transformer/*.yaml
Entry points:
- Module-boundary microbenchmarks:
scripts/bench_jhcg_micro.py,scripts/run_jhcg_sweep.py. - Cached fast-path and branch profiling:
scripts/bench_jhcg_speedup_fastpath.py,scripts/bench_jhcg_profile.py. - Controlled Transformer training and audits:
scripts/train_controlled_transformer.py,scripts/run_controlled_transformer_smoke.py,scripts/run_controlled_transformerc_smoke.py,scripts/run_nanochat_h100_pilot.py,analysis/controlled_transformer_controlled_transformer_audit.py. - S4 attribution/adaptation:
scripts/eval_s4_attribution.py,scripts/run_adaptation_diagnostics.py,analysis/s4_attribution_2026-04-28/,analysis/s2_adaptation_2026-04-28/. - Nanochat portability:
nanochat_integration/nanochat_controlled_transformer_m7_adapter.py,nanochat_integration/nanochat_controlled_transformer_m7_s4_eval.py,scripts/run_nanochat_measurement_n0n1n2_launch.py,scripts/run_nanochat_m7_s4_attribution_launch.py,scripts/modal_nanochat_h100.py.
Small CSV summaries used by the paper are under tables/, especially tables/jhcg_micro_bonus_h100_summary.csv, tables/controlled_transformer/controlled_transformer_main_matrix_audit.csv, and tables/controlled_transformer/controlled_transformer_adapt_summary.csv.
- CPU smoke tests: seconds to a few minutes.
- RTX 3080 local dispatcher: hours for the complete suite; individual microbenchmarks are much shorter.
- H100 Modal dispatcher: several hours across Scale-B, compile frontier, ViT-KAN, and sequence-length sweeps.
- Full paper reproduction is approximately 50 GPU-hours across RTX 3080 and H100-class runs.
- JHCG controlled Transformer full runs require prepared tokenized data and H100-class GPU memory. Use the smoke configs in
configs/controlled_transformer/tiny.yamlandscripts/run_controlled_transformer_smoke.pyfor reduced checks.
- The repo does not bundle datasets, checkpoints, Modal credentials, raw benchmark logs, or generated figures.
- Some H100 and nanochat results require external services/checkouts and cannot be fully reproduced by a single local command.
- CPU smoke tests validate imports, paths, and small tensor execution; they do not validate performance claims.
- SparseFuse assumes uniform cubic B-spline grids for the benchmarked main path. Non-B-spline bases are correctness-tested but not benchmarked at Transformer scale.
- The sparse-format comparison is limited to the CSR/BSR paths tested by the scripts.
- JHCG nanochat transfer is reported as exploratory/negative in the paper and is not a main SparseFuse systems claim.
sparsefuse/: SparseFuse layers, basis evaluators, kernels, JHCG modules, and FFN factories.scripts/: benchmark, training, dispatcher, plotting, and audit entry points.configs/controlled_transformer/: controlled Transformer and JHCG method configs.analysis/: aggregation and audit scripts plus small summary CSV/JSON metadata for JHCG appendix diagnostics.tables/: small paper table inputs and JHCG summary CSVs.tests/: unit and smoke tests for SparseFuse and JHCG.docs/: user-facing API and migration documentation.nanochat_integration/: adapters and measurement scripts for external nanochat experiments.
MIT; see LICENSE.