-
Notifications
You must be signed in to change notification settings - Fork 7
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
JIT compilation using Qobj #6
Comments
We could try to register the Qobj as a pytree node when importing qutip-jax. We should be able to set it to work with |
Thank you @Ericgig and @quantshah for making this repo and pushing this forward!! This is amazing! I just wanna add a small bit of my understanding of JAX JIT. It seems to me that JIT (and jax.grad etc.) works as long as the function-to-be-jitted only takes inputs and gives outputs that are supported by JAX, e.g The following JIT example works fine for me with the current master branch. import qutip
import qutip_jax
import jax
@jax.jit
def fun(a):
M = qutip.sigmay().to("jax")
N = qutip.sigmaz().to("jax")
return (a * M.conj() * N + N).data._jxa
fun(3.) This looks sufficient to me. The input is just numbers and the output is a JAX array. Maybe instead of making the whole |
We need to define what is sufficient. Are |
No. This is just an example showing that even if you cannot directly In the way that JAX implemented things, you cannot jit any function but only those that are pure functions. It is very likely that we can never make
Yes, that would be great. It should be feasible with some global settings? E.g. with a default
Yes it should also work with other |
|
Thanks Boxi. Yes I have been using JIT by shuttling to and fro QuTiP Qobj
and it all works with the extra step of .to('jax'). The JaxArray type
registered as a PyTree works for JIT and I was hoping we could do something
like that for all Qobj when qutip_jax is imported to get rid of the extra
.to('jax').
Several things will break when we use Jax and JIT, as Eric pointed out.
Even something as simple as displace() will not work, but there are
workarounds to it. I am wondering if we can somehow mark and change all
functions that are non JITable natively and suggest a JITable qutip_jax
version. Eg, qutip_jax.ptrace() if the use calls ptrace and also wants to
JIT things.
Mesolve cannot be JITted directly of course but there again it's possible
to have a custom solver with Jax that is JIT able. I will try to post some
examples this week.
…On Fri, 7 Oct 2022 at 18:12, Eric Giguère ***@***.***> wrote:
tensor will be just adding a new specialisations, but ptrace will not.
You cannot branch on input value with jit per default, so ptrace's sel
will cause issues. eigenstates return a pair of eigenvalues and list of
Qobj, thus I don't see how that could work...
solver and integrator cannot be jitted, nor it makes any sense to try to.
But we need to think about getting grad working with solvers.
—
Reply to this email directly, view it on GitHub
<#6 (comment)>, or
unsubscribe
<https://github.com/notifications/unsubscribe-auth/ABVFIBDWWYNELLKAUTDMGZDWCBDXDANCNFSM6AAAAAAQ7PUMZY>
.
You are receiving this because you were mentioned.Message ID:
***@***.***>
|
On a second thought maybe we should just good documentation and a section
about "sharp edges" like the Jax documentation on creating functions that
are pure and only take Jax data as inputs and outputs. In the simplest use
case, I was thinking of "learning unitaries" using gradient descent. You
send in angles that define the unitary and get out some measure of fidelity
as a number. Probably that's enough for all the JITing we want to do. Let
me post a notebook and we can discuss more.
On Sun, 9 Oct 2022 at 18:45, Shahnawaz Ahmed ***@***.***>
wrote:
… Thanks Boxi. Yes I have been using JIT by shuttling to and fro QuTiP Qobj
and it all works with the extra step of .to('jax'). The JaxArray type
registered as a PyTree works for JIT and I was hoping we could do something
like that for all Qobj when qutip_jax is imported to get rid of the extra
.to('jax').
Several things will break when we use Jax and JIT, as Eric pointed out.
Even something as simple as displace() will not work, but there are
workarounds to it. I am wondering if we can somehow mark and change all
functions that are non JITable natively and suggest a JITable qutip_jax
version. Eg, qutip_jax.ptrace() if the use calls ptrace and also wants to
JIT things.
Mesolve cannot be JITted directly of course but there again it's possible
to have a custom solver with Jax that is JIT able. I will try to post some
examples this week.
On Fri, 7 Oct 2022 at 18:12, Eric Giguère ***@***.***>
wrote:
> tensor will be just adding a new specialisations, but ptrace will not.
> You cannot branch on input value with jit per default, so ptrace's sel
> will cause issues. eigenstates return a pair of eigenvalues and list of
> Qobj, thus I don't see how that could work...
>
> solver and integrator cannot be jitted, nor it makes any sense to try
> to. But we need to think about getting grad working with solvers.
>
> —
> Reply to this email directly, view it on GitHub
> <#6 (comment)>,
> or unsubscribe
> <https://github.com/notifications/unsubscribe-auth/ABVFIBDWWYNELLKAUTDMGZDWCBDXDANCNFSM6AAAAAAQ7PUMZY>
> .
> You are receiving this because you were mentioned.Message ID:
> ***@***.***>
>
|
The benefit of using JAX is the ability to JIT compile. With the setup right now, it is not clear what's the best way to make JAX recognize QuTiP objects as valid inputs since JIT only works for pure JAX arrays. There are workarounds to it, e.g., https://github.com/google/jax/blob/cc13fd1e5892a08f5360db933d4dfd64c0fc66eb/jax/experimental/lapax.py#L164. The alternative is to use the data._jxa instead of passing around quantum objects as:
The text was updated successfully, but these errors were encountered: