Skip to content

[Feat] Add JAX cpu-ref of KDA (naive and recurrent)#175

Merged
0xaskr merged 9 commits intomainfrom
feat/kda-cpu-gold-new
Apr 13, 2026
Merged

[Feat] Add JAX cpu-ref of KDA (naive and recurrent)#175
0xaskr merged 9 commits intomainfrom
feat/kda-cpu-gold-new

Conversation

@lingebeng
Copy link
Copy Markdown
Collaborator

@lingebeng lingebeng commented Apr 8, 2026

Description

  • Add JAX naive recurrent and chunk
  • Add JAX recurrent_fwd and recurrent_bwd

Related Issue

Closes #

Change Type

  • feat — New feature
  • fix — Bug fix
  • refactor — Code refactoring
  • docs — Documentation
  • ci — CI/CD changes
  • test — Tests
  • perf — Performance improvement

Checklist

  • Code passes uv run ruff check src/ tests/ and uv run ruff format src/ tests/
  • New/modified public APIs have complete docstrings (tensor shapes, dimension semantics, business logic)
  • Public functions have input assertions (assert or assert_shape_or_none)
  • Tests added at the appropriate layer (tests/ops/, tests/modules/, tests/layers/, or tests/ref/)
  • If tops/cpu/ is modified, core developers have been notified and PR is labeled cpu-ref

Test Results

  • Causes of error:
    • FMA: GPU computes a*b+c using Fused Multiply-Add (rounded only once), while CPU computes it in two steps (rounded twice). The difference can be up to 1 ULP.
    • Different exp() implementations: CUDA and CPU libm use different polynomial approximations.
    • Reduction order: Floating-point addition is not associative; (a+b)+c ≠ a+(b+c). GPU uses tree reduction, while CPU uses sequential reduction.
    • These three factors produce very small differences per step (< 1 ULP). However, KDA has a recursive structure — larger shapes (more heads, longer sequences) lead to greater error accumulation.

Desc

image

Summary by CodeRabbit

  • New Features

    • Added CPU reference fused-recurrent KDA and gate operations with multiple gate formulas and chunk-local cumsum support.
  • API Changes

    • Reorganized KDA exports and renamed naive KDA to recurrent/chunked variants to expose fused forward/backward entry points.
  • Tests

    • Added comprehensive reference tests for naive, fused-recurrent, and chunked KDA; removed an obsolete legacy test module.

lingebeng and others added 2 commits April 7, 2026 17:50
- Add strict input shape assertions to naive_recurrent_kda and
  naive_chunk_kda matching chunk_kda and GLA patterns
- Fix _acc_dtype(q) -> _acc_dtype(q.dtype) in naive_recurrent_kda
  so fp64 inputs correctly use fp64 accumulator
- Add T-to-chunk_size padding in naive_chunk_kda for non-aligned
  sequence lengths
- Replace test_chunk_kda.py with test_naive_kda.py: 48 tests covering
  dtype verification, input assertions, recurrent-vs-chunk cross
  validation (fp32/fp64), backward cross-validation, and FLA Triton
  comparison (GPU-optional)
- Remove chunk.py (not yet implemented) and clean up __init__.py

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
@coderabbitai
Copy link
Copy Markdown

coderabbitai bot commented Apr 8, 2026

📝 Walkthrough

Walkthrough

Refactors CPU KDA references: removes legacy chunk_kda and its tests, renames/refactors naive_kda to naive_recurrent_kda and adds naive_chunk_kda, adds gate utilities, implements fused recurrent forward/backward (fused_recurrent_kda_*), updates exports, and introduces new JAX CPU test suites for naive and fused implementations.

Changes

Cohort / File(s) Summary
Naive KDA & chunked variant
tops/cpu/ops/kda/naive.py
Renamed naive_kdanaive_recurrent_kda, changed dtype/casting and transposed inputs; added naive_chunk_kda implementing intra-chunk forward-substitution and inter-chunk recurrence with padding/reshape logic.
Fused recurrent KDA
tops/cpu/ops/kda/fused_recurrent.py
New CPU reference module providing fused_recurrent_kda_fwd, fused_recurrent_kda_bwd, and fused_recurrent_kda; handles varlen, L2-normalization, optional gate transform, dtype/accumulator casting, forward recurrence and reverse-time gradient accumulation.
KDA gate utilities
tops/cpu/ops/kda/gate.py
New gate implementations and APIs: naive_kda_gate, naive_kda_lowerbound_gate, kda_gate_fwd, kda_gate_bwd, fused_kda_gate, and kda_gate_chunk_cumsum with head-expansion and chunk-local cumsum support.
Exports & legacy removal
tops/cpu/ops/kda/__init__.py, tops/cpu/ops/kda/chunk.py
__init__ updated to export new fused/naive/gate entry points; removed the prior chunk_kda implementation (file chunk.py deleted).
Test suites
tests/ref/kda/test_chunk_kda.py (deleted), tests/ref/kda/test_naive_kda.py, tests/ref/kda/test_fused_recurrent_kda.py
Removed legacy chunk tests; added comprehensive CPU JAX tests for naive and fused implementations (dtype checks, forward/backward cross-validation, varlen behavior, optional Triton GPU comparisons).

Sequence Diagram(s)

sequenceDiagram
    participant Test as Test Suite
    participant API as fused_recurrent_kda_fwd
    participant Gate as Gate Helper
    participant Accum as Accumulator Loop
    participant State as Hidden State

    Test->>API: call(q,k,v,g,beta,initial_state,cu_seqlens,A_log,dt_bias,lower_bound)
    API->>Gate: expand head params, (opt) L2-normalize Q/K
    API->>Gate: compute gate = fused_kda_gate(g,A_log,dt_bias,lower_bound)
    Gate-->>API: gate
    API->>Accum: init accumulator/state
    loop over timesteps (or cu_seqlens segments)
        Accum->>State: apply gate decay, compute h_prime
        State->>Accum: out[t] = combine(h_prime,k,v)
        Accum->>State: update h via delta-rule using beta and v-correction
    end
    Accum-->>API: outputs, final_state
    API-->>Test: return (out, final_state)
Loading
sequenceDiagram
    participant Test as Test Suite
    participant Bwd as fused_recurrent_kda_bwd
    participant Replay as Forward Replay
    participant BackPass as Reverse Pass
    participant Grad as Gradient Accumulator

    Test->>Bwd: call(q,k,v,g,beta,dout,dht,initial_state,cu_seqlens,A_log,dt_bias)
    Bwd->>Replay: re-run forward, store h_prime_all, h_all, gate_all
    Replay-->>Bwd: intermediate states
    Bwd->>BackPass: compute per-timestep dq from dout
    BackPass->>Grad: loop reverse-time, accumulate dk, dv, dg, dbeta
    BackPass-->>Bwd: gradients (dq, dk, dv, dg, dbeta, dh0)
    Bwd-->>Test: return gradients
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~60 minutes

Possibly related issues

Possibly related PRs

Suggested reviewers

  • 0xaskr
  • pathfinder-pf

"🐰
I hop through gates and shape the state,
New kernels hum and tests await,
Old chunk files tucked away to rest,
Recurrent hops give outputs best,
JAX ears twitch — the code's well dressed."

🚥 Pre-merge checks | ✅ 2 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 58.00% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (2 passed)
Check name Status Explanation
Title check ✅ Passed The title accurately and specifically describes the main change: adding JAX CPU reference implementations of KDA (naive and recurrent variants).
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
📝 Generate docstrings
  • Create stacked PR
  • Commit on current branch
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Commit unit tests in branch feat/kda-cpu-gold-new

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request reorganizes the KDA (Kernel Delta Attention) CPU reference implementation by introducing a fused recurrent version, a dedicated gate module, and updated naive and chunked implementations alongside comprehensive tests. Feedback highlights a critical bug in the intra-chunk dependency logic of the chunked implementation that results in incorrect updates. Additionally, several opportunities for performance optimization are identified, including vectorizing matrix construction loops and removing redundant dtype casting logic. Finally, a request was made to maintain language consistency by using English for all code comments.

@lingebeng lingebeng added the cpu-ref Modifies tops/cpu/ reference implementations label Apr 8, 2026
lingebeng and others added 4 commits April 8, 2026 13:24
- fused_recurrent: pre-compute gate before passing to FLA Triton, which
  ignores use_gate_in_kernel/A_log/dt_bias via **kwargs
- naive: relax fp32/fp64 tolerance (5e-5 → 5e-3) for large shapes where
  GPU vs CPU fp32 rounding accumulates over T=64 recurrence steps

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
@lingebeng lingebeng marked this pull request as ready for review April 10, 2026 01:22
Copy link
Copy Markdown

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 6

🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@tests/ref/kda/test_fused_recurrent_kda.py`:
- Around line 1-10: Add a backward cross-check test that calls
fused_recurrent_kda_bwd and compares its gradients to a naive/reference backward
implementation (e.g., compute grads via naive_recurrent or autograd on
fused_recurrent_kda_fwd) using tolerance-based assertions; the test should
exercise both boolean branches use_gate_in_kernel and use_qk_l2norm_in_kernel
(run subcases with each True/False), include varlen and fixed-length cases
analogous to existing forward tests, and assert close equality for gradient
tensors (w.r.t. inputs, gates, and parameters) to catch regressions.

In `@tests/ref/kda/test_naive_kda.py`:
- Around line 282-290: The test currently drops the first time column for all
"dg" comparisons by slicing ref/test to skip t=0; instead make this conditional
on the zero-initial-state case (the h0 setting from _XVAL_SHAPES). Change the
block guarded by if name == "dg" to only remove the t=0 column when the test
config indicates no initial state (e.g., h0 is None or False) — leave ref and
test intact when h0 is True so real backward mismatches are not masked;
reference the variables name, ref, test and the _XVAL_SHAPES/h0 flag to locate
and update the conditional.

In `@tops/cpu/ops/kda/fused_recurrent.py`:
- Around line 285-287: When use_qk_l2norm_in_kernel=True the code normalizes q_f
and k_f but returns dq_f/dk_f as if they were gradients w.r.t. the original
(unnormalized) inputs; fix this by preserving the pre-normalized tensors (e.g.
q_raw = q_f; k_raw = k_f), compute norms (norm_q = sqrt(sum(q_raw**2, axis=-1,
keepdims=True)) and same for k), then after replay convert the gradient on the
normalized vectors back to gradients on the raw inputs using the
L2-normalization Jacobian: dq_raw = (dq_f - q_norm * sum(q_norm * dq_f, axis=-1,
keepdims=True)) / norm_q and similarly for dk_raw with k_norm and norm_k;
replace returns of dq_f/dk_f with these propagated dq_raw/dk_raw and apply the
same change where _l2_normalize_last_dim is used (also update the analogous
block around the 406-418 region).
- Around line 289-295: The branch for use_gate_in_kernel currently transposes g
to g_raw = jnp.transpose(g, (0, 2, 1, 3)) which produces [B, H, T, K] and passes
that into fused_kda_gate; fused_kda_gate expects [B, T, H, K] (penultimate dim =
heads), causing incorrect parameter alignment and broken backward passes. Fix by
passing g to fused_kda_gate in [B, T, H, K] order (i.e., do not transpose g
before calling fused_kda_gate) so fused_kda_gate(g, A_log, dt_bias=dt_bias,
lower_bound=lower_bound, output_dtype=acc) receives the correct layout and head
dimension matches A_log.

In `@tops/cpu/ops/kda/gate.py`:
- Around line 60-95: Add full docstrings to naive_kda_gate and
naive_kda_lowerbound_gate describing semantic meaning and exact tensor
shapes/dtypes for every input and output (e.g., g: jax.Array with shape [..., H,
K], A_log: head-wise log amplitudes with shape [H] or [H,1], dt_bias: optional
per-head or per-head-per-channel bias shape [H] or [H,K], output: same leading
dims as g with last two dims matching H,K and dtype output_dtype). At the top of
each function replace the minimal checks with strict assertions using
tops.utils.assert_shape_or_none and explicit type checks: assert g.ndim >= 2 and
g.dtype is a numeric jax dtype, assert_shape_or_none(A_log, (H,)) (or (H,1)) to
match how A_f is reshaped, assert_shape_or_none(dt_bias, (H,)) or (H,K) and that
dt_bias dtype is numeric or None, and for naive_kda_lowerbound_gate assert
isinstance(lower_bound, float) and isfinite. Ensure these assertions run before
any casting/logic, and update the docstrings to reference the parameter
semantics and return dtype (output_dtype). Use the existing helper
_expand_headwise_params and _acc_dtype but document their expectations in the
docstrings.

In `@tops/cpu/ops/kda/naive.py`:
- Around line 177-190: Add an explicit assertion that chunk_size is a positive
integer before any padding/math using it: insert something like assert
isinstance(chunk_size, int) and assert chunk_size > 0 (or use the project's
validation helper) just before computing T_padded where chunk_size/BT and
_cdiv(T, BT) are used; reference the symbols chunk_size (BT), _cdiv, and the
T_padded computation so the check prevents invalid chunk_size values from
reaching _cdiv and follows the project's assert-based input validation policy.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: ede1e66e-29ad-4c02-b747-887895a60f77

📥 Commits

Reviewing files that changed from the base of the PR and between dfe6ffa and 9cc7b30.

📒 Files selected for processing (8)
  • tests/ref/kda/test_chunk_kda.py
  • tests/ref/kda/test_fused_recurrent_kda.py
  • tests/ref/kda/test_naive_kda.py
  • tops/cpu/ops/kda/__init__.py
  • tops/cpu/ops/kda/chunk.py
  • tops/cpu/ops/kda/fused_recurrent.py
  • tops/cpu/ops/kda/gate.py
  • tops/cpu/ops/kda/naive.py
💤 Files with no reviewable changes (2)
  • tests/ref/kda/test_chunk_kda.py
  • tops/cpu/ops/kda/chunk.py

Comment on lines +1 to +10
"""fused_recurrent_kda(+_fwd) / kda_gate_*: JAX CPU ref (tops.cpu.ops.kda) tests.

Tests:
1. Gate formula verification (no GPU)
2. Gate chunk cumsum verification (no GPU, varlen)
3. Dtype verification (no GPU)
4. Cross-validation: fused_recurrent vs naive_recurrent (no GPU)
5. Wrapper equivalence: fused_recurrent_kda_fwd vs fused_recurrent_kda (no GPU)
6. Varlen: cu_seqlens vs per-segment calls (no GPU)
"""
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Add a backward cross-check for fused_recurrent_kda_bwd.

This suite is forward-only right now. The new backward reference API and its use_gate_in_kernel / use_qk_l2norm_in_kernel branches are never exercised, so gradient regressions can slip through untested.

Based on learnings "Each Jax/Pallas kernel implementation must have a corresponding CPU reference test comparing the optimized kernel against naive implementations with tolerance-based assertions".

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/ref/kda/test_fused_recurrent_kda.py` around lines 1 - 10, Add a
backward cross-check test that calls fused_recurrent_kda_bwd and compares its
gradients to a naive/reference backward implementation (e.g., compute grads via
naive_recurrent or autograd on fused_recurrent_kda_fwd) using tolerance-based
assertions; the test should exercise both boolean branches use_gate_in_kernel
and use_qk_l2norm_in_kernel (run subcases with each True/False), include varlen
and fixed-length cases analogous to existing forward tests, and assert close
equality for gradient tensors (w.r.t. inputs, gates, and parameters) to catch
regressions.

Comment on lines +282 to +290
if name == "dg":
# At t=0 with zero initial state, dg_0 = 0 mathematically because
# exp(g_0) * 0 has no dependency on g_0. But the chunk version's
# cumsum-based computation creates AD paths that produce small
# numerical residuals. Skip t=0 for dg comparison.
if ref.shape[1] <= 1:
continue
ref = ref[:, 1:]
test = test[:, 1:]
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Only skip dg[:, 0] for the zero-initial-state cases.

The rationale here assumes h0 is None, but _XVAL_SHAPES also includes an h0=True case. In that configuration, g at t=0 affects the decayed initial state, so slicing off the first step masks real backward mismatches.

Suggested fix
-        if name == "dg":
+        if name == "dg" and h0 is None:
             # At t=0 with zero initial state, dg_0 = 0 mathematically because
             # exp(g_0) * 0 has no dependency on g_0. But the chunk version's
             # cumsum-based computation creates AD paths that produce small
             # numerical residuals. Skip t=0 for dg comparison.
             if ref.shape[1] <= 1:
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
if name == "dg":
# At t=0 with zero initial state, dg_0 = 0 mathematically because
# exp(g_0) * 0 has no dependency on g_0. But the chunk version's
# cumsum-based computation creates AD paths that produce small
# numerical residuals. Skip t=0 for dg comparison.
if ref.shape[1] <= 1:
continue
ref = ref[:, 1:]
test = test[:, 1:]
if name == "dg" and h0 is None:
# At t=0 with zero initial state, dg_0 = 0 mathematically because
# exp(g_0) * 0 has no dependency on g_0. But the chunk version's
# cumsum-based computation creates AD paths that produce small
# numerical residuals. Skip t=0 for dg comparison.
if ref.shape[1] <= 1:
continue
ref = ref[:, 1:]
test = test[:, 1:]
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/ref/kda/test_naive_kda.py` around lines 282 - 290, The test currently
drops the first time column for all "dg" comparisons by slicing ref/test to skip
t=0; instead make this conditional on the zero-initial-state case (the h0
setting from _XVAL_SHAPES). Change the block guarded by if name == "dg" to only
remove the t=0 column when the test config indicates no initial state (e.g., h0
is None or False) — leave ref and test intact when h0 is True so real backward
mismatches are not masked; reference the variables name, ref, test and the
_XVAL_SHAPES/h0 flag to locate and update the conditional.

Comment on lines +285 to +287
if use_qk_l2norm_in_kernel:
q_f = _l2_normalize_last_dim(q_f)
k_f = _l2_normalize_last_dim(k_f)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Propagate dq/dk through the L2-normalization Jacobian.

When use_qk_l2norm_in_kernel=True, the replay runs on normalized q_f and k_f, but the function returns dq_f and dk_f directly as if they were gradients with respect to the raw inputs. That drops the chain rule and makes the backward reference incorrect for normalized Q/K.

Suggested fix
-    q_f = jnp.transpose(q, (0, 2, 1, 3)).astype(acc)   # [B, H, T, K]
-    k_f = jnp.transpose(k, (0, 2, 1, 3)).astype(acc)   # [B, H, T, K]
+    q_in = jnp.transpose(q, (0, 2, 1, 3)).astype(acc)  # [B, H, T, K]
+    k_in = jnp.transpose(k, (0, 2, 1, 3)).astype(acc)  # [B, H, T, K]
+    q_f = q_in
+    k_f = k_in
     v_f = jnp.transpose(v, (0, 2, 1, 3)).astype(acc)    # [B, H, T, V]
     g_f = jnp.transpose(g, (0, 2, 1, 3)).astype(acc)    # [B, H, T, K]
     beta_f = jnp.transpose(beta, (0, 2, 1)).astype(acc)  # [B, H, T]
     do_f = jnp.transpose(do, (0, 2, 1, 3)).astype(acc)  # [B, H, T, V]
 
     if use_qk_l2norm_in_kernel:
         q_f = _l2_normalize_last_dim(q_f)
         k_f = _l2_normalize_last_dim(k_f)
@@
-    dq = jnp.transpose(dq_f, (0, 2, 1, 3)).astype(q.dtype)
-    dk = jnp.transpose(dk_f, (0, 2, 1, 3)).astype(k.dtype)
+    if use_qk_l2norm_in_kernel:
+        _, q_pullback = jax.vjp(_l2_normalize_last_dim, q_in)
+        (dq_f,) = q_pullback(dq_f)
+        _, k_pullback = jax.vjp(_l2_normalize_last_dim, k_in)
+        (dk_f,) = k_pullback(dk_f)
+
+    dq = jnp.transpose(dq_f, (0, 2, 1, 3)).astype(q.dtype)
+    dk = jnp.transpose(dk_f, (0, 2, 1, 3)).astype(k.dtype)

Also applies to: 406-418

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tops/cpu/ops/kda/fused_recurrent.py` around lines 285 - 287, When
use_qk_l2norm_in_kernel=True the code normalizes q_f and k_f but returns
dq_f/dk_f as if they were gradients w.r.t. the original (unnormalized) inputs;
fix this by preserving the pre-normalized tensors (e.g. q_raw = q_f; k_raw =
k_f), compute norms (norm_q = sqrt(sum(q_raw**2, axis=-1, keepdims=True)) and
same for k), then after replay convert the gradient on the normalized vectors
back to gradients on the raw inputs using the L2-normalization Jacobian: dq_raw
= (dq_f - q_norm * sum(q_norm * dq_f, axis=-1, keepdims=True)) / norm_q and
similarly for dk_raw with k_norm and norm_k; replace returns of dq_f/dk_f with
these propagated dq_raw/dk_raw and apply the same change where
_l2_normalize_last_dim is used (also update the analogous block around the
406-418 region).

Comment on lines +289 to +295
if use_gate_in_kernel:
assert A_log is not None
g_raw = jnp.transpose(g, (0, 2, 1, 3))
g_f = fused_kda_gate(
g_raw, A_log, dt_bias=dt_bias,
lower_bound=lower_bound, output_dtype=acc,
)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Call fused_kda_gate with [B, T, H, K], not [B, H, T, K].

fused_kda_gate treats the penultimate dimension as heads. This branch transposes g first, so the helper sees T as the head count and either trips the A_log shape assertion or applies head parameters across time. The backward path is therefore wrong whenever use_gate_in_kernel=True.

Suggested fix
     if use_gate_in_kernel:
         assert A_log is not None
-        g_raw = jnp.transpose(g, (0, 2, 1, 3))
-        g_f = fused_kda_gate(
-            g_raw, A_log, dt_bias=dt_bias,
-            lower_bound=lower_bound, output_dtype=acc,
-        )
+        g_f = jnp.transpose(
+            fused_kda_gate(
+                g,
+                A_log,
+                dt_bias=dt_bias,
+                lower_bound=lower_bound,
+                output_dtype=acc,
+            ),
+            (0, 2, 1, 3),
+        )
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
if use_gate_in_kernel:
assert A_log is not None
g_raw = jnp.transpose(g, (0, 2, 1, 3))
g_f = fused_kda_gate(
g_raw, A_log, dt_bias=dt_bias,
lower_bound=lower_bound, output_dtype=acc,
)
if use_gate_in_kernel:
assert A_log is not None
g_f = jnp.transpose(
fused_kda_gate(
g,
A_log,
dt_bias=dt_bias,
lower_bound=lower_bound,
output_dtype=acc,
),
(0, 2, 1, 3),
)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tops/cpu/ops/kda/fused_recurrent.py` around lines 289 - 295, The branch for
use_gate_in_kernel currently transposes g to g_raw = jnp.transpose(g, (0, 2, 1,
3)) which produces [B, H, T, K] and passes that into fused_kda_gate;
fused_kda_gate expects [B, T, H, K] (penultimate dim = heads), causing incorrect
parameter alignment and broken backward passes. Fix by passing g to
fused_kda_gate in [B, T, H, K] order (i.e., do not transpose g before calling
fused_kda_gate) so fused_kda_gate(g, A_log, dt_bias=dt_bias,
lower_bound=lower_bound, output_dtype=acc) receives the correct layout and head
dimension matches A_log.

Comment on lines +60 to +95
def naive_kda_gate(
g: jax.Array,
A_log: jax.Array,
dt_bias: jax.Array | None = None,
output_dtype: jnp.dtype = jnp.float32,
) -> jax.Array:
"""Reference implementation of the standard KDA gate."""
assert g.ndim >= 2, f"g must have at least 2 dims, got shape {g.shape}"
H, K = g.shape[-2:]
acc = _acc_dtype(g.dtype)
g_f = g.astype(acc)
A_f, bias_f = _expand_headwise_params(A_log, dt_bias, H, K, acc)
if bias_f is not None:
g_f = g_f + bias_f.reshape(H, K)
out = -jnp.exp(A_f.reshape(H, 1)) * jax.nn.softplus(g_f)
return out.astype(output_dtype)


def naive_kda_lowerbound_gate(
g: jax.Array,
A_log: jax.Array,
dt_bias: jax.Array | None = None,
lower_bound: float = -5.0,
output_dtype: jnp.dtype = jnp.float32,
) -> jax.Array:
"""Reference implementation of the lower-bounded KDA gate."""
assert g.ndim >= 2, f"g must have at least 2 dims, got shape {g.shape}"
H, K = g.shape[-2:]
acc = _acc_dtype(g.dtype)
g_f = g.astype(acc)
A_f, bias_f = _expand_headwise_params(A_log, dt_bias, H, K, acc)
if bias_f is not None:
g_f = g_f + bias_f.reshape(H, K)
out = lower_bound * jax.nn.sigmoid(jnp.exp(A_f.reshape(H, 1)) * g_f)
return out.astype(output_dtype)

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion | 🟠 Major

Document and harden the new public gate APIs.

These exported functions only have minimal one-line docstrings, and the validation is partial. The repo contract for public Python APIs is stricter: each function should describe tensor semantics/shapes for every input/output and assert the relevant type/shape invariants up front.

As per coding guidelines **/*.py: All public functions must have a comprehensive docstring that explains the business semantics of the function and clearly details tensor shapes and dimension meanings for every input and output argument; All public functions must enforce strict input assertions on shape and types before executing main logic using assert instructions or utilities like assert_shape_or_none from tops.utils.

Also applies to: 97-204

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tops/cpu/ops/kda/gate.py` around lines 60 - 95, Add full docstrings to
naive_kda_gate and naive_kda_lowerbound_gate describing semantic meaning and
exact tensor shapes/dtypes for every input and output (e.g., g: jax.Array with
shape [..., H, K], A_log: head-wise log amplitudes with shape [H] or [H,1],
dt_bias: optional per-head or per-head-per-channel bias shape [H] or [H,K],
output: same leading dims as g with last two dims matching H,K and dtype
output_dtype). At the top of each function replace the minimal checks with
strict assertions using tops.utils.assert_shape_or_none and explicit type
checks: assert g.ndim >= 2 and g.dtype is a numeric jax dtype,
assert_shape_or_none(A_log, (H,)) (or (H,1)) to match how A_f is reshaped,
assert_shape_or_none(dt_bias, (H,)) or (H,K) and that dt_bias dtype is numeric
or None, and for naive_kda_lowerbound_gate assert isinstance(lower_bound, float)
and isfinite. Ensure these assertions run before any casting/logic, and update
the docstrings to reference the parameter semantics and return dtype
(output_dtype). Use the existing helper _expand_headwise_params and _acc_dtype
but document their expectations in the docstrings.

@0xaskr
Copy link
Copy Markdown
Collaborator

0xaskr commented Apr 10, 2026

@copilot fix last ci error.

Copilot stopped work on behalf of 0xaskr due to an error April 10, 2026 09:16
Copilot AI requested a review from 0xaskr April 10, 2026 09:16
Copy link
Copy Markdown

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

♻️ Duplicate comments (2)
tops/cpu/ops/kda/fused_recurrent.py (2)

285-287: ⚠️ Potential issue | 🔴 Critical

Propagate gradients through L2-normalization Jacobian when use_qk_l2norm_in_kernel=True.

When L2 normalization is applied, the returned dq_f and dk_f are gradients with respect to the normalized vectors, not the original inputs. The chain rule requires back-propagating through the normalization operation. Currently, this produces incorrect gradients when use_qk_l2norm_in_kernel=True.

Suggested fix using JAX VJP
+    q_raw = q_f
+    k_raw = k_f
     if use_qk_l2norm_in_kernel:
         q_f = _l2_normalize_last_dim(q_f)
         k_f = _l2_normalize_last_dim(k_f)

Then before the final transpose (around line 406):

+    if use_qk_l2norm_in_kernel:
+        _, q_vjp = jax.vjp(_l2_normalize_last_dim, q_raw)
+        (dq_f,) = q_vjp(dq_f)
+        _, k_vjp = jax.vjp(_l2_normalize_last_dim, k_raw)
+        (dk_f,) = k_vjp(dk_f)
+
     dq = jnp.transpose(dq_f, (0, 2, 1, 3)).astype(q.dtype)
     dk = jnp.transpose(dk_f, (0, 2, 1, 3)).astype(k.dtype)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tops/cpu/ops/kda/fused_recurrent.py` around lines 285 - 287, When
use_qk_l2norm_in_kernel is true, dq_f and dk_f are gradients w.r.t. the
normalized outputs of _l2_normalize_last_dim, so you must backprop through that
normalization: compute VJPs of _l2_normalize_last_dim for q_f and k_f (e.g.,
call jax.vjp or the equivalent vjp function used in this code) to convert
dq_f/dk_f into gradients w.r.t. the original q_f and k_f, and replace dq_f/dk_f
with those propagated gradients before the code continues to the final transpose
step; ensure you use the exact _l2_normalize_last_dim function and apply the VJP
outputs in place of dq_f and dk_f so the chain rule is satisfied when
use_qk_l2norm_in_kernel=True.

289-295: ⚠️ Potential issue | 🔴 Critical

Incorrect tensor layout passed to fused_kda_gate in backward.

The forward pass calls fused_kda_gate(g_h, ...) with g_h in [B, T, H, K] layout. However, this backward pass transposes g to [B, H, T, K] before calling the gate, causing a layout mismatch. This means the gate applies its per-head parameters (A_log) across the wrong dimension.

Suggested fix
     if use_gate_in_kernel:
         assert A_log is not None
-        g_raw = jnp.transpose(g, (0, 2, 1, 3))
-        g_f = fused_kda_gate(
-            g_raw, A_log, dt_bias=dt_bias,
+        g_f = jnp.transpose(
+            fused_kda_gate(
+                g, A_log, dt_bias=dt_bias,
+                lower_bound=lower_bound, output_dtype=acc,
+            ),
+            (0, 2, 1, 3),
-            lower_bound=lower_bound, output_dtype=acc,
         )
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tops/cpu/ops/kda/fused_recurrent.py` around lines 289 - 295, The backward
pass incorrectly transposes g to g_raw = jnp.transpose(g, (0, 2, 1, 3))
producing [B,H,T,K] before calling fused_kda_gate, but the forward used g_h in
[B,T,H,K]; remove the transpose and pass g (or a variable named g_h) directly to
fused_kda_gate so the tensor layout matches A_log's per-head parameters (update
the call site in the use_gate_in_kernel branch where fused_kda_gate(g_raw,
A_log, dt_bias=dt_bias, lower_bound=lower_bound, output_dtype=acc) is invoked).
🧹 Nitpick comments (3)
tops/cpu/ops/kda/naive.py (1)

221-221: Consider renaming masks to avoid shadowing.

The variable mask is defined twice with different semantics: line 221 creates an upper-triangular mask including the diagonal (for zeroing A's upper triangle), while line 265 creates a strictly upper-triangular mask excluding the diagonal (for causal attention). Distinct names would improve clarity.

♻️ Suggested naming
-  mask = jnp.triu(jnp.ones((BT, BT), dtype=jnp.bool_))  # [BT, BT]
+  upper_mask = jnp.triu(jnp.ones((BT, BT), dtype=jnp.bool_))  # [BT, BT] includes diagonal
-  A = jnp.where(mask, 0, -A)  # [B, H, NT, BT, BT]  zero upper triangle, negate lower
+  A = jnp.where(upper_mask, 0, -A)  # [B, H, NT, BT, BT]  zero upper triangle, negate lower
-  mask = jnp.triu(jnp.ones((BT, BT), dtype=jnp.bool_), k=1)  # [BT, BT] strict upper triangle
+  causal_mask = jnp.triu(jnp.ones((BT, BT), dtype=jnp.bool_), k=1)  # [BT, BT] strict upper triangle
-    A_qk = jnp.where(mask, 0, A_qk)  # [B, H, BT, BT]
+    A_qk = jnp.where(causal_mask, 0, A_qk)  # [B, H, BT, BT]

Also applies to: 265-265

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tops/cpu/ops/kda/naive.py` at line 221, The code defines two different masks
using the same name `mask` (one is an upper-triangular including the diagonal
created via jnp.triu(jnp.ones((BT, BT), ...)) and the other is a strictly
upper-triangular/causal mask); rename them to distinct, descriptive names (e.g.,
mask_triu_diag or A_mask for the one used to zero A's upper triangle, and
mask_triu_strict or causal_mask for the strictly upper-triangular causal
attention mask) and update all subsequent uses of `mask` in the surrounding
function(s) so each use references the correctly renamed symbol.
tops/cpu/ops/kda/fused_recurrent.py (2)

565-565: Consider sorting __all__ for consistency.

Static analysis suggests sorting __all__ alphabetically for consistency with isort-style conventions.

Suggested fix
-__all__ = ["fused_recurrent_kda_fwd", "fused_recurrent_kda_bwd", "fused_recurrent_kda"]
+__all__ = ["fused_recurrent_kda", "fused_recurrent_kda_bwd", "fused_recurrent_kda_fwd"]
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tops/cpu/ops/kda/fused_recurrent.py` at line 565, The __all__ export list is
unsorted; please alphabetize the entries so exports follow a consistent
order—replace __all__ = ["fused_recurrent_kda_fwd", "fused_recurrent_kda_bwd",
"fused_recurrent_kda"] with a sorted list (e.g., ["fused_recurrent_kda",
"fused_recurrent_kda_bwd", "fused_recurrent_kda_fwd"]) ensuring the symbols
fused_recurrent_kda, fused_recurrent_kda_bwd, and fused_recurrent_kda_fwd remain
present and correctly quoted.

115-118: Unused variable H can be suppressed.

The variable H is unpacked but never used. Consider using an underscore prefix to suppress the linter warning.

Suggested fix
-    B, T, H, K = q.shape
+    B, T, _H, K = q.shape
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tops/cpu/ops/kda/fused_recurrent.py` around lines 115 - 118, The unpacking of
q.shape assigns H but it is never used; change the unpack to suppress the linter
by replacing H with a throwaway name (e.g., _ or _H) in the tuple assignment so
it becomes B, T, _H, K = q.shape (or B, T, _, K = q.shape), leaving the
subsequent uses of q, v, and cu_seqlens unchanged; ensure the rest of the
function (references to q, v, cu_seqlens, and variables B, T, K, V, HV, N) still
work after the rename.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@tops/cpu/ops/kda/fused_recurrent.py`:
- Around line 524-562: The public function fused_recurrent_kda currently has
only a one-line docstring; replace it with a comprehensive docstring that (1)
describes the high-level business semantics of fused_recurrent_kda and how it
relates to fused_recurrent_kda_fwd, (2) lists every parameter (q, k, v, g, beta,
A_log, dt_bias, scale, initial_state, output_final_state,
use_qk_l2norm_in_kernel, use_gate_in_kernel, lower_bound, cu_seqlens,
transpose_state_layout, **kwargs) with expected shapes, dtypes and axis meanings
(e.g. batch, sequence, heads, head_dim), which args are optional and default
behaviors, (3) documents the return value tuple and shapes/types (output tensor
and optional final state when output_final_state=True), and (4) notes any
layout/format conventions (e.g. state layout when transpose_state_layout
toggled), edge cases and side effects; add short usage examples and mention any
requirements/constraints (e.g. broadcasting rules, supported dtypes) to make the
API clear to callers.
- Around line 476-521: The public function fused_recurrent_kda_fwd is missing a
comprehensive docstring; add a multi-line docstring (modeled on
fused_recurrent_kda_bwd) that explains the operation semantics and documents
every parameter and return value with precise tensor shapes and dimension
meanings for q, k, v, g, beta, A_log, dt_bias, initial_state, scale,
output_final_state, inplace_final_state, cu_seqlens, ssm_state_indices,
num_accepted_tokens, use_qk_l2norm_in_kernel, use_gate_in_kernel, lower_bound,
out, and transpose_state_layout; include return shapes (output array and
optional final state), any optional semantics (e.g., when output_final_state
True), allowed dtypes, and any side-effects (inplace_final_state behavior) so
callers have complete usage guidance.

---

Duplicate comments:
In `@tops/cpu/ops/kda/fused_recurrent.py`:
- Around line 285-287: When use_qk_l2norm_in_kernel is true, dq_f and dk_f are
gradients w.r.t. the normalized outputs of _l2_normalize_last_dim, so you must
backprop through that normalization: compute VJPs of _l2_normalize_last_dim for
q_f and k_f (e.g., call jax.vjp or the equivalent vjp function used in this
code) to convert dq_f/dk_f into gradients w.r.t. the original q_f and k_f, and
replace dq_f/dk_f with those propagated gradients before the code continues to
the final transpose step; ensure you use the exact _l2_normalize_last_dim
function and apply the VJP outputs in place of dq_f and dk_f so the chain rule
is satisfied when use_qk_l2norm_in_kernel=True.
- Around line 289-295: The backward pass incorrectly transposes g to g_raw =
jnp.transpose(g, (0, 2, 1, 3)) producing [B,H,T,K] before calling
fused_kda_gate, but the forward used g_h in [B,T,H,K]; remove the transpose and
pass g (or a variable named g_h) directly to fused_kda_gate so the tensor layout
matches A_log's per-head parameters (update the call site in the
use_gate_in_kernel branch where fused_kda_gate(g_raw, A_log, dt_bias=dt_bias,
lower_bound=lower_bound, output_dtype=acc) is invoked).

---

Nitpick comments:
In `@tops/cpu/ops/kda/fused_recurrent.py`:
- Line 565: The __all__ export list is unsorted; please alphabetize the entries
so exports follow a consistent order—replace __all__ =
["fused_recurrent_kda_fwd", "fused_recurrent_kda_bwd", "fused_recurrent_kda"]
with a sorted list (e.g., ["fused_recurrent_kda", "fused_recurrent_kda_bwd",
"fused_recurrent_kda_fwd"]) ensuring the symbols fused_recurrent_kda,
fused_recurrent_kda_bwd, and fused_recurrent_kda_fwd remain present and
correctly quoted.
- Around line 115-118: The unpacking of q.shape assigns H but it is never used;
change the unpack to suppress the linter by replacing H with a throwaway name
(e.g., _ or _H) in the tuple assignment so it becomes B, T, _H, K = q.shape (or
B, T, _, K = q.shape), leaving the subsequent uses of q, v, and cu_seqlens
unchanged; ensure the rest of the function (references to q, v, cu_seqlens, and
variables B, T, K, V, HV, N) still work after the rename.

In `@tops/cpu/ops/kda/naive.py`:
- Line 221: The code defines two different masks using the same name `mask` (one
is an upper-triangular including the diagonal created via jnp.triu(jnp.ones((BT,
BT), ...)) and the other is a strictly upper-triangular/causal mask); rename
them to distinct, descriptive names (e.g., mask_triu_diag or A_mask for the one
used to zero A's upper triangle, and mask_triu_strict or causal_mask for the
strictly upper-triangular causal attention mask) and update all subsequent uses
of `mask` in the surrounding function(s) so each use references the correctly
renamed symbol.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 19e8ed5b-0042-4ef7-8797-8581757a146b

📥 Commits

Reviewing files that changed from the base of the PR and between 9cc7b30 and f73a524.

📒 Files selected for processing (2)
  • tops/cpu/ops/kda/fused_recurrent.py
  • tops/cpu/ops/kda/naive.py

@lingebeng lingebeng changed the title [Feat] Add JAX CPU ref KDA [Feat] Add JAX cpu-ref of KDA (naive and recurrent) Apr 11, 2026
@0xaskr
Copy link
Copy Markdown
Collaborator

0xaskr commented Apr 13, 2026

LGTM

@0xaskr 0xaskr added this pull request to the merge queue Apr 13, 2026
Merged via the queue into main with commit 734cf52 Apr 13, 2026
3 of 4 checks passed
@0xaskr 0xaskr deleted the feat/kda-cpu-gold-new branch April 13, 2026 02:19
@FENP
Copy link
Copy Markdown
Contributor

FENP commented Apr 13, 2026

LGTM


# =========================================================================
# Step 2: Solve lower-triangular dependency system
# Forward substitution: A[i, :i] += A[i, :] @ A[:, :i]
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Rather than forward substitution, this code resembles a row-wise iterative implementation of the Neumann series expansion.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You're right,thanks a lot. I'll fix the comment in next PR.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

cpu-ref Modifies tops/cpu/ reference implementations

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants