[Feat] Add JAX cpu-ref of KDA (naive and recurrent)#175
Conversation
- Add strict input shape assertions to naive_recurrent_kda and naive_chunk_kda matching chunk_kda and GLA patterns - Fix _acc_dtype(q) -> _acc_dtype(q.dtype) in naive_recurrent_kda so fp64 inputs correctly use fp64 accumulator - Add T-to-chunk_size padding in naive_chunk_kda for non-aligned sequence lengths - Replace test_chunk_kda.py with test_naive_kda.py: 48 tests covering dtype verification, input assertions, recurrent-vs-chunk cross validation (fp32/fp64), backward cross-validation, and FLA Triton comparison (GPU-optional) - Remove chunk.py (not yet implemented) and clean up __init__.py Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
📝 WalkthroughWalkthroughRefactors CPU KDA references: removes legacy Changes
Sequence Diagram(s)sequenceDiagram
participant Test as Test Suite
participant API as fused_recurrent_kda_fwd
participant Gate as Gate Helper
participant Accum as Accumulator Loop
participant State as Hidden State
Test->>API: call(q,k,v,g,beta,initial_state,cu_seqlens,A_log,dt_bias,lower_bound)
API->>Gate: expand head params, (opt) L2-normalize Q/K
API->>Gate: compute gate = fused_kda_gate(g,A_log,dt_bias,lower_bound)
Gate-->>API: gate
API->>Accum: init accumulator/state
loop over timesteps (or cu_seqlens segments)
Accum->>State: apply gate decay, compute h_prime
State->>Accum: out[t] = combine(h_prime,k,v)
Accum->>State: update h via delta-rule using beta and v-correction
end
Accum-->>API: outputs, final_state
API-->>Test: return (out, final_state)
sequenceDiagram
participant Test as Test Suite
participant Bwd as fused_recurrent_kda_bwd
participant Replay as Forward Replay
participant BackPass as Reverse Pass
participant Grad as Gradient Accumulator
Test->>Bwd: call(q,k,v,g,beta,dout,dht,initial_state,cu_seqlens,A_log,dt_bias)
Bwd->>Replay: re-run forward, store h_prime_all, h_all, gate_all
Replay-->>Bwd: intermediate states
Bwd->>BackPass: compute per-timestep dq from dout
BackPass->>Grad: loop reverse-time, accumulate dk, dv, dg, dbeta
BackPass-->>Bwd: gradients (dq, dk, dv, dg, dbeta, dh0)
Bwd-->>Test: return gradients
Estimated code review effort🎯 4 (Complex) | ⏱️ ~60 minutes Possibly related issues
Possibly related PRs
Suggested reviewers
🚥 Pre-merge checks | ✅ 2 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches📝 Generate docstrings
🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Code Review
This pull request reorganizes the KDA (Kernel Delta Attention) CPU reference implementation by introducing a fused recurrent version, a dedicated gate module, and updated naive and chunked implementations alongside comprehensive tests. Feedback highlights a critical bug in the intra-chunk dependency logic of the chunked implementation that results in incorrect updates. Additionally, several opportunities for performance optimization are identified, including vectorizing matrix construction loops and removing redundant dtype casting logic. Finally, a request was made to maintain language consistency by using English for all code comments.
- fused_recurrent: pre-compute gate before passing to FLA Triton, which ignores use_gate_in_kernel/A_log/dt_bias via **kwargs - naive: relax fp32/fp64 tolerance (5e-5 → 5e-3) for large shapes where GPU vs CPU fp32 rounding accumulates over T=64 recurrence steps Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
There was a problem hiding this comment.
Actionable comments posted: 6
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@tests/ref/kda/test_fused_recurrent_kda.py`:
- Around line 1-10: Add a backward cross-check test that calls
fused_recurrent_kda_bwd and compares its gradients to a naive/reference backward
implementation (e.g., compute grads via naive_recurrent or autograd on
fused_recurrent_kda_fwd) using tolerance-based assertions; the test should
exercise both boolean branches use_gate_in_kernel and use_qk_l2norm_in_kernel
(run subcases with each True/False), include varlen and fixed-length cases
analogous to existing forward tests, and assert close equality for gradient
tensors (w.r.t. inputs, gates, and parameters) to catch regressions.
In `@tests/ref/kda/test_naive_kda.py`:
- Around line 282-290: The test currently drops the first time column for all
"dg" comparisons by slicing ref/test to skip t=0; instead make this conditional
on the zero-initial-state case (the h0 setting from _XVAL_SHAPES). Change the
block guarded by if name == "dg" to only remove the t=0 column when the test
config indicates no initial state (e.g., h0 is None or False) — leave ref and
test intact when h0 is True so real backward mismatches are not masked;
reference the variables name, ref, test and the _XVAL_SHAPES/h0 flag to locate
and update the conditional.
In `@tops/cpu/ops/kda/fused_recurrent.py`:
- Around line 285-287: When use_qk_l2norm_in_kernel=True the code normalizes q_f
and k_f but returns dq_f/dk_f as if they were gradients w.r.t. the original
(unnormalized) inputs; fix this by preserving the pre-normalized tensors (e.g.
q_raw = q_f; k_raw = k_f), compute norms (norm_q = sqrt(sum(q_raw**2, axis=-1,
keepdims=True)) and same for k), then after replay convert the gradient on the
normalized vectors back to gradients on the raw inputs using the
L2-normalization Jacobian: dq_raw = (dq_f - q_norm * sum(q_norm * dq_f, axis=-1,
keepdims=True)) / norm_q and similarly for dk_raw with k_norm and norm_k;
replace returns of dq_f/dk_f with these propagated dq_raw/dk_raw and apply the
same change where _l2_normalize_last_dim is used (also update the analogous
block around the 406-418 region).
- Around line 289-295: The branch for use_gate_in_kernel currently transposes g
to g_raw = jnp.transpose(g, (0, 2, 1, 3)) which produces [B, H, T, K] and passes
that into fused_kda_gate; fused_kda_gate expects [B, T, H, K] (penultimate dim =
heads), causing incorrect parameter alignment and broken backward passes. Fix by
passing g to fused_kda_gate in [B, T, H, K] order (i.e., do not transpose g
before calling fused_kda_gate) so fused_kda_gate(g, A_log, dt_bias=dt_bias,
lower_bound=lower_bound, output_dtype=acc) receives the correct layout and head
dimension matches A_log.
In `@tops/cpu/ops/kda/gate.py`:
- Around line 60-95: Add full docstrings to naive_kda_gate and
naive_kda_lowerbound_gate describing semantic meaning and exact tensor
shapes/dtypes for every input and output (e.g., g: jax.Array with shape [..., H,
K], A_log: head-wise log amplitudes with shape [H] or [H,1], dt_bias: optional
per-head or per-head-per-channel bias shape [H] or [H,K], output: same leading
dims as g with last two dims matching H,K and dtype output_dtype). At the top of
each function replace the minimal checks with strict assertions using
tops.utils.assert_shape_or_none and explicit type checks: assert g.ndim >= 2 and
g.dtype is a numeric jax dtype, assert_shape_or_none(A_log, (H,)) (or (H,1)) to
match how A_f is reshaped, assert_shape_or_none(dt_bias, (H,)) or (H,K) and that
dt_bias dtype is numeric or None, and for naive_kda_lowerbound_gate assert
isinstance(lower_bound, float) and isfinite. Ensure these assertions run before
any casting/logic, and update the docstrings to reference the parameter
semantics and return dtype (output_dtype). Use the existing helper
_expand_headwise_params and _acc_dtype but document their expectations in the
docstrings.
In `@tops/cpu/ops/kda/naive.py`:
- Around line 177-190: Add an explicit assertion that chunk_size is a positive
integer before any padding/math using it: insert something like assert
isinstance(chunk_size, int) and assert chunk_size > 0 (or use the project's
validation helper) just before computing T_padded where chunk_size/BT and
_cdiv(T, BT) are used; reference the symbols chunk_size (BT), _cdiv, and the
T_padded computation so the check prevents invalid chunk_size values from
reaching _cdiv and follows the project's assert-based input validation policy.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: ede1e66e-29ad-4c02-b747-887895a60f77
📒 Files selected for processing (8)
tests/ref/kda/test_chunk_kda.pytests/ref/kda/test_fused_recurrent_kda.pytests/ref/kda/test_naive_kda.pytops/cpu/ops/kda/__init__.pytops/cpu/ops/kda/chunk.pytops/cpu/ops/kda/fused_recurrent.pytops/cpu/ops/kda/gate.pytops/cpu/ops/kda/naive.py
💤 Files with no reviewable changes (2)
- tests/ref/kda/test_chunk_kda.py
- tops/cpu/ops/kda/chunk.py
| """fused_recurrent_kda(+_fwd) / kda_gate_*: JAX CPU ref (tops.cpu.ops.kda) tests. | ||
|
|
||
| Tests: | ||
| 1. Gate formula verification (no GPU) | ||
| 2. Gate chunk cumsum verification (no GPU, varlen) | ||
| 3. Dtype verification (no GPU) | ||
| 4. Cross-validation: fused_recurrent vs naive_recurrent (no GPU) | ||
| 5. Wrapper equivalence: fused_recurrent_kda_fwd vs fused_recurrent_kda (no GPU) | ||
| 6. Varlen: cu_seqlens vs per-segment calls (no GPU) | ||
| """ |
There was a problem hiding this comment.
Add a backward cross-check for fused_recurrent_kda_bwd.
This suite is forward-only right now. The new backward reference API and its use_gate_in_kernel / use_qk_l2norm_in_kernel branches are never exercised, so gradient regressions can slip through untested.
Based on learnings "Each Jax/Pallas kernel implementation must have a corresponding CPU reference test comparing the optimized kernel against naive implementations with tolerance-based assertions".
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@tests/ref/kda/test_fused_recurrent_kda.py` around lines 1 - 10, Add a
backward cross-check test that calls fused_recurrent_kda_bwd and compares its
gradients to a naive/reference backward implementation (e.g., compute grads via
naive_recurrent or autograd on fused_recurrent_kda_fwd) using tolerance-based
assertions; the test should exercise both boolean branches use_gate_in_kernel
and use_qk_l2norm_in_kernel (run subcases with each True/False), include varlen
and fixed-length cases analogous to existing forward tests, and assert close
equality for gradient tensors (w.r.t. inputs, gates, and parameters) to catch
regressions.
| if name == "dg": | ||
| # At t=0 with zero initial state, dg_0 = 0 mathematically because | ||
| # exp(g_0) * 0 has no dependency on g_0. But the chunk version's | ||
| # cumsum-based computation creates AD paths that produce small | ||
| # numerical residuals. Skip t=0 for dg comparison. | ||
| if ref.shape[1] <= 1: | ||
| continue | ||
| ref = ref[:, 1:] | ||
| test = test[:, 1:] |
There was a problem hiding this comment.
Only skip dg[:, 0] for the zero-initial-state cases.
The rationale here assumes h0 is None, but _XVAL_SHAPES also includes an h0=True case. In that configuration, g at t=0 affects the decayed initial state, so slicing off the first step masks real backward mismatches.
Suggested fix
- if name == "dg":
+ if name == "dg" and h0 is None:
# At t=0 with zero initial state, dg_0 = 0 mathematically because
# exp(g_0) * 0 has no dependency on g_0. But the chunk version's
# cumsum-based computation creates AD paths that produce small
# numerical residuals. Skip t=0 for dg comparison.
if ref.shape[1] <= 1:📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| if name == "dg": | |
| # At t=0 with zero initial state, dg_0 = 0 mathematically because | |
| # exp(g_0) * 0 has no dependency on g_0. But the chunk version's | |
| # cumsum-based computation creates AD paths that produce small | |
| # numerical residuals. Skip t=0 for dg comparison. | |
| if ref.shape[1] <= 1: | |
| continue | |
| ref = ref[:, 1:] | |
| test = test[:, 1:] | |
| if name == "dg" and h0 is None: | |
| # At t=0 with zero initial state, dg_0 = 0 mathematically because | |
| # exp(g_0) * 0 has no dependency on g_0. But the chunk version's | |
| # cumsum-based computation creates AD paths that produce small | |
| # numerical residuals. Skip t=0 for dg comparison. | |
| if ref.shape[1] <= 1: | |
| continue | |
| ref = ref[:, 1:] | |
| test = test[:, 1:] |
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@tests/ref/kda/test_naive_kda.py` around lines 282 - 290, The test currently
drops the first time column for all "dg" comparisons by slicing ref/test to skip
t=0; instead make this conditional on the zero-initial-state case (the h0
setting from _XVAL_SHAPES). Change the block guarded by if name == "dg" to only
remove the t=0 column when the test config indicates no initial state (e.g., h0
is None or False) — leave ref and test intact when h0 is True so real backward
mismatches are not masked; reference the variables name, ref, test and the
_XVAL_SHAPES/h0 flag to locate and update the conditional.
| if use_qk_l2norm_in_kernel: | ||
| q_f = _l2_normalize_last_dim(q_f) | ||
| k_f = _l2_normalize_last_dim(k_f) |
There was a problem hiding this comment.
Propagate dq/dk through the L2-normalization Jacobian.
When use_qk_l2norm_in_kernel=True, the replay runs on normalized q_f and k_f, but the function returns dq_f and dk_f directly as if they were gradients with respect to the raw inputs. That drops the chain rule and makes the backward reference incorrect for normalized Q/K.
Suggested fix
- q_f = jnp.transpose(q, (0, 2, 1, 3)).astype(acc) # [B, H, T, K]
- k_f = jnp.transpose(k, (0, 2, 1, 3)).astype(acc) # [B, H, T, K]
+ q_in = jnp.transpose(q, (0, 2, 1, 3)).astype(acc) # [B, H, T, K]
+ k_in = jnp.transpose(k, (0, 2, 1, 3)).astype(acc) # [B, H, T, K]
+ q_f = q_in
+ k_f = k_in
v_f = jnp.transpose(v, (0, 2, 1, 3)).astype(acc) # [B, H, T, V]
g_f = jnp.transpose(g, (0, 2, 1, 3)).astype(acc) # [B, H, T, K]
beta_f = jnp.transpose(beta, (0, 2, 1)).astype(acc) # [B, H, T]
do_f = jnp.transpose(do, (0, 2, 1, 3)).astype(acc) # [B, H, T, V]
if use_qk_l2norm_in_kernel:
q_f = _l2_normalize_last_dim(q_f)
k_f = _l2_normalize_last_dim(k_f)
@@
- dq = jnp.transpose(dq_f, (0, 2, 1, 3)).astype(q.dtype)
- dk = jnp.transpose(dk_f, (0, 2, 1, 3)).astype(k.dtype)
+ if use_qk_l2norm_in_kernel:
+ _, q_pullback = jax.vjp(_l2_normalize_last_dim, q_in)
+ (dq_f,) = q_pullback(dq_f)
+ _, k_pullback = jax.vjp(_l2_normalize_last_dim, k_in)
+ (dk_f,) = k_pullback(dk_f)
+
+ dq = jnp.transpose(dq_f, (0, 2, 1, 3)).astype(q.dtype)
+ dk = jnp.transpose(dk_f, (0, 2, 1, 3)).astype(k.dtype)Also applies to: 406-418
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@tops/cpu/ops/kda/fused_recurrent.py` around lines 285 - 287, When
use_qk_l2norm_in_kernel=True the code normalizes q_f and k_f but returns
dq_f/dk_f as if they were gradients w.r.t. the original (unnormalized) inputs;
fix this by preserving the pre-normalized tensors (e.g. q_raw = q_f; k_raw =
k_f), compute norms (norm_q = sqrt(sum(q_raw**2, axis=-1, keepdims=True)) and
same for k), then after replay convert the gradient on the normalized vectors
back to gradients on the raw inputs using the L2-normalization Jacobian: dq_raw
= (dq_f - q_norm * sum(q_norm * dq_f, axis=-1, keepdims=True)) / norm_q and
similarly for dk_raw with k_norm and norm_k; replace returns of dq_f/dk_f with
these propagated dq_raw/dk_raw and apply the same change where
_l2_normalize_last_dim is used (also update the analogous block around the
406-418 region).
| if use_gate_in_kernel: | ||
| assert A_log is not None | ||
| g_raw = jnp.transpose(g, (0, 2, 1, 3)) | ||
| g_f = fused_kda_gate( | ||
| g_raw, A_log, dt_bias=dt_bias, | ||
| lower_bound=lower_bound, output_dtype=acc, | ||
| ) |
There was a problem hiding this comment.
Call fused_kda_gate with [B, T, H, K], not [B, H, T, K].
fused_kda_gate treats the penultimate dimension as heads. This branch transposes g first, so the helper sees T as the head count and either trips the A_log shape assertion or applies head parameters across time. The backward path is therefore wrong whenever use_gate_in_kernel=True.
Suggested fix
if use_gate_in_kernel:
assert A_log is not None
- g_raw = jnp.transpose(g, (0, 2, 1, 3))
- g_f = fused_kda_gate(
- g_raw, A_log, dt_bias=dt_bias,
- lower_bound=lower_bound, output_dtype=acc,
- )
+ g_f = jnp.transpose(
+ fused_kda_gate(
+ g,
+ A_log,
+ dt_bias=dt_bias,
+ lower_bound=lower_bound,
+ output_dtype=acc,
+ ),
+ (0, 2, 1, 3),
+ )📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| if use_gate_in_kernel: | |
| assert A_log is not None | |
| g_raw = jnp.transpose(g, (0, 2, 1, 3)) | |
| g_f = fused_kda_gate( | |
| g_raw, A_log, dt_bias=dt_bias, | |
| lower_bound=lower_bound, output_dtype=acc, | |
| ) | |
| if use_gate_in_kernel: | |
| assert A_log is not None | |
| g_f = jnp.transpose( | |
| fused_kda_gate( | |
| g, | |
| A_log, | |
| dt_bias=dt_bias, | |
| lower_bound=lower_bound, | |
| output_dtype=acc, | |
| ), | |
| (0, 2, 1, 3), | |
| ) |
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@tops/cpu/ops/kda/fused_recurrent.py` around lines 289 - 295, The branch for
use_gate_in_kernel currently transposes g to g_raw = jnp.transpose(g, (0, 2, 1,
3)) which produces [B, H, T, K] and passes that into fused_kda_gate;
fused_kda_gate expects [B, T, H, K] (penultimate dim = heads), causing incorrect
parameter alignment and broken backward passes. Fix by passing g to
fused_kda_gate in [B, T, H, K] order (i.e., do not transpose g before calling
fused_kda_gate) so fused_kda_gate(g, A_log, dt_bias=dt_bias,
lower_bound=lower_bound, output_dtype=acc) receives the correct layout and head
dimension matches A_log.
| def naive_kda_gate( | ||
| g: jax.Array, | ||
| A_log: jax.Array, | ||
| dt_bias: jax.Array | None = None, | ||
| output_dtype: jnp.dtype = jnp.float32, | ||
| ) -> jax.Array: | ||
| """Reference implementation of the standard KDA gate.""" | ||
| assert g.ndim >= 2, f"g must have at least 2 dims, got shape {g.shape}" | ||
| H, K = g.shape[-2:] | ||
| acc = _acc_dtype(g.dtype) | ||
| g_f = g.astype(acc) | ||
| A_f, bias_f = _expand_headwise_params(A_log, dt_bias, H, K, acc) | ||
| if bias_f is not None: | ||
| g_f = g_f + bias_f.reshape(H, K) | ||
| out = -jnp.exp(A_f.reshape(H, 1)) * jax.nn.softplus(g_f) | ||
| return out.astype(output_dtype) | ||
|
|
||
|
|
||
| def naive_kda_lowerbound_gate( | ||
| g: jax.Array, | ||
| A_log: jax.Array, | ||
| dt_bias: jax.Array | None = None, | ||
| lower_bound: float = -5.0, | ||
| output_dtype: jnp.dtype = jnp.float32, | ||
| ) -> jax.Array: | ||
| """Reference implementation of the lower-bounded KDA gate.""" | ||
| assert g.ndim >= 2, f"g must have at least 2 dims, got shape {g.shape}" | ||
| H, K = g.shape[-2:] | ||
| acc = _acc_dtype(g.dtype) | ||
| g_f = g.astype(acc) | ||
| A_f, bias_f = _expand_headwise_params(A_log, dt_bias, H, K, acc) | ||
| if bias_f is not None: | ||
| g_f = g_f + bias_f.reshape(H, K) | ||
| out = lower_bound * jax.nn.sigmoid(jnp.exp(A_f.reshape(H, 1)) * g_f) | ||
| return out.astype(output_dtype) | ||
|
|
There was a problem hiding this comment.
🛠️ Refactor suggestion | 🟠 Major
Document and harden the new public gate APIs.
These exported functions only have minimal one-line docstrings, and the validation is partial. The repo contract for public Python APIs is stricter: each function should describe tensor semantics/shapes for every input/output and assert the relevant type/shape invariants up front.
As per coding guidelines **/*.py: All public functions must have a comprehensive docstring that explains the business semantics of the function and clearly details tensor shapes and dimension meanings for every input and output argument; All public functions must enforce strict input assertions on shape and types before executing main logic using assert instructions or utilities like assert_shape_or_none from tops.utils.
Also applies to: 97-204
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@tops/cpu/ops/kda/gate.py` around lines 60 - 95, Add full docstrings to
naive_kda_gate and naive_kda_lowerbound_gate describing semantic meaning and
exact tensor shapes/dtypes for every input and output (e.g., g: jax.Array with
shape [..., H, K], A_log: head-wise log amplitudes with shape [H] or [H,1],
dt_bias: optional per-head or per-head-per-channel bias shape [H] or [H,K],
output: same leading dims as g with last two dims matching H,K and dtype
output_dtype). At the top of each function replace the minimal checks with
strict assertions using tops.utils.assert_shape_or_none and explicit type
checks: assert g.ndim >= 2 and g.dtype is a numeric jax dtype,
assert_shape_or_none(A_log, (H,)) (or (H,1)) to match how A_f is reshaped,
assert_shape_or_none(dt_bias, (H,)) or (H,K) and that dt_bias dtype is numeric
or None, and for naive_kda_lowerbound_gate assert isinstance(lower_bound, float)
and isfinite. Ensure these assertions run before any casting/logic, and update
the docstrings to reference the parameter semantics and return dtype
(output_dtype). Use the existing helper _expand_headwise_params and _acc_dtype
but document their expectations in the docstrings.
|
@copilot fix last ci error. |
There was a problem hiding this comment.
Actionable comments posted: 2
♻️ Duplicate comments (2)
tops/cpu/ops/kda/fused_recurrent.py (2)
285-287:⚠️ Potential issue | 🔴 CriticalPropagate gradients through L2-normalization Jacobian when
use_qk_l2norm_in_kernel=True.When L2 normalization is applied, the returned
dq_fanddk_fare gradients with respect to the normalized vectors, not the original inputs. The chain rule requires back-propagating through the normalization operation. Currently, this produces incorrect gradients whenuse_qk_l2norm_in_kernel=True.Suggested fix using JAX VJP
+ q_raw = q_f + k_raw = k_f if use_qk_l2norm_in_kernel: q_f = _l2_normalize_last_dim(q_f) k_f = _l2_normalize_last_dim(k_f)Then before the final transpose (around line 406):
+ if use_qk_l2norm_in_kernel: + _, q_vjp = jax.vjp(_l2_normalize_last_dim, q_raw) + (dq_f,) = q_vjp(dq_f) + _, k_vjp = jax.vjp(_l2_normalize_last_dim, k_raw) + (dk_f,) = k_vjp(dk_f) + dq = jnp.transpose(dq_f, (0, 2, 1, 3)).astype(q.dtype) dk = jnp.transpose(dk_f, (0, 2, 1, 3)).astype(k.dtype)🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tops/cpu/ops/kda/fused_recurrent.py` around lines 285 - 287, When use_qk_l2norm_in_kernel is true, dq_f and dk_f are gradients w.r.t. the normalized outputs of _l2_normalize_last_dim, so you must backprop through that normalization: compute VJPs of _l2_normalize_last_dim for q_f and k_f (e.g., call jax.vjp or the equivalent vjp function used in this code) to convert dq_f/dk_f into gradients w.r.t. the original q_f and k_f, and replace dq_f/dk_f with those propagated gradients before the code continues to the final transpose step; ensure you use the exact _l2_normalize_last_dim function and apply the VJP outputs in place of dq_f and dk_f so the chain rule is satisfied when use_qk_l2norm_in_kernel=True.
289-295:⚠️ Potential issue | 🔴 CriticalIncorrect tensor layout passed to
fused_kda_gatein backward.The forward pass calls
fused_kda_gate(g_h, ...)withg_hin[B, T, H, K]layout. However, this backward pass transposesgto[B, H, T, K]before calling the gate, causing a layout mismatch. This means the gate applies its per-head parameters (A_log) across the wrong dimension.Suggested fix
if use_gate_in_kernel: assert A_log is not None - g_raw = jnp.transpose(g, (0, 2, 1, 3)) - g_f = fused_kda_gate( - g_raw, A_log, dt_bias=dt_bias, + g_f = jnp.transpose( + fused_kda_gate( + g, A_log, dt_bias=dt_bias, + lower_bound=lower_bound, output_dtype=acc, + ), + (0, 2, 1, 3), - lower_bound=lower_bound, output_dtype=acc, )🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tops/cpu/ops/kda/fused_recurrent.py` around lines 289 - 295, The backward pass incorrectly transposes g to g_raw = jnp.transpose(g, (0, 2, 1, 3)) producing [B,H,T,K] before calling fused_kda_gate, but the forward used g_h in [B,T,H,K]; remove the transpose and pass g (or a variable named g_h) directly to fused_kda_gate so the tensor layout matches A_log's per-head parameters (update the call site in the use_gate_in_kernel branch where fused_kda_gate(g_raw, A_log, dt_bias=dt_bias, lower_bound=lower_bound, output_dtype=acc) is invoked).
🧹 Nitpick comments (3)
tops/cpu/ops/kda/naive.py (1)
221-221: Consider renaming masks to avoid shadowing.The variable
maskis defined twice with different semantics: line 221 creates an upper-triangular mask including the diagonal (for zeroing A's upper triangle), while line 265 creates a strictly upper-triangular mask excluding the diagonal (for causal attention). Distinct names would improve clarity.♻️ Suggested naming
- mask = jnp.triu(jnp.ones((BT, BT), dtype=jnp.bool_)) # [BT, BT] + upper_mask = jnp.triu(jnp.ones((BT, BT), dtype=jnp.bool_)) # [BT, BT] includes diagonal- A = jnp.where(mask, 0, -A) # [B, H, NT, BT, BT] zero upper triangle, negate lower + A = jnp.where(upper_mask, 0, -A) # [B, H, NT, BT, BT] zero upper triangle, negate lower- mask = jnp.triu(jnp.ones((BT, BT), dtype=jnp.bool_), k=1) # [BT, BT] strict upper triangle + causal_mask = jnp.triu(jnp.ones((BT, BT), dtype=jnp.bool_), k=1) # [BT, BT] strict upper triangle- A_qk = jnp.where(mask, 0, A_qk) # [B, H, BT, BT] + A_qk = jnp.where(causal_mask, 0, A_qk) # [B, H, BT, BT]Also applies to: 265-265
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tops/cpu/ops/kda/naive.py` at line 221, The code defines two different masks using the same name `mask` (one is an upper-triangular including the diagonal created via jnp.triu(jnp.ones((BT, BT), ...)) and the other is a strictly upper-triangular/causal mask); rename them to distinct, descriptive names (e.g., mask_triu_diag or A_mask for the one used to zero A's upper triangle, and mask_triu_strict or causal_mask for the strictly upper-triangular causal attention mask) and update all subsequent uses of `mask` in the surrounding function(s) so each use references the correctly renamed symbol.tops/cpu/ops/kda/fused_recurrent.py (2)
565-565: Consider sorting__all__for consistency.Static analysis suggests sorting
__all__alphabetically for consistency with isort-style conventions.Suggested fix
-__all__ = ["fused_recurrent_kda_fwd", "fused_recurrent_kda_bwd", "fused_recurrent_kda"] +__all__ = ["fused_recurrent_kda", "fused_recurrent_kda_bwd", "fused_recurrent_kda_fwd"]🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tops/cpu/ops/kda/fused_recurrent.py` at line 565, The __all__ export list is unsorted; please alphabetize the entries so exports follow a consistent order—replace __all__ = ["fused_recurrent_kda_fwd", "fused_recurrent_kda_bwd", "fused_recurrent_kda"] with a sorted list (e.g., ["fused_recurrent_kda", "fused_recurrent_kda_bwd", "fused_recurrent_kda_fwd"]) ensuring the symbols fused_recurrent_kda, fused_recurrent_kda_bwd, and fused_recurrent_kda_fwd remain present and correctly quoted.
115-118: Unused variableHcan be suppressed.The variable
His unpacked but never used. Consider using an underscore prefix to suppress the linter warning.Suggested fix
- B, T, H, K = q.shape + B, T, _H, K = q.shape🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tops/cpu/ops/kda/fused_recurrent.py` around lines 115 - 118, The unpacking of q.shape assigns H but it is never used; change the unpack to suppress the linter by replacing H with a throwaway name (e.g., _ or _H) in the tuple assignment so it becomes B, T, _H, K = q.shape (or B, T, _, K = q.shape), leaving the subsequent uses of q, v, and cu_seqlens unchanged; ensure the rest of the function (references to q, v, cu_seqlens, and variables B, T, K, V, HV, N) still work after the rename.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@tops/cpu/ops/kda/fused_recurrent.py`:
- Around line 524-562: The public function fused_recurrent_kda currently has
only a one-line docstring; replace it with a comprehensive docstring that (1)
describes the high-level business semantics of fused_recurrent_kda and how it
relates to fused_recurrent_kda_fwd, (2) lists every parameter (q, k, v, g, beta,
A_log, dt_bias, scale, initial_state, output_final_state,
use_qk_l2norm_in_kernel, use_gate_in_kernel, lower_bound, cu_seqlens,
transpose_state_layout, **kwargs) with expected shapes, dtypes and axis meanings
(e.g. batch, sequence, heads, head_dim), which args are optional and default
behaviors, (3) documents the return value tuple and shapes/types (output tensor
and optional final state when output_final_state=True), and (4) notes any
layout/format conventions (e.g. state layout when transpose_state_layout
toggled), edge cases and side effects; add short usage examples and mention any
requirements/constraints (e.g. broadcasting rules, supported dtypes) to make the
API clear to callers.
- Around line 476-521: The public function fused_recurrent_kda_fwd is missing a
comprehensive docstring; add a multi-line docstring (modeled on
fused_recurrent_kda_bwd) that explains the operation semantics and documents
every parameter and return value with precise tensor shapes and dimension
meanings for q, k, v, g, beta, A_log, dt_bias, initial_state, scale,
output_final_state, inplace_final_state, cu_seqlens, ssm_state_indices,
num_accepted_tokens, use_qk_l2norm_in_kernel, use_gate_in_kernel, lower_bound,
out, and transpose_state_layout; include return shapes (output array and
optional final state), any optional semantics (e.g., when output_final_state
True), allowed dtypes, and any side-effects (inplace_final_state behavior) so
callers have complete usage guidance.
---
Duplicate comments:
In `@tops/cpu/ops/kda/fused_recurrent.py`:
- Around line 285-287: When use_qk_l2norm_in_kernel is true, dq_f and dk_f are
gradients w.r.t. the normalized outputs of _l2_normalize_last_dim, so you must
backprop through that normalization: compute VJPs of _l2_normalize_last_dim for
q_f and k_f (e.g., call jax.vjp or the equivalent vjp function used in this
code) to convert dq_f/dk_f into gradients w.r.t. the original q_f and k_f, and
replace dq_f/dk_f with those propagated gradients before the code continues to
the final transpose step; ensure you use the exact _l2_normalize_last_dim
function and apply the VJP outputs in place of dq_f and dk_f so the chain rule
is satisfied when use_qk_l2norm_in_kernel=True.
- Around line 289-295: The backward pass incorrectly transposes g to g_raw =
jnp.transpose(g, (0, 2, 1, 3)) producing [B,H,T,K] before calling
fused_kda_gate, but the forward used g_h in [B,T,H,K]; remove the transpose and
pass g (or a variable named g_h) directly to fused_kda_gate so the tensor layout
matches A_log's per-head parameters (update the call site in the
use_gate_in_kernel branch where fused_kda_gate(g_raw, A_log, dt_bias=dt_bias,
lower_bound=lower_bound, output_dtype=acc) is invoked).
---
Nitpick comments:
In `@tops/cpu/ops/kda/fused_recurrent.py`:
- Line 565: The __all__ export list is unsorted; please alphabetize the entries
so exports follow a consistent order—replace __all__ =
["fused_recurrent_kda_fwd", "fused_recurrent_kda_bwd", "fused_recurrent_kda"]
with a sorted list (e.g., ["fused_recurrent_kda", "fused_recurrent_kda_bwd",
"fused_recurrent_kda_fwd"]) ensuring the symbols fused_recurrent_kda,
fused_recurrent_kda_bwd, and fused_recurrent_kda_fwd remain present and
correctly quoted.
- Around line 115-118: The unpacking of q.shape assigns H but it is never used;
change the unpack to suppress the linter by replacing H with a throwaway name
(e.g., _ or _H) in the tuple assignment so it becomes B, T, _H, K = q.shape (or
B, T, _, K = q.shape), leaving the subsequent uses of q, v, and cu_seqlens
unchanged; ensure the rest of the function (references to q, v, cu_seqlens, and
variables B, T, K, V, HV, N) still work after the rename.
In `@tops/cpu/ops/kda/naive.py`:
- Line 221: The code defines two different masks using the same name `mask` (one
is an upper-triangular including the diagonal created via jnp.triu(jnp.ones((BT,
BT), ...)) and the other is a strictly upper-triangular/causal mask); rename
them to distinct, descriptive names (e.g., mask_triu_diag or A_mask for the one
used to zero A's upper triangle, and mask_triu_strict or causal_mask for the
strictly upper-triangular causal attention mask) and update all subsequent uses
of `mask` in the surrounding function(s) so each use references the correctly
renamed symbol.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: 19e8ed5b-0042-4ef7-8797-8581757a146b
📒 Files selected for processing (2)
tops/cpu/ops/kda/fused_recurrent.pytops/cpu/ops/kda/naive.py
|
LGTM |
|
LGTM |
|
|
||
| # ========================================================================= | ||
| # Step 2: Solve lower-triangular dependency system | ||
| # Forward substitution: A[i, :i] += A[i, :] @ A[:, :i] |
There was a problem hiding this comment.
Rather than forward substitution, this code resembles a row-wise iterative implementation of the Neumann series expansion.
There was a problem hiding this comment.
You're right,thanks a lot. I'll fix the comment in next PR.
Description
Related Issue
Closes #
Change Type
feat— New featurefix— Bug fixrefactor— Code refactoringdocs— Documentationci— CI/CD changestest— Testsperf— Performance improvementChecklist
uv run ruff check src/ tests/anduv run ruff format src/ tests/assertorassert_shape_or_none)tests/ops/,tests/modules/,tests/layers/, ortests/ref/)tops/cpu/is modified, core developers have been notified and PR is labeledcpu-refTest Results
Desc
Summary by CodeRabbit
New Features
API Changes
Tests