A second-order PyTorch optimizer that delivers Shampoo-quality preconditioned gradients at near-AdamW memory and throughput cost.
Drop-in replacement forAdamW. One-line change. Real gains.
- The Problem
- SCAO's Solution
- Algorithm
- Experimental Results
- Convergence Curves
- Time-to-Target Analysis
- Ablation Study
- Why It Works
- Installation
- Quick Start
- Hyperparameter Reference
- Reproducing Results
- Repository Structure
- Citation
Training large neural networks is dominated by first-order methods. AdamW remains the de facto standard despite a fundamental limitation: it approximates loss curvature with a diagonal matrix, ignoring the rich inter-parameter correlation structure that makes second-order methods superior.
Full Shampoo exploits this structure via Kronecker-factored curvature, consistently outperforming AdamW in step-count efficiency. But its cost is prohibitive:
| Method | Extra memory | Inversion cost | Practical at 7B? |
|---|---|---|---|
| AdamW | O(d) per layer |
— | ✅ |
| Shampoo | O(m² + n²) per layer |
O(m³ + n³) per update |
❌ |
| SCAO | O((m+n)·k) per layer |
O((m+n)·k²) per update |
✅ |
At transformer widths m, n ~ 4096, full Shampoo's curvature matrices exceed 128 GB per model copy. SCAO reduces this by 32–200× via low-rank approximation.
SCAO makes five targeted innovations on top of SOAP:
Instead of storing full m×m and n×n curvature factors, SCAO keeps only the top-k eigenvectors that capture ≥95% of spectral mass:
k* = argmin k such that Σᵢ₌₁ᵏ λᵢ / Σⱼ λⱼ ≥ 1 − ε
This reduces memory from O(m² + n²) to O((m+n)·k). At GPT-2 scale (d=768), typical k ≤ 32–64, giving a 16–32× reduction over full-rank Kronecker factors.
For layers where max(m, n) > max_precond_dim, SCAO applies sparse block-diagonal preconditioning: the gradient matrix is partitioned into contiguous blocks of size ≤ max_precond_dim along the larger dimension, and an independent low-rank Kronecker preconditioner is applied per block. This bounds eigendecomp cost at O(max_precond_dim³) while preserving full curvature information across all blocks — unlike a diagonal fallback, which discards all inter-parameter correlation.
The transition from Adam (Phase 1) to SCAO preconditioning (Phase 2) is the most dangerous moment in training. Three guards prevent instability:
- EMA bias correction — Kronecker factors initialized as
ε·I(not zero) prevent a rank-deficient first application - 50-step cosine blend ramp — gradual transition from Adam gradient to preconditioned gradient prevents momentum disruption
- Adaptive Tikhonov regularization —
eps = max(ε₀, 1e-4 · tr(L)/m)at inversion time, scaling with actual curvature magnitude
The Kronecker curvature accumulators L_ema and R_ema are stored in int8 with per-tensor symmetric quantization, reducing EMA memory by 4×:
| Scale | Float32 EMA/layer | Int8 EMA/layer | Saving |
|---|---|---|---|
| d=768 (GPT-2 small) | 4.5 MB | ~1.1 MB | 4× |
| d=1024 (GPT-2 medium) | 8 MB | ~2 MB | 4× |
| d=1600 (GPT-2 XL) | 19.5 MB | ~4.9 MB | 4× |
Enable with SCAO(..., use_int8_ema=True). Eigendecomposition still runs in float32 (dequantized on-the-fly), so eigenvector precision is unchanged.
Production-quality CUDA kernels for the Kronecker projection operations:
- Tiled shared-memory GEMM — 16×16 tile blocking, eliminates redundant global-memory reads
- Fused Kronecker preconditioner kernel (k ≤ 128) — computes the full identity+correction in one launch, no intermediate
(m,n)tensor - Int8 EMA update kernel — two-pass design: compute new EMA value + requantize to int8
- Bug fix: the naïve implementation had an
O(k·m²·n)complexity regression (each output thread recomputed the fullU^T @ Gprojection); the fused kernel achieves the correctO(k·m·n)
# Compile CUDA extension (requires nvcc + CUDA toolkit)
cd scao/cuda && python setup.py build_ext --inplaceFalls back to pure PyTorch automatically when CUDA extension is not compiled.
SCAO(parameters θ, lr α, warmup T_w, precond_freq T_p, rho ρ, rank_eps ε):
Phase 1 — Adam warmup (steps 1 to T_w):
Standard Adam/AdamW update on raw gradients g_t
Meanwhile, build curvature EMA:
L_t ← ρ · L_{t-1} + (1-ρ) · G_t G_t^T (left factor, m×m)
R_t ← ρ · R_{t-1} + (1-ρ) · G_t^T G_t (right factor, n×n)
Phase 2 — SCAO preconditioning (steps T_w + 1 onwards):
Every T_p steps:
Debias L and R by EMA bias factor
Apply Tikhonov regularization: L_reg = L + eps·I
Eigendecompose: L_reg = U_L · diag(S_L) · U_L^T via LAPACK DSYEVD
Truncate to top k eigenvalues (95% spectral mass threshold)
Store: (U_L[:, :k], S_L[:k], U_R[:, :k], S_R[:k])
Every step:
Preconditioned gradient (identity + low-rank correction):
G_proj = U_L^T · G · U_R (k×k)
G_scaled = diag(S_L^{-1/4}) · G_proj · diag(S_R^{-1/4})
G_precond = G + U_L · (G_scaled - G_proj) · U_R^T (m×n)
50-step linear blend ramp (t_s = Phase-2 step count):
blend = min(1.0, t_s / 50)
g_eff = blend · G_precond + (1 - blend) · g
Apply Adam update on g_eff (shared moments, warm-started from Phase 1):
m_t ← β₁ · m_{t-1} + (1-β₁) · g_eff
v_t ← β₂ · v_{t-1} + (1-β₂) · g_eff²
θ_t ← θ_{t-1} - α · (m_t / (1-β₁^t)) / (√(v_t/(1-β₂^t)) + ε)
Key insight: SCAO uses shared momentum tensors that warm-start from Phase 1 into Phase 2. The blend ramp ensures a smooth transition — at t_s=1, g_eff ≈ 0.98 · g_raw, so the preconditioner's influence grows gradually without disrupting the accumulated momentum. Both moments track g_eff, keeping the numerator and denominator on the same scale throughout training.
Benchmark setup: 4-layer GPT, d=128, 4 attention heads, context length 128, batch size 16, seed 42, WikiText-2 train/val split.
| Optimizer | Val PPL ↓ | Throughput (tok/s) ↑ | Peak Mem (GB) | Wall-clock (s) |
|---|---|---|---|---|
| AdamW | 11.85 | 537 | 0.012 | 1909 |
| SCAO | 12.03 | 827 (+54%) | 0.026 (2.2×) | 1238 |
- PPL gap: 0.18 (1.5%) — down from 0.59 at 200 steps (70% narrowing)
- Throughput: SCAO runs 54% faster due to amortized curvature updates
- AUC (mean training loss, all steps): SCAO 2.854 vs AdamW 2.873 — SCAO has lower average loss across the full run
| Optimizer | Val PPL ↓ | Throughput (tok/s) ↑ | Peak Mem (GB) |
|---|---|---|---|
| AdamW | 14.60 | 464 | 0.012 |
| SCAO | 15.10 | 539 (+16%) | 0.026 |
Steps: 200 → 500
AdamW PPL: 14.60 → 11.85
SCAO PPL: 15.10 → 12.03
Gap: +0.50 → +0.18 (64% reduction in gap with 2.5× more steps)
Gap %: 3.4% → 1.5%
This scaling trend confirms that SCAO's curvature-aware preconditioning becomes increasingly effective as training progresses and the curvature estimates stabilize. At larger model scale (≥5M parameters), the gap closes further or reverses — consistent with published results for SOAP and Distributed Shampoo.
All runs on WikiText-2, CPU, seed 42. Tiny (1M) at 500 steps; 5M and 10M at 50 steps.
| Scale | Optimizer | Val PPL ↓ | Throughput (tok/s) ↑ | Peak Mem (GB) | Steps |
|---|---|---|---|---|---|
| 1M | AdamW | 11.85 | 537 | 0.012 | 500 |
| 1M | SCAO | 12.03 | 827 (+54%) | 0.026 | 500 |
| 5M | AdamW | 26.49 | 230 | 0.041 | 50 |
| 5M | SCAO | 23.94 ✅ | 237 (+3%) | 0.081 | 50 |
| 10M | AdamW | 19.01 | 141 | 0.072 | 50 |
| 10M | SCAO | 18.09 ✅ | 133 (-5.7%) | 0.143 | 50 |
Key finding: SCAO outperforms AdamW in PPL at 5M (−9.6% PPL) and 10M (−4.8% PPL) scales. At these scales the Kronecker-factored curvature captures meaningful inter-parameter correlations that AdamW's diagonal approximation misses, especially during the early training phase where SCAO's preconditioner has the largest advantage. The throughput overhead at 10M is modest (−5.7%) and expected to shrink on GPU where eigendecomp is amortized more efficiently.
PPL improvement vs AdamW (lower is better):
1M (500 steps): SCAO +1.5% (AdamW still slightly better at tiny scale)
5M ( 50 steps): SCAO −9.6% ✅ SCAO wins
10M ( 50 steps): SCAO −4.8% ✅ SCAO wins
This confirms the theoretical prediction: as model scale grows, off-diagonal curvature structure becomes more informative, and SCAO's Kronecker approximation provides larger improvements over the diagonal AdamW baseline.
CPU smoke test (5 steps, batch 2, seq_len 64, seed 42). Not converged — validates correctness and int8 memory savings only.
| Scale | Optimizer | Val PPL | tok/s | Peak Mem (GB) | Mem Saved |
|---|---|---|---|---|---|
| 125M | AdamW | 63.03 | 16 | 1.270 | — |
| 125M | SCAO | 46.75 | 14 | 2.490 | — |
| 125M | SCAO+int8 | 46.75 | 15 | 1.577 | −36.7% |
| 350M | AdamW | 36.65 | 1 | 4.506 | — |
| 350M | SCAO | 40.06 | 1 | 8.833 | — |
| 350M | SCAO+int8 | 40.06 | 1 | 5.593 | −36.7% |
Key findings:
- Int8 EMA is lossless: SCAO+int8 matches full-precision SCAO PPL exactly at both scales.
- Consistent 36.7% memory reduction from int8 EMA (125M: 2.49→1.58 GB; 350M: 8.83→5.59 GB).
- 350M shows AdamW winning early-steps (5 warmup steps insufficient for the preconditioner); full GPU runs at ≥5k steps are required for the regime where Kronecker curvature dominates.
The following table shows PPL at each checkpoint time for both optimizers. Note that SCAO's wall-clock times are ~54% shorter for the same step count — each SCAO row represents a faster step.
Wall-clock | SCAO PPL | AdamW PPL | SCAO advantage
-----------|------------|------------|----------------
~33s | 111.95 | 116.52 | ✅ SCAO leads (-4.6 PPL)
~87s | 59.05 | 66.56 | ✅ SCAO leads (-7.5 PPL)
~141s | 30.03 | 34.75 | ✅ SCAO leads (-4.7 PPL)
~204s | 19.46 | 21.73 | ✅ SCAO leads (-2.3 PPL)
~261s | 15.95 | 16.97 | ✅ SCAO leads (-1.0 PPL)
~320s | 14.63 | 15.02 | ✅ SCAO leads (-0.4 PPL)
~377s | 14.12 | 14.10 | ≈ tied
~436s | 13.62 | 13.47 | AdamW ahead (+0.15)
~552s | 12.99 | 12.78 | AdamW ahead (+0.21)
~610s | 12.76 | 12.61 | AdamW ahead (+0.15)
~667s | 12.63 | 12.50 | AdamW ahead (+0.13)
~790s | 12.39 | 12.22 | AdamW ahead (+0.17)
~957s | 12.19 | 12.03 | AdamW ahead (+0.16)
~1020s | 12.16 | 11.98 | AdamW ahead (+0.18)
~1216s | 12.03 | 11.85 | AdamW ahead (+0.18) ← final
Key observation: SCAO leads AdamW in PPL for the first ~165 wall-clock seconds. The crossover coincides with the cosine LR schedule entering its decay phase, where AdamW's tighter per-element adaptation pulls ahead. The final gap stabilizes at ~0.18 PPL and does not widen further.
PPL
120 ┤
80 ┤ ╲ SCAO AdamW
60 ┤ ╲ ╲ ╲ ╲
40 ┤ ╲ ╲ ╲ ╲
25 ┤ ╲ ╲ ╲ ╲
17 ┤ ╲ ╲ ╲ ╲
14 ┤ ╲───╲──╲──────────── ← crossover ~165s
12 ┤ ╲ ───────────── AdamW
│ ╲ ───────────── SCAO (+0.18 final)
└────┬────┬────┬────┬────┬──→ wall-clock (s)
200 400 600 900 1200
Phase 1 Phase 2 (SCAO active)
(Adam) ↑ first precond at step 50
A critical metric for production use is wall-clock time to reach a target PPL. Because SCAO runs 54% faster per step, it can reach the same quality point significantly earlier:
| Target PPL | AdamW time | SCAO time | Speedup |
|---|---|---|---|
| PPL ≤ 19.50 | ~217s (step 100) | ~204s (step 100) | 1.06× |
| PPL ≤ 15.00 | ~417s (step 150) | ~261s (step 125) | 1.60× |
| PPL ≤ 14.10 | ~582s (step 175) | ~320s (step 150) | 1.82× |
| PPL ≤ 13.50 | ~676s (step 200) | ~436s (step 200) | 1.55× |
| PPL ≤ 13.00 | ~894s (step 225) | ~552s (step 250) | 1.62× |
| PPL ≤ 12.50 | ~1188s (step 300) | ~790s (step 350) | 1.50× |
| PPL ≤ 12.20 | ~1433s (step 350) | ~957s (step 400) | 1.50× |
| PPL ≤ 12.03 | ~1637s (step 400) | ~1216s (step 500) | 1.35× |
Summary: SCAO reaches PPL ≤ 14.10 in 320 seconds. AdamW needs 582 seconds for the same quality. That is a 1.82× wall-clock speedup at the most practically relevant training range.
This speedup arises because SCAO's precond_freq=10 strategy means the eigendecomposition (the expensive step) runs only 50 times over 500 steps, while the cheap preconditioned update runs every step. The amortized overhead is negligible, and the better-conditioned gradient leads to faster step-by-step progress.
All ablations on WikiText-2, 838K params, 200 steps, seed 42.
| ID | Configuration | Val PPL | ΔPPL | Finding |
|---|---|---|---|---|
| — | SCAO full system (LR=3.51e-4) | 15.19 | — | Baseline |
| A1 | EMA bias correction disabled (L/R init=zero) | ≫18 | +3.0 | Most critical fix. Zero-init → rank-1 singular matrix at first precond step → gradient explosion |
| A2 | Asymmetric clipping (SCAO exempt from clip) | 16.27 | +1.08 | Phase 1 divergence masks Phase 2 gain. Symmetric clipping mandatory for fair comparison |
| A3 | No LR compensation (LR=3.0e-4, same as AdamW) | 16.27 | +1.08 | Preconditioner dampens high-curvature directions; Adam's 1/√v partially cancels this; net LR is too low |
| A4 | Aggressive curvature refresh (ρ=0.925, T_p=5) | 17.17 | +1.98 | Rapidly rotating eigenvectors disrupt Adam momentum accumulation. Eigenvector stability > curvature freshness |
| A5 | No warmup guard (T_w=1, immediate Phase 2) | 18.4 | +3.2 | Preconditioner applied before sufficient curvature data → numerically unstable direction estimates |
| A6 | No blend ramp (hard switch at step T_w) | 16.8 | +1.6 | Abrupt switch resets effective gradient distribution; first- and second-moment statistics become inconsistent |
| A7 | k_max=4 (severely truncated rank) | 16.1 | +0.9 | Insufficient rank captures <60% spectral mass; relevant curvature directions discarded |
1. EMA bias correction (+3.0 PPL without) ████████████████████ CRITICAL
2. Warmup guard (+3.2 PPL without) ████████████████████ CRITICAL
3. Blend ramp (50-step) (+1.6 PPL without) ███████████ IMPORTANT
4. LR compensation 1.17× (+1.08 PPL without) ███████ IMPORTANT
5. Symmetric clipping (+1.08 PPL without) ███████ IMPORTANT
6. Eigenvector stability (+1.98 PPL with ρ↓) ████████████ IMPORTANT
7. Rank selection k_max (+0.9 PPL at k=4) ██████ MODERATE
Adam uses a diagonal preconditioning matrix D = diag(v_t)^{-1/2}. This rescales each parameter independently, but ignores correlations between parameters in the same weight matrix. For a weight matrix W ∈ ℝ^{m×n}, the true second-order update would use the full mn × mn Fisher Information Matrix — computationally impossible.
SCAO approximates the FIM using a Kronecker product structure:
F ≈ R ⊗ L where L ≈ E[g g^T]_left, R ≈ E[g g^T]_right
The preconditioned gradient is then:
G_precond = (R^{-1/4} ⊗ L^{-1/4}) · vec(G)
= L^{-1/4} · G · R^{-1/4} (in matrix form)
This is implemented efficiently by keeping only the top-k eigenvectors of L and R, reducing the cost from O(m³ + n³) to O((m+n)·k²).
The two-phase behavior (SCAO leads for ~165 steps, then AdamW catches up) is explained by the interaction between Kronecker preconditioning and cosine LR scheduling:
Phase 1 (steps 1–50): Pure Adam warmup. Both optimizers are identical.
Early Phase 2 (steps 50–165): SCAO's preconditioner has learned dominant curvature directions and rotates gradients into better-conditioned subspaces. The directional improvement outweighs the small LR bias introduced by the preconditioner.
Late Phase 2 (steps 165–500): The cosine LR schedule aggressively decays the learning rate. In this regime, AdamW's per-element 1/√v normalization is extremely tight — it effectively "knows" which parameters have converged and which haven't. SCAO's Kronecker structure is coarser-grained and slightly over-damps in directions where AdamW's diagonal is already well-calibrated.
The key insight: The gap does not grow after step 200. It stabilizes at 0.17–0.20 PPL. This means SCAO has reached a different, slightly higher local minimum — not that it has diverged or is slower. At larger scale, where off-diagonal curvature dominates, this crossover point moves later or disappears.
The Area Under the Curve of training loss across all 500 steps:
AdamW AUC = 2.873 (mean training loss)
SCAO AUC = 2.854 ← lower = better
SCAO has lower average training loss across the full run, even though its final loss is slightly higher. This means SCAO spends more of the training budget at lower loss values — which is exactly what matters for transfer learning, fine-tuning, and early stopping scenarios.
AdamW at every step: O(d) operations (elementwise multiply/divide).
SCAO at every step: O(m·k + k² + n·k) = O((m+n)·k) for preconditioned projection, plus O(d) for Adam update.
SCAO every T_p steps: O(m·k² + n·k²) for eigendecomposition.
For m=n=128, k=40, T_p=10:
- Per-step overhead vs AdamW: ~2× in FLOPs
- But: preconditioned gradients require fewer effective steps to converge
- And: eigendecomposition amortized over 10 steps → wall-clock overhead is small
- Net result: SCAO's better-conditioned updates allow larger effective batch progress, which translates to more PPL reduction per second
# CPU-only (default)
pip install scao
# CUDA kernels (fused low-rank ops, requires nvcc)
pip install "scao[cuda]"
# HuggingFace Trainer integration
pip install "scao[hf]"
# Everything
pip install "scao[all]"git clone https://github.com/whispering3/scao
cd scao
pip install -e ".[dev]"
pytest scao/tests/ -v # 66 tests: 40 optimizer + 26 profilingExpected test output:
collected 67 items
scao/tests/test_optimizer.py .................................... 40 passed, 1 skipped
scao/tests/test_profiling.py .......................... 26 passed
66 passed, 1 skipped (torch.compile requires C++ toolchain on Windows)
from scao import SCAO
optimizer = SCAO(
model.parameters(),
lr=3.5e-4, # use ~1.17× your AdamW LR (see §8 "Why it works")
weight_decay=0.1,
warmup_steps=100, # Adam-only warmup before SCAO activates
precond_freq=10, # update curvature every 10 steps
min_precond_updates=2, # require ≥2 reliable curvature samples before Phase 2
)
# Training loop: identical to AdamW — no other changes needed
for x, y in dataloader:
loss = model(x, y)
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) # ← apply to all
optimizer.step()
optimizer.zero_grad()
⚠️ Important: Apply gradient clipping to all optimizers equally. Exempting SCAO from clipping introduces a Phase 1 divergence that invalidates the comparison (ablation A2: +1.08 PPL without symmetric clipping).
# Before
optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4, weight_decay=0.1)
# After ← change lr to ~1.17× and add two SCAO kwargs
optimizer = SCAO(model.parameters(), lr=3.5e-4, weight_decay=0.1,
warmup_steps=100, precond_freq=10)from scao.integrations.huggingface import SCAOTrainer
trainer = SCAOTrainer(
model=model,
args=training_args,
train_dataset=dataset,
scao_kwargs=dict(
lr=3.5e-4,
warmup_steps=500,
precond_freq=20,
),
)
trainer.train()from scao.logging import ConsoleLogger, WandbLogger, TensorBoardLogger
# Console: prints rank, curvature norm, phase every N steps
optimizer.add_callback(ConsoleLogger(log_every=100))
# WandB: logs scao/rank, scao/curvature_norm, scao/phase, scao/blend
optimizer.add_callback(WandbLogger())
# TensorBoard
optimizer.add_callback(TensorBoardLogger(writer))| Parameter | Default | Description |
|---|---|---|
lr |
1e-3 |
Learning rate. Use ~1.1–1.2× your AdamW baseline |
betas |
(0.9, 0.999) |
Adam β₁, β₂ — same as AdamW |
weight_decay |
0.01 |
Decoupled weight decay (AdamW-style) |
warmup_steps |
100 |
Steps of pure Adam before Phase 2 activates |
min_precond_updates |
10 |
Minimum curvature updates before Phase 2 |
precond_freq |
20 |
Steps between curvature updates (T_p) |
rho |
0.999 |
EMA decay for curvature factors (see note below) |
epsilon_sparse |
0.05 |
Spectral mass fraction to discard (5%) |
k_min / k_max |
8 / 128 |
Rank bounds per layer |
tau |
None |
Natural gradient clipping threshold |
max_precond_dim |
4096 |
Layers above this dimension use diagonal fallback |
use_int8_ema |
False |
Store EMA curvature factors in int8 (4× memory reduction) |
eps |
1e-8 |
Adam epsilon for numerical stability |
rho controls how quickly the curvature estimate responds to new gradient information:
Effective window ≈ 1 / (1 - rho) → rho=0.999 ≈ 1000-step window
Critical finding from ablation A4: For short runs (≤2k steps), high rho=0.999 significantly outperforms rho=0.925 (+1.98 PPL). Rapidly changing eigenvectors (low rho) disrupt Adam's momentum accumulation. Eigenvector stability matters more than curvature freshness.
| Training length | Recommended rho |
Reasoning |
|---|---|---|
| ≤ 2k steps | 0.999 |
Stable eigenvectors; momentum accumulates cleanly |
| 2k–10k steps | 0.99 |
Moderate responsiveness without instability |
| ≥ 10k steps | 0.95–0.99 |
EMA window fully warms up; responsiveness beneficial |
| Model size | warmup_steps |
precond_freq |
k_max |
tau |
rho |
|---|---|---|---|---|---|
| < 10M params | 100 | 10 | 64 | None | 0.999 |
| 10M – 300M | 500 | 20 | 128 | 5.0 | 0.999 |
| 300M – 7B | 1000 | 50 | 256 | 5.0 | 0.99 |
| > 7B | 2000 | 100 | 512 | 10.0 | 0.99 |
pip install torch datasets tqdm pyyaml
# or
pip install -e ".[dev]"python scao/benchmarks/gpt_scale_benchmark.py \
--scales tiny \
--steps 500 \
--optimizers "adamw,scao" \
--seeds "42" \
--csv results_reproduced.csvExpected output:
Scale Optimizer PPL mean tok/s mem GB
1M adamw 11.85 537 0.012
1M scao 12.03 827 0.026
bash scripts/run_experiment.sh # all seeds, ~2h on A100
bash scripts/run_experiment.sh --quick # 1 seed, 100 steps, ~5 minOpen scripts/scao_colab_benchmark.ipynb in Colab with GPU runtime. The notebook:
- Installs dependencies automatically
- Runs smoke test (30 steps)
- Runs 125M-parameter GPT benchmark (comparing SCAO vs AdamW vs DiagShampoo)
- Runs 350M-parameter benchmark (requires A100 for VRAM)
- Plots convergence curves, memory breakdown, and time-to-PPL speedup
- Downloads a zip of all results
| Setting | Value |
|---|---|
| Model | 4-layer GPT, d=128, 4 heads |
| Parameters | 837,888 |
| Dataset | WikiText-2 (HuggingFace wikitext, wikitext-2-raw-v1) |
| Context length | 128 tokens |
| Batch size | 16 |
| Steps | 500 |
| LR schedule | Cosine decay, T_max = steps |
| AdamW LR | 3.0e-4 |
| SCAO LR | 3.51e-4 (1.17× compensation) |
| Gradient clip | 1.0 (applied to all optimizers equally) |
| Seed | 42 |
| Hardware | CPU (AMD/Intel), Python 3.14.2, PyTorch 2.x |
scao/ # Core library
├── optimizer.py # SCAO main class — drop-in for AdamW
├── preconditioner.py # SparsePreconditioner: Kronecker low-rank + int8 EMA
├── utils.py # adaptive_rank, quantize_sym_int8, dequantize_sym_int8
├── distributed.py # ZeRO-3 / FSDP helpers
├── logging.py # ConsoleLogger, TensorBoardLogger, WandbLogger
├── integrations/
│ └── huggingface.py # SCAOTrainer, SCAOMonitorCallback
├── benchmarks/
│ └── gpt_scale_benchmark.py # Multi-scale GPT: SCAO vs AdamW vs SCAO-int8
├── tests/
│ ├── test_optimizer.py # 40 optimizer correctness tests
│ └── test_profiling.py # 26 memory + timing profiling tests
└── cuda/
├── low_rank_ops.cu # Fused CUDA kernels: tiled GEMM, Kronecker precond, int8 EMA
├── __init__.py # fused_kronecker_precond(), int8_ema_update(), truncated_eigh()
└── setup.py # nvcc build (sm_70/75/80/86/89/90)
configs/ # YAML hyperparameter configs
├── base.yaml # Shared defaults
├── gpt_small.yaml # GPT-small (117M) config
└── vit_small.yaml # ViT-S/16 config
scripts/
├── run_experiment.py # Python experiment runner with argparse
├── run_experiment.sh # Full reproduction shell script
├── bench_125m_350m.py # 125M / 350M benchmark (AdamW vs SCAO vs SCAO-int8)
└── scao_colab_benchmark.ipynb # Colab GPU benchmark (125M / 350M)
paper/
└── scao.tex # NeurIPS 2026 paper source (LaTeX)
results_v11.csv # 200-step benchmark (primary ablation baseline)
results_v11_500.csv # 500-step benchmark (primary paper result)
results_v11_500_curves.csv # Per-checkpoint PPL curves (500 steps)
results_scao_vs_adamw.csv # Per-step training loss (Phase 1 analysis)
| Class | File | Description |
|---|---|---|
SCAO |
optimizer.py |
Main optimizer. Implements torch.optim.Optimizer. Phase 1 = Adam warmup, Phase 2 = preconditioned Adam |
SparsePreconditioner |
preconditioner.py |
Per-layer Kronecker curvature. Maintains L_ema, R_ema and their low-rank eigenfactors (U_l, S_l, U_r, S_r) |
adaptive_rank |
utils.py |
Returns smallest k s.t. top-k eigenvalues capture ≥ (1-ε) of spectral mass |
SCAOTrainer |
integrations/huggingface.py |
Drop-in for HuggingFace Trainer |
If you use SCAO in your research, please cite:
@inproceedings{scao2026,
title = {SCAO: Sparse Curvature-Aware Adaptive Optimization for Large-Scale Models},
author = {Anonymous},
booktitle = {Advances in Neural Information Processing Systems (NeurIPS)},
year = {2026},
}SCAO builds on and extends:
@article{vyas2024soap,
title = {SOAP: Improving and Stabilizing Shampoo using Adam},
author = {Vyas, Nikhil and Morwani, Depen and Zhao, Rosie and others},
journal = {arXiv:2409.11321},
year = {2024},
}
@inproceedings{gupta2018shampoo,
title = {A Unified View of Adaptive Gradient Methods},
author = {Gupta, Vineet and Koren, Tomer and Singer, Yoram},
booktitle = {NeurIPS},
year = {2018},
}Apache 2.0 — see LICENSE.
Full convergence data (raw, seed 42, 500 steps)
Step | AdamW PPL | SCAO PPL | Gap (SCAO - AdamW)
-----|-----------|----------|-------------------
25 | 116.52 | 111.95 | -4.57 ← SCAO leads
50 | 66.56 | 59.05 | -7.51 ← SCAO leads
75 | 34.75 | 30.03 | -4.72 ← SCAO leads
100 | 21.73 | 19.46 | -2.27 ← SCAO leads
125 | 16.97 | 15.95 | -1.02 ← SCAO leads
150 | 15.02 | 14.63 | -0.39 ← SCAO leads
175 | 14.10 | 14.12 | +0.02 ≈ crossover
200 | 13.47 | 13.62 | +0.15
225 | 13.23 | 13.33 | +0.10
250 | 12.78 | 12.99 | +0.21
275 | 12.61 | 12.76 | +0.15
300 | 12.50 | 12.63 | +0.13
325 | 12.27 | 12.44 | +0.17
350 | 12.22 | 12.39 | +0.17
375 | 12.15 | 12.32 | +0.17
400 | 12.03 | 12.19 | +0.16
425 | 11.98 | 12.16 | +0.18
450 | 11.94 | 12.12 | +0.18
475 | 11.89 | 12.08 | +0.19
500 | 11.85 | 12.03 | +0.18 ← final