Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow qudits in deferred measurements #5850

Merged
merged 10 commits into from
Oct 11, 2022
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions cirq-core/cirq/ops/gate_operation_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -494,6 +494,7 @@ def all_subclasses(cls):
cirq.Pauli,
# Private gates.
cirq.transformers.analytical_decompositions.two_qubit_to_fsim._BGate,
cirq.transformers.measurement_transformers._CX,
cirq.ops.raw_types._InverseCompositeGate,
cirq.circuits.qasm_output.QasmTwoQubitGate,
cirq.ops.MSGate,
Expand Down
57 changes: 54 additions & 3 deletions cirq-core/cirq/transformers/measurement_transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
import itertools
from typing import Any, Dict, List, Optional, TYPE_CHECKING, Union

import numpy as np

from cirq import ops, protocols, value
from cirq.transformers import transformer_api, transformer_primitives
from cirq.transformers.synchronize_terminal_measurements import find_terminal_measurements
Expand Down Expand Up @@ -46,7 +48,7 @@ def dimension(self) -> int:
return self._qid.dimension

def _comparison_key(self) -> Any:
return (str(self._key), self._qid._comparison_key())
return str(self._key), self._qid._comparison_key()

def __str__(self) -> str:
return f"M('{self._key}', q={self._qid})"
Expand Down Expand Up @@ -104,7 +106,7 @@ def defer(op: 'cirq.Operation', _) -> 'cirq.OP_TREE':
key = value.MeasurementKey.parse_serialized(gate.key)
targets = [_MeasurementQid(key, q) for q in op.qubits]
measurement_qubits[key] = targets
cxs = [ops.CX(q, target) for q, target in zip(op.qubits, targets)]
cxs = [_cx(q.dimension).on(q, target) for q, target in zip(op.qubits, targets)]
daxfohl marked this conversation as resolved.
Show resolved Hide resolved
xs = [ops.X(targets[i]) for i, b in enumerate(gate.full_invert_mask()) if b]
return cxs + xs
elif protocols.is_measurement(op):
Expand All @@ -117,7 +119,7 @@ def defer(op: 'cirq.Operation', _) -> 'cirq.OP_TREE':
raise ValueError(f'Deferred measurement for key={c.key} not found.')
qs = measurement_qubits[c.key]
if len(qs) == 1:
control_values: Any = range(1, qs[0].dimension)
control_values: Any = [range(1, qs[0].dimension)]
else:
all_values = itertools.product(*[range(q.dimension) for q in qs])
anything_but_all_zeros = tuple(itertools.islice(all_values, 1, None))
Expand Down Expand Up @@ -227,3 +229,52 @@ def flip_inversion(op: 'cirq.Operation', _) -> 'cirq.OP_TREE':
return transformer_primitives.map_operations(
circuit, flip_inversion, deep=context.deep if context else True, tags_to_ignore=ignored
).unfreeze()


class _CX(ops.Gate):
"""A CX gate generalized for qudits.

This represents a Controlled-NOT gate for qudits, in the following sense. If the target is
zero, then it becomes the value of the control. If the target is equal to the control, then
it becomes zero. All other cases do not affect the target. The first rule is the only one
required for measurement-deferral purposes: the ancilla qudit representing the creg is always
zero at the time of measurement. The remaining rules are for symmetry.

|k0> -> |kk>
|kk> -> |k0>
|kj> -> |kj> otherwise
daxfohl marked this conversation as resolved.
Show resolved Hide resolved

The unitary is formed directly from these rules, thus fully defining behavior in the presence
of superposition and entanglement. This definition preserves 2D CX behavior, such as
CX∘CX == I, and CX∘XC∘CX == SWAP.

Note that this is explicitly different from a controlled multidimensional X gate. This is easy
to see, in that the latter is a STEP gate, allowing the qudit value to increase by at most one.
daxfohl marked this conversation as resolved.
Show resolved Hide resolved
For this gate however, the target qudit can jump from zero to any value, depending on the
control.

Note also that the 2D definition of CX follows as a special case.
"""

def __init__(self, dimension: int):
self._dimension = dimension

def _qid_shape_(self):
return self._dimension, self._dimension

def _unitary_(self):
u = np.zeros((self._dimension**2, self._dimension**2))
for i in range(self._dimension):
for j in range(self._dimension):
offset = i * self._dimension
row = offset + j
col = offset + 0 if i == j else i if j == 0 else j
u[row, col] = 1
return u

def __eq__(self, other):
return isinstance(other, _CX) and other._dimension == self._dimension


def _cx(dimension: int):
daxfohl marked this conversation as resolved.
Show resolved Hide resolved
daxfohl marked this conversation as resolved.
Show resolved Hide resolved
return ops.CX if dimension == 2 else _CX(dimension)
31 changes: 26 additions & 5 deletions cirq-core/cirq/transformers/measurement_transformers_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,19 +17,19 @@
import sympy

import cirq
from cirq.transformers.measurement_transformers import _MeasurementQid
from cirq.transformers.measurement_transformers import _cx, _MeasurementQid


def assert_equivalent_to_deferred(circuit: cirq.Circuit):
qubits = list(circuit.all_qubits())
sim = cirq.Simulator()
num_qubits = len(qubits)
for i in range(2**num_qubits):
bits = cirq.big_endian_int_to_bits(i, bit_count=num_qubits)
dimensions = [q.dimension for q in qubits]
for i in range(np.prod(dimensions)):
bits = cirq.big_endian_int_to_digits(i, base=dimensions)
modified = cirq.Circuit()
for j in range(num_qubits):
if bits[j]:
modified.append(cirq.X(qubits[j]))
modified.append(cirq.XPowGate(dimension=qubits[j].dimension)(qubits[j]) ** bits[j])
modified.append(circuit)
deferred = cirq.defer_measurements(modified)
result = sim.simulate(modified)
Expand Down Expand Up @@ -58,6 +58,27 @@ def test_basic():
)


def test_qudits():
q0, q1 = cirq.LineQid.range(2, dimension=3)
circuit = cirq.Circuit(
cirq.measure(q0, key='a'),
cirq.XPowGate(dimension=3).on(q1).with_classical_controls('a'),
cirq.measure(q1, key='b'),
)
assert_equivalent_to_deferred(circuit)
deferred = cirq.defer_measurements(circuit)
q_ma = _MeasurementQid('a', q0)
cirq.testing.assert_same_circuits(
deferred,
cirq.Circuit(
_cx(3)(q0, q_ma),
cirq.XPowGate(dimension=3).on(q1).controlled_by(q_ma, control_values=[[1, 2]]),
cirq.measure(q_ma, key='a'),
cirq.measure(q1, key='b'),
),
)


def test_nocompile_context():
q0, q1 = cirq.LineQubit.range(2)
circuit = cirq.Circuit(
Expand Down