Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: Resolve differences betweenjax_backend.concatenate and other backends #1655

Closed
wants to merge 5 commits into from
Closed

fix: Resolve differences betweenjax_backend.concatenate and other backends #1655

wants to merge 5 commits into from

Conversation

phinate
Copy link
Contributor

@phinate phinate commented Oct 19, 2021

Resolves part of the jax component for #1654.

This should make the behaviour consistent with numpy, e.g. when using a list as input. The perf slowdown in the trivial case (input == jax array) is on the order of µs on my machine (8 core mbp) with an example 1D array of size 10.

I made this PR purely for syntactical consistency -- of course, the slightly neater way may be to just always cast sequence-type inputs to tensors explicitly, so this isn't a hidden detail.
I couldn't think of a way to best test this consistency in the scope of this PR; if there was a suite that tested the output of common ops, then it fits nicely in there, but AFAIK you only test the core pyhf functionality instead.

Checklist Before Requesting Reviewer

  • Tests are passing
  • "WIP" removed from the title of the pull request
  • Selected an Assignee for the PR to be responsible for the log summary

Before Merging

For the PR Assignees:

  • Summarize commit messages into a comprehensive review of the PR

This should make the behaviour consistent with numpy, e.g. when using a list as input.
@phinate
Copy link
Contributor Author

phinate commented Oct 19, 2021

@kratsg thoughts on ths type of approach?

@phinate phinate changed the title Cast the input of jax_backend.concatenate to a jax array fix: Cast the input of jax_backend.concatenate to a jax array Oct 19, 2021
@phinate
Copy link
Contributor Author

phinate commented Oct 22, 2021

Explaining the current CI failures:

there are cases when doing stitches where this type of encounter is had:

jnp.concatenate((jnp.array([]), jnp.array([3])))
> DeviceArray([3.], dtype=float64)

but this fails if you cast the input to an array.

The actual thing that would fix this PR is casting the elements of sequence to arrays, but my naive implementation of that involves a for loop, e.g.

return jnp.concatenate([jnp.asarray(s) for s in sequence], axis=axis)

which obviously incurs a noticable slowdown.

@phinate
Copy link
Contributor Author

phinate commented Oct 22, 2021

I can directly address the issue with

return jnp.concatenate(jnp.array(tuple(filter(lambda item: len(item) != 0, sequence))))

which is around 3 times faster, but this is still 10000x slower than numpy :/

Any suggestions here? do we go the long-winded approach of looking at every concatenate instance and making sure it's jax-friendly individually?

@phinate phinate changed the title fix: Cast the input of jax_backend.concatenate to a jax array fix: Resolve differences betweenjax_backend.concatenate and other backends Oct 27, 2021
This pull request was closed.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants