Skip to content

Commit

Permalink
Allow qudits in deferred measurements (quantumlib#5850)
Browse files Browse the repository at this point in the history
For this, we have to define a multidimensional ModAdd gate, for use in applying the state from the source qudit to the ancilla qudit representing the creg. 

That done, we insert it into the deferred measurements algorithm instead of the ordinary CX gate, and add a qudit test to make sure it all works.

cc @viathor for sanity check on the gate logic
  • Loading branch information
daxfohl authored and rht committed May 1, 2023
1 parent 05b6577 commit 42fa984
Show file tree
Hide file tree
Showing 3 changed files with 66 additions and 9 deletions.
1 change: 1 addition & 0 deletions cirq-core/cirq/ops/gate_operation_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -494,6 +494,7 @@ def all_subclasses(cls):
cirq.Pauli,
# Private gates.
cirq.transformers.analytical_decompositions.two_qubit_to_fsim._BGate,
cirq.transformers.measurement_transformers._ModAdd,
cirq.transformers.routing.visualize_routed_circuit._SwapPrintGate,
cirq.ops.raw_types._InverseCompositeGate,
cirq.circuits.qasm_output.QasmTwoQubitGate,
Expand Down
43 changes: 39 additions & 4 deletions cirq-core/cirq/transformers/measurement_transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.

import itertools
from typing import Any, Dict, List, Optional, TYPE_CHECKING, Union
from typing import Any, Dict, List, Optional, Tuple, TYPE_CHECKING, Union

from cirq import ops, protocols, value
from cirq.transformers import transformer_api, transformer_primitives
Expand Down Expand Up @@ -46,7 +46,7 @@ def dimension(self) -> int:
return self._qid.dimension

def _comparison_key(self) -> Any:
return (str(self._key), self._qid._comparison_key())
return str(self._key), self._qid._comparison_key()

def __str__(self) -> str:
return f"M('{self._key}', q={self._qid})"
Expand Down Expand Up @@ -104,7 +104,7 @@ def defer(op: 'cirq.Operation', _) -> 'cirq.OP_TREE':
key = value.MeasurementKey.parse_serialized(gate.key)
targets = [_MeasurementQid(key, q) for q in op.qubits]
measurement_qubits[key] = targets
cxs = [ops.CX(q, target) for q, target in zip(op.qubits, targets)]
cxs = [_mod_add(q, target) for q, target in zip(op.qubits, targets)]
xs = [ops.X(targets[i]) for i, b in enumerate(gate.full_invert_mask()) if b]
return cxs + xs
elif protocols.is_measurement(op):
Expand All @@ -117,7 +117,7 @@ def defer(op: 'cirq.Operation', _) -> 'cirq.OP_TREE':
raise ValueError(f'Deferred measurement for key={c.key} not found.')
qs = measurement_qubits[c.key]
if len(qs) == 1:
control_values: Any = range(1, qs[0].dimension)
control_values: Any = [range(1, qs[0].dimension)]
else:
all_values = itertools.product(*[range(q.dimension) for q in qs])
anything_but_all_zeros = tuple(itertools.islice(all_values, 1, None))
Expand Down Expand Up @@ -227,3 +227,38 @@ def flip_inversion(op: 'cirq.Operation', _) -> 'cirq.OP_TREE':
return transformer_primitives.map_operations(
circuit, flip_inversion, deep=context.deep if context else True, tags_to_ignore=ignored
).unfreeze()


@value.value_equality
class _ModAdd(ops.ArithmeticGate):
"""Adds two qudits of the same dimension.
Operates on two qudits by modular addition:
|a,b> -> |a,a+b mod d>"""

def __init__(self, dimension: int):
self._dimension = dimension

def registers(self) -> Tuple[Tuple[int], Tuple[int]]:
return (self._dimension,), (self._dimension,)

def with_registers(self, *new_registers) -> '_ModAdd':
raise NotImplementedError()

def apply(self, *register_values: int) -> Tuple[int, int]:
return register_values[0], sum(register_values)

def _value_equality_values_(self) -> int:
return self._dimension


def _mod_add(source: 'cirq.Qid', target: 'cirq.Qid') -> 'cirq.Operation':
assert source.dimension == target.dimension
if source.dimension == 2:
# Use a CX gate in 2D case for simplicity.
return ops.CX(source, target)
# We can use a ModAdd gate in the qudit case, since the ancilla qudit corresponding to the
# measurement is always zero, so "adding" the measured qudit to it sets the ancilla qudit to
# the same state, which is the quantum equivalent to a measurement onto a creg.
return _ModAdd(source.dimension).on(source, target)
31 changes: 26 additions & 5 deletions cirq-core/cirq/transformers/measurement_transformers_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,19 +17,19 @@
import sympy

import cirq
from cirq.transformers.measurement_transformers import _MeasurementQid
from cirq.transformers.measurement_transformers import _mod_add, _MeasurementQid


def assert_equivalent_to_deferred(circuit: cirq.Circuit):
qubits = list(circuit.all_qubits())
sim = cirq.Simulator()
num_qubits = len(qubits)
for i in range(2**num_qubits):
bits = cirq.big_endian_int_to_bits(i, bit_count=num_qubits)
dimensions = [q.dimension for q in qubits]
for i in range(np.prod(dimensions)):
bits = cirq.big_endian_int_to_digits(i, base=dimensions)
modified = cirq.Circuit()
for j in range(num_qubits):
if bits[j]:
modified.append(cirq.X(qubits[j]))
modified.append(cirq.XPowGate(dimension=qubits[j].dimension)(qubits[j]) ** bits[j])
modified.append(circuit)
deferred = cirq.defer_measurements(modified)
result = sim.simulate(modified)
Expand Down Expand Up @@ -58,6 +58,27 @@ def test_basic():
)


def test_qudits():
q0, q1 = cirq.LineQid.range(2, dimension=3)
circuit = cirq.Circuit(
cirq.measure(q0, key='a'),
cirq.XPowGate(dimension=3).on(q1).with_classical_controls('a'),
cirq.measure(q1, key='b'),
)
assert_equivalent_to_deferred(circuit)
deferred = cirq.defer_measurements(circuit)
q_ma = _MeasurementQid('a', q0)
cirq.testing.assert_same_circuits(
deferred,
cirq.Circuit(
_mod_add(q0, q_ma),
cirq.XPowGate(dimension=3).on(q1).controlled_by(q_ma, control_values=[[1, 2]]),
cirq.measure(q_ma, key='a'),
cirq.measure(q1, key='b'),
),
)


def test_nocompile_context():
q0, q1 = cirq.LineQubit.range(2)
circuit = cirq.Circuit(
Expand Down

0 comments on commit 42fa984

Please sign in to comment.