Skip to content

Conversation

@rogerwaleffe
Copy link
Contributor

This PR adds the following:

  1. Differences of negative numbers that should be negative are clamped to zero. In most cases this clamping has no effect. In the case where numerical error leads the difference to be positive, this clamping prevents a positive number from being exponentiated leading to infinities/nans.
  2. xBC is made contiguous before causal_conv1d to prevent a stride error (RuntimeError: causal_conv1d with channel last layout requires strides (x.stride(0) and x.stride(2)) to be multiples of 8) from being thrown when the number of heads is not a multiple of 8. With this patch, the number of heads can be any positive integer.
  3. A ddt -> dt typo is fixed

@tridao
Copy link
Collaborator

tridao commented Mar 27, 2025

I think contiguous should only be called if the stride is not a multiple of 8. In other cases calling contiguous would incur an extra kernel?

@rogerwaleffe
Copy link
Contributor Author

Yeah I agree. I'll add that

@tridao tridao merged commit 2e16fc3 into state-spaces:main Apr 1, 2025
@peterbjorgensen
Copy link

This patch only fixes the training path (triton ops), but the numerical issues remain in the inference paths.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants