Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUG] JIT + Global measurements without device wires does not work #5813

Open
1 task done
dwierichs opened this issue Jun 6, 2024 · 0 comments
Open
1 task done
Labels
bug 🐛 Something isn't working

Comments

@dwierichs
Copy link
Contributor

Expected behavior

Inferring the wires of a global measurement from the tape's operations if the device has no wires specified works without and with JITting.

Actual behavior

Only works without JITting.

Additional information

@albi3ro suggested to modify MeasurementProcess.shape to take num_device_wires as an argument, instead of the full device.

Source code

dev = qml.device("default.qubit")

@jax.jit
@qml.qnode(dev, diff_method="parameter-shift")
def node(x):
    qml.RX(x,0)
    return qml.probs() # Crashes, see thread
    return qml.probs(wires=[0, 1]) # Works

node(jnp.array(0.5))

Tracebacks

---------------------------------------------------------------------------
XlaRuntimeError                           Traceback (most recent call last)
Cell In[25], line 9
      6     qml.RX(x,0)
      7     return qml.probs()
----> 9 node(jnp.array(0.5))

    [... skipping hidden 10 frame]

File ~/venvs/dev/lib/python3.10/site-packages/jax/_src/interpreters/pxla.py:1151, in ExecuteReplicated.__call__(self, *args)
   1148 if (self.ordered_effects or self.has_unordered_effects
   1149     or self.has_host_callbacks):
   1150   input_bufs = self._add_tokens_to_inputs(input_bufs)
-> 1151   results = self.xla_executable.execute_sharded(
   1152       input_bufs, with_tokens=True
   1153   )
   1154   result_token_bufs = results.disassemble_prefix_into_single_device_arrays(
   1155       len(self.ordered_effects))
   1156   sharded_runtime_token = results.consume_token()

XlaRuntimeError: INTERNAL: Generated function failed: CpuCallback error: RuntimeError: Incorrect output shape for return value 0: Expected: (1,), Actual: (2,)

At:
  /home/david/venvs/dev/lib/python3.10/site-packages/jax/_src/interpreters/mlir.py(2374): _wrapped_callback
  /home/david/venvs/dev/lib/python3.10/site-packages/jax/_src/interpreters/pxla.py(1151): __call__
  /home/david/venvs/dev/lib/python3.10/site-packages/jax/_src/profiler.py(336): wrapper
  /home/david/venvs/dev/lib/python3.10/site-packages/jax/_src/pjit.py(1185): _pjit_call_impl_python
  /home/david/venvs/dev/lib/python3.10/site-packages/jax/_src/pjit.py(1229): call_impl_cache_miss
  /home/david/venvs/dev/lib/python3.10/site-packages/jax/_src/pjit.py(1245): _pjit_call_impl
  /home/david/venvs/dev/lib/python3.10/site-packages/jax/_src/core.py(935): process_primitive
  /home/david/venvs/dev/lib/python3.10/site-packages/jax/_src/core.py(447): bind_with_trace
  /home/david/venvs/dev/lib/python3.10/site-packages/jax/_src/core.py(2740): bind
  /home/david/venvs/dev/lib/python3.10/site-packages/jax/_src/pjit.py(168): _python_pjit_helper
  /home/david/venvs/dev/lib/python3.10/site-packages/jax/_src/pjit.py(257): cache_miss
  /home/david/venvs/dev/lib/python3.10/site-packages/jax/_src/traceback_util.py(179): reraise_with_filtered_traceback
  /tmp/ipykernel_292987/2553214436.py(9): <module>
  /home/david/venvs/dev/lib/python3.10/site-packages/IPython/core/interactiveshell.py(3577): run_code
  /home/david/venvs/dev/lib/python3.10/site-packages/IPython/core/interactiveshell.py(3517): run_ast_nodes
  /home/david/venvs/dev/lib/python3.10/site-packages/IPython/core/interactiveshell.py(3334): run_cell_async
  /home/david/venvs/dev/lib/python3.10/site-packages/IPython/core/async_helpers.py(129): _pseudo_sync_runner
  /home/david/venvs/dev/lib/python3.10/site-packages/IPython/core/interactiveshell.py(3130): _run_cell
  /home/david/venvs/dev/lib/python3.10/site-packages/IPython/core/interactiveshell.py(3075): run_cell
  /home/david/venvs/dev/lib/python3.10/site-packages/ipykernel/zmqshell.py(549): run_cell
  /home/david/venvs/dev/lib/python3.10/site-packages/ipykernel/ipkernel.py(449): do_execute
  /home/david/venvs/dev/lib/python3.10/site-packages/ipykernel/kernelbase.py(778): execute_request
  /home/david/venvs/dev/lib/python3.10/site-packages/ipykernel/ipkernel.py(362): execute_request
  /home/david/venvs/dev/lib/python3.10/site-packages/ipykernel/kernelbase.py(437): dispatch_shell
  /home/david/venvs/dev/lib/python3.10/site-packages/ipykernel/kernelbase.py(534): process_one
  /home/david/venvs/dev/lib/python3.10/site-packages/ipykernel/kernelbase.py(545): dispatch_queue
  /usr/lib/python3.10/asyncio/events.py(80): _run
  /usr/lib/python3.10/asyncio/base_events.py(1909): _run_once
  /usr/lib/python3.10/asyncio/base_events.py(603): run_forever
  /home/david/venvs/dev/lib/python3.10/site-packages/tornado/platform/asyncio.py(205): start
  /home/david/venvs/dev/lib/python3.10/site-packages/ipykernel/kernelapp.py(739): start
  /home/david/venvs/dev/lib/python3.10/site-packages/traitlets/config/application.py(1075): launch_instance
  /home/david/venvs/dev/lib/python3.10/site-packages/ipykernel_launcher.py(18): <module>
  /usr/lib/python3.10/runpy.py(86): _run_code
  /usr/lib/python3.10/runpy.py(196): _run_module_as_main

System information

pl dev

Existing GitHub issues

  • I have searched existing GitHub issues to make sure the issue does not already exist.
@dwierichs dwierichs added the bug 🐛 Something isn't working label Jun 6, 2024
@dwierichs dwierichs changed the title [BUG] [BUG] JIT + Global measurements without device wires does not work Jun 6, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug 🐛 Something isn't working
Projects
None yet
Development

No branches or pull requests

1 participant