We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Inferring the wires of a global measurement from the tape's operations if the device has no wires specified works without and with JITting.
Only works without JITting.
@albi3ro suggested to modify MeasurementProcess.shape to take num_device_wires as an argument, instead of the full device.
MeasurementProcess.shape
num_device_wires
dev = qml.device("default.qubit") @jax.jit @qml.qnode(dev, diff_method="parameter-shift") def node(x): qml.RX(x,0) return qml.probs() # Crashes, see thread return qml.probs(wires=[0, 1]) # Works node(jnp.array(0.5))
--------------------------------------------------------------------------- XlaRuntimeError Traceback (most recent call last) Cell In[25], line 9 6 qml.RX(x,0) 7 return qml.probs() ----> 9 node(jnp.array(0.5)) [... skipping hidden 10 frame] File ~/venvs/dev/lib/python3.10/site-packages/jax/_src/interpreters/pxla.py:1151, in ExecuteReplicated.__call__(self, *args) 1148 if (self.ordered_effects or self.has_unordered_effects 1149 or self.has_host_callbacks): 1150 input_bufs = self._add_tokens_to_inputs(input_bufs) -> 1151 results = self.xla_executable.execute_sharded( 1152 input_bufs, with_tokens=True 1153 ) 1154 result_token_bufs = results.disassemble_prefix_into_single_device_arrays( 1155 len(self.ordered_effects)) 1156 sharded_runtime_token = results.consume_token() XlaRuntimeError: INTERNAL: Generated function failed: CpuCallback error: RuntimeError: Incorrect output shape for return value 0: Expected: (1,), Actual: (2,) At: /home/david/venvs/dev/lib/python3.10/site-packages/jax/_src/interpreters/mlir.py(2374): _wrapped_callback /home/david/venvs/dev/lib/python3.10/site-packages/jax/_src/interpreters/pxla.py(1151): __call__ /home/david/venvs/dev/lib/python3.10/site-packages/jax/_src/profiler.py(336): wrapper /home/david/venvs/dev/lib/python3.10/site-packages/jax/_src/pjit.py(1185): _pjit_call_impl_python /home/david/venvs/dev/lib/python3.10/site-packages/jax/_src/pjit.py(1229): call_impl_cache_miss /home/david/venvs/dev/lib/python3.10/site-packages/jax/_src/pjit.py(1245): _pjit_call_impl /home/david/venvs/dev/lib/python3.10/site-packages/jax/_src/core.py(935): process_primitive /home/david/venvs/dev/lib/python3.10/site-packages/jax/_src/core.py(447): bind_with_trace /home/david/venvs/dev/lib/python3.10/site-packages/jax/_src/core.py(2740): bind /home/david/venvs/dev/lib/python3.10/site-packages/jax/_src/pjit.py(168): _python_pjit_helper /home/david/venvs/dev/lib/python3.10/site-packages/jax/_src/pjit.py(257): cache_miss /home/david/venvs/dev/lib/python3.10/site-packages/jax/_src/traceback_util.py(179): reraise_with_filtered_traceback /tmp/ipykernel_292987/2553214436.py(9): <module> /home/david/venvs/dev/lib/python3.10/site-packages/IPython/core/interactiveshell.py(3577): run_code /home/david/venvs/dev/lib/python3.10/site-packages/IPython/core/interactiveshell.py(3517): run_ast_nodes /home/david/venvs/dev/lib/python3.10/site-packages/IPython/core/interactiveshell.py(3334): run_cell_async /home/david/venvs/dev/lib/python3.10/site-packages/IPython/core/async_helpers.py(129): _pseudo_sync_runner /home/david/venvs/dev/lib/python3.10/site-packages/IPython/core/interactiveshell.py(3130): _run_cell /home/david/venvs/dev/lib/python3.10/site-packages/IPython/core/interactiveshell.py(3075): run_cell /home/david/venvs/dev/lib/python3.10/site-packages/ipykernel/zmqshell.py(549): run_cell /home/david/venvs/dev/lib/python3.10/site-packages/ipykernel/ipkernel.py(449): do_execute /home/david/venvs/dev/lib/python3.10/site-packages/ipykernel/kernelbase.py(778): execute_request /home/david/venvs/dev/lib/python3.10/site-packages/ipykernel/ipkernel.py(362): execute_request /home/david/venvs/dev/lib/python3.10/site-packages/ipykernel/kernelbase.py(437): dispatch_shell /home/david/venvs/dev/lib/python3.10/site-packages/ipykernel/kernelbase.py(534): process_one /home/david/venvs/dev/lib/python3.10/site-packages/ipykernel/kernelbase.py(545): dispatch_queue /usr/lib/python3.10/asyncio/events.py(80): _run /usr/lib/python3.10/asyncio/base_events.py(1909): _run_once /usr/lib/python3.10/asyncio/base_events.py(603): run_forever /home/david/venvs/dev/lib/python3.10/site-packages/tornado/platform/asyncio.py(205): start /home/david/venvs/dev/lib/python3.10/site-packages/ipykernel/kernelapp.py(739): start /home/david/venvs/dev/lib/python3.10/site-packages/traitlets/config/application.py(1075): launch_instance /home/david/venvs/dev/lib/python3.10/site-packages/ipykernel_launcher.py(18): <module> /usr/lib/python3.10/runpy.py(86): _run_code /usr/lib/python3.10/runpy.py(196): _run_module_as_main
pl dev
The text was updated successfully, but these errors were encountered:
No branches or pull requests
Expected behavior
Inferring the wires of a global measurement from the tape's operations if the device has no wires specified works without and with JITting.
Actual behavior
Only works without JITting.
Additional information
@albi3ro suggested to modify
MeasurementProcess.shape
to takenum_device_wires
as an argument, instead of the full device.Source code
Tracebacks
System information
Existing GitHub issues
The text was updated successfully, but these errors were encountered: