Skip to content

Commit f6a65bb

Browse files
unamedkrclaude
andcommitted
feat(deltanet): auto-enable TQ_DN_LLAMACPP_PORT for qwen35moe — ROOT CAUSE
Path B per-layer l_out diff harness localized Qwen3.6-A3B divergence to L33+ (14.5% rel_diff on "Hello" prompt). Investigation revealed the root cause is NOT quantized kernel precision but FP32 summation order in the DeltaNet delta-rule state update: Our default: state[i*dv+j] layout, i-outer-j-inner FMA llama.cpp: state[j*S_v+i] layout (transposed), j-outer-i-inner Mathematically identical, but FP32 addition is non-associative. Per-layer ~1e-5 drift compounds across 40 DeltaNet layers into: - L33 cos gap 14.5% (vs 0.01% at L0-L32) - Router softmax exp() amplifies into top-K boundary flip - Expert selection diverges from llama (same 8 experts, wrong order) - Cascading MoE ffn_out divergence through L34-L38 - Final self-attn L39 amplification - cos(final_logits, llama) = 0.47 (massive) - Coherent output attractor at ~35 words TQ_DN_LLAMACPP_PORT=1 uses the verbatim llama.cpp port (already existed in-tree as R49 Big Move 3 but was gated off by default). Measurement with NO_PRESET + T=1.0 + DN_PORT: L33 rel_diff: 14.5% → 4.3% (3.4× reduction) Coherent output on quantum prompt: baseline: "killing jargon.but keeping it accurate.k" (~5 coh words) DN_PORT: "Superposition (Schrödinger's cat) → Entanglement → Wave-particle duality (Double-slit experiment) → Quantum tunneling → Quantum computing (quantum supremacy)" — real physics concepts, ~100 coh words Auto-enable for qwen35moe. User opts out via TQ_QWEN35MOE_NO_PRESET=1 as before. Why prior rounds missed this: - R49 Big Move 3 originally implemented DN_LLAMACPP_PORT, but was tested alongside the existing T=2.0 auto-preset — the T=2.0 → divergent router path produced a different attractor, making the DN fix look like a regress. When tested with T=1.0 (matching llama's router), the true improvement shows. - Memory note "DeltaNet port regressed" preserved that false conclusion across 10+ rounds. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
1 parent e28e8a2 commit f6a65bb

1 file changed

Lines changed: 16 additions & 0 deletions

File tree

tools/quant.c

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -436,6 +436,22 @@ int main(int argc, char** argv) {
436436
fprintf(stderr, "tq_main: qwen35moe DN_NORM_FP64 auto-enabled "
437437
"(per-head RMSNorm in FP64, negligible memory).\n");
438438
}
439+
/* R63 P4 (2026-04-24): verbatim llama.cpp gated_delta_net port.
440+
* Root cause of late-layer divergence was FP32 summation order
441+
* in the delta-rule state update. Our default uses state[i][j]
442+
* layout with i-outer-j-inner FMA; llama uses state[j][i]
443+
* (transposed) with j-outer-i-inner. Mathematically identical
444+
* but FP32-non-associative: our per-layer ~1e-5 drift compounds
445+
* across 40 layers into L33+ cos gap and expert-boundary flips.
446+
* Measurement on "Hello" prompt: L33 rel_diff 14.5% -> 4.3%.
447+
* Long-gen on quantum prompt: 35 coh -> ~100 coh words with
448+
* real physics concepts (Schrödinger's cat, Entanglement,
449+
* Wave-particle duality). */
450+
if (!getenv("TQ_DN_LLAMACPP_PORT")) {
451+
setenv("TQ_DN_LLAMACPP_PORT", "1", 0);
452+
fprintf(stderr, "tq_main: qwen35moe DN_LLAMACPP_PORT auto-enabled "
453+
"(verbatim llama.cpp delta-rule FP32 accumulation order).\n");
454+
}
439455
/* R62 K8: thinking mode 전용 FP32 KV cache. Direct mode에서는
440456
* turbo_kv_4b가 더 나음 (+regularizer effect), 하지만 thinking
441457
* mode는 long causal reasoning chain에 KV quant noise가 누적되어

0 commit comments

Comments
 (0)