Skip to content

Commit

Permalink
Refactor AbstractControlValues and it's implementations to fix multip…
Browse files Browse the repository at this point in the history
…le bugs and improve consistency (#5788)

* Remove stale TODO

* Refactor AbstractControlValues and it's implementations to fix multiple bugs and improve consistency

* Revert unrelate change in consistent_protocols

* Fix tests and update json

* Add more json

* Add a lot more tests to control_values_test.py

* Address feedback from maffoo@

* Fix pylint and mypy errors

* maffoo@ feedback part-2

* Update diagrams and change SumOfProducts API to Collection[Sequence[int]]
  • Loading branch information
tanujkhattar authored Jul 18, 2022
1 parent 3c67cd7 commit beab8b7
Show file tree
Hide file tree
Showing 18 changed files with 819 additions and 425 deletions.
339 changes: 187 additions & 152 deletions cirq-core/cirq/ops/control_values.py

Large diffs are not rendered by default.

404 changes: 304 additions & 100 deletions cirq-core/cirq/ops/control_values_test.py

Large diffs are not rendered by default.

50 changes: 20 additions & 30 deletions cirq-core/cirq/ops/controlled_gate.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,29 +82,31 @@ def __init__(
_validate_sub_object(sub_gate)
if num_controls is None:
if control_values is not None:
num_controls = len(control_values)
num_controls = (
control_values._num_qubits_()
if isinstance(control_values, cv.AbstractControlValues)
else len(control_values)
)
elif control_qid_shape is not None:
num_controls = len(control_qid_shape)
else:
num_controls = 1
if control_values is None:
control_values = ((1,),) * num_controls
if num_controls != len(control_values):
raise ValueError('len(control_values) != num_controls')

# Convert to `cv.ProductOfSums` if input is a tuple of control values for each qubit.
if not isinstance(control_values, cv.AbstractControlValues):
control_values = cv.ProductOfSums(control_values)

if num_controls != protocols.num_qubits(control_values):
raise ValueError('cirq.num_qubits(control_values) != num_controls')

if control_qid_shape is None:
control_qid_shape = (2,) * num_controls
if num_controls != len(control_qid_shape):
raise ValueError('len(control_qid_shape) != num_controls')
self._control_qid_shape = tuple(control_qid_shape)

# Convert to sorted tuples
if not isinstance(control_values, cv.AbstractControlValues):
control_values = cv.ProductOfSums(
tuple(
(val,) if isinstance(val, int) else tuple(sorted(val)) for val in control_values
)
)
self._control_values = control_values

# Verify control values not out of bounds
Expand Down Expand Up @@ -141,14 +143,11 @@ def _decompose_(self, qubits):
protocols.has_unitary(self.sub_gate)
and protocols.num_qubits(self.sub_gate) == 1
and self._qid_shape_() == (2,) * len(self._qid_shape_())
and isinstance(self.control_values, cv.ProductOfSums)
):
if not isinstance(self.control_values, cv.ProductOfSums):
return NotImplemented
control_qubits = list(qubits[: self.num_controls()])
invert_ops: List['cirq.Operation'] = []
for cvals, cqbit in zip(
self.control_values._identifier(), qubits[: self.num_controls()]
):
for cvals, cqbit in zip(self.control_values, qubits[: self.num_controls()]):
if set(cvals) == {0}:
invert_ops.append(common_gates.X(cqbit))
elif set(cvals) == {0, 1}:
Expand Down Expand Up @@ -271,8 +270,6 @@ def _trace_distance_bound_(self) -> Optional[float]:
def _circuit_diagram_info_(
self, args: 'cirq.CircuitDiagramInfoArgs'
) -> 'cirq.CircuitDiagramInfo':
if not isinstance(self.control_values, cv.ProductOfSums):
return NotImplemented
sub_args = protocols.CircuitDiagramInfoArgs(
known_qubit_count=(
args.known_qubit_count - self.num_controls()
Expand All @@ -290,27 +287,20 @@ def _circuit_diagram_info_(
if sub_info is None:
return NotImplemented

def get_symbol(vals):
if tuple(vals) == (1,):
return '@'
return f"({','.join(map(str, vals))})"
cv_info = protocols.circuit_diagram_info(self.control_values)

return protocols.CircuitDiagramInfo(
wire_symbols=(
*(get_symbol(vals) for vals in self.control_values._identifier()),
*sub_info.wire_symbols,
),
exponent=sub_info.exponent,
wire_symbols=(*cv_info.wire_symbols, *sub_info.wire_symbols), exponent=sub_info.exponent
)

def __str__(self) -> str:
return self.control_values.diagram_repr() + str(self.sub_gate)
return str(self.control_values) + str(self.sub_gate)

def __repr__(self) -> str:
if self.num_controls() == 1 and self.control_values._are_ones():
if self.num_controls() == 1 and self.control_values.is_trivial:
return f'cirq.ControlledGate(sub_gate={self.sub_gate!r})'

if self.control_values._are_ones() and set(self.control_qid_shape) == {2}:
if self.control_values.is_trivial and set(self.control_qid_shape) == {2}:
return (
f'cirq.ControlledGate(sub_gate={self.sub_gate!r}, '
f'num_controls={self.num_controls()!r})'
Expand All @@ -323,7 +313,7 @@ def __repr__(self) -> str:

def _json_dict_(self) -> Dict[str, Any]:
return {
'control_values': self.control_values._identifier(),
'control_values': self.control_values,
'control_qid_shape': self.control_qid_shape,
'sub_gate': self.sub_gate,
}
Expand Down
150 changes: 90 additions & 60 deletions cirq-core/cirq/ops/controlled_gate_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@

import cirq
from cirq.type_workarounds import NotImplementedType
from cirq.ops import AbstractControlValues


class GateUsingWorkspaceForApplyUnitary(cirq.testing.SingleQubitGate):
Expand Down Expand Up @@ -89,14 +88,25 @@ def __str__(self):

C0Y = cirq.ControlledGate(cirq.Y, control_values=[0])
C0C1H = cirq.ControlledGate(cirq.ControlledGate(cirq.H, control_values=[1]), control_values=[0])

nand_control_values = cirq.SumOfProducts([(0, 1), (1, 0), (1, 1)])
xor_control_values = cirq.SumOfProducts([[0, 1], [1, 0]], name="xor")
C_01_10_11H = cirq.ControlledGate(cirq.H, control_values=nand_control_values)
C_xorH = cirq.ControlledGate(cirq.H, control_values=xor_control_values)
C0C_xorH = cirq.ControlledGate(C_xorH, control_values=[0])

C0Restricted = cirq.ControlledGate(RestrictedGate(), control_values=[0])
C_xorRestricted = cirq.ControlledGate(RestrictedGate(), control_values=xor_control_values)

C2Y = cirq.ControlledGate(cirq.Y, control_values=[2], control_qid_shape=(3,))
C2C2H = cirq.ControlledGate(
cirq.ControlledGate(cirq.H, control_values=[2], control_qid_shape=(3,)),
control_values=[2],
control_qid_shape=(3,),
)
C_02_20H = cirq.ControlledGate(
cirq.H, control_values=cirq.SumOfProducts([[0, 2], [1, 0]]), control_qid_shape=(2, 3)
)
C2Restricted = cirq.ControlledGate(RestrictedGate(), control_values=[2], control_qid_shape=(3,))


Expand All @@ -107,7 +117,7 @@ def test_init():


def test_init2():
with pytest.raises(ValueError, match=r'len\(control_values\) != num_controls'):
with pytest.raises(ValueError, match=r'cirq\.num_qubits\(control_values\) != num_controls'):
cirq.ControlledGate(cirq.Z, num_controls=1, control_values=(1, 0))
with pytest.raises(ValueError, match=r'len\(control_qid_shape\) != num_controls'):
cirq.ControlledGate(cirq.Z, num_controls=1, control_qid_shape=(2, 2))
Expand All @@ -125,15 +135,15 @@ def test_init2():
gate = cirq.ControlledGate(cirq.Z, 1)
assert gate.sub_gate is cirq.Z
assert gate.num_controls() == 1
assert gate.control_values == ((1,),)
assert gate.control_values == cirq.ProductOfSums(((1,),))
assert gate.control_qid_shape == (2,)
assert gate.num_qubits() == 2
assert cirq.qid_shape(gate) == (2, 2)

gate = cirq.ControlledGate(cirq.Z, 2)
assert gate.sub_gate is cirq.Z
assert gate.num_controls() == 2
assert gate.control_values == ((1,), (1,))
assert gate.control_values == cirq.ProductOfSums(((1,), (1,)))
assert gate.control_qid_shape == (2, 2)
assert gate.num_qubits() == 3
assert cirq.qid_shape(gate) == (2, 2, 2)
Expand All @@ -143,7 +153,7 @@ def test_init2():
)
assert gate.sub_gate is cirq.Z
assert gate.num_controls() == 7
assert gate.control_values == ((1,),) * 7
assert gate.control_values == cirq.ProductOfSums(((1,),) * 7)
assert gate.control_qid_shape == (2,) * 7
assert gate.num_qubits() == 8
assert cirq.qid_shape(gate) == (2,) * 8
Expand All @@ -162,15 +172,15 @@ def test_init2():
gate = cirq.ControlledGate(cirq.Z, control_values=(0, (0, 1)))
assert gate.sub_gate is cirq.Z
assert gate.num_controls() == 2
assert gate.control_values == ((0,), (0, 1))
assert gate.control_values == cirq.ProductOfSums(((0,), (0, 1)))
assert gate.control_qid_shape == (2, 2)
assert gate.num_qubits() == 3
assert cirq.qid_shape(gate) == (2, 2, 2)

gate = cirq.ControlledGate(cirq.Z, control_qid_shape=(3, 3))
assert gate.sub_gate is cirq.Z
assert gate.num_controls() == 2
assert gate.control_values == ((1,), (1,))
assert gate.control_values == cirq.ProductOfSums(((1,), (1,)))
assert gate.control_qid_shape == (3, 3)
assert gate.num_qubits() == 3
assert cirq.qid_shape(gate) == (3, 3, 2)
Expand Down Expand Up @@ -232,9 +242,15 @@ def test_eq():
eq.add_equality_group(
cirq.ControlledGate(cirq.H, control_values=[1, (0, 2)], control_qid_shape=[2, 3]),
cirq.ControlledGate(cirq.H, control_values=(1, [0, 2]), control_qid_shape=(2, 3)),
cirq.ControlledGate(
cirq.H, control_values=cirq.SumOfProducts([[1, 0], [1, 2]]), control_qid_shape=(2, 3)
),
)
eq.add_equality_group(
cirq.ControlledGate(cirq.H, control_values=[(2, 0), 1], control_qid_shape=[3, 2])
cirq.ControlledGate(cirq.H, control_values=[(2, 0), 1], control_qid_shape=[3, 2]),
cirq.ControlledGate(
cirq.H, control_values=cirq.SumOfProducts([[2, 1], [0, 1]]), control_qid_shape=(3, 2)
),
)
eq.add_equality_group(
cirq.ControlledGate(cirq.H, control_values=[1, 0], control_qid_shape=[2, 3]),
Expand Down Expand Up @@ -278,18 +294,21 @@ def _has_mixture_(self):
g.controlled(control_values=[1]),
g.controlled(control_qid_shape=(2,)),
cirq.ControlledGate(g, num_controls=1),
g.controlled(control_values=cirq.SumOfProducts([[1]])),
)
eq.add_equality_group(
cirq.ControlledGate(g, num_controls=2),
g.controlled(control_values=[1, 1]),
g.controlled(control_qid_shape=[2, 2]),
g.controlled(num_controls=2),
g.controlled().controlled(),
g.controlled(control_values=cirq.SumOfProducts([[1, 1]])),
)
eq.add_equality_group(
cirq.ControlledGate(g, control_values=[0, 1]),
g.controlled(control_values=[0, 1]),
g.controlled(control_values=[1]).controlled(control_values=[0]),
g.controlled(control_values=cirq.SumOfProducts([[1]])).controlled(control_values=[0]),
)
eq.add_equality_group(g.controlled(control_values=[0]).controlled(control_values=[1]))
eq.add_equality_group(
Expand Down Expand Up @@ -350,6 +369,20 @@ def test_unitary():
atol=1e-8,
)

C_xorX = cirq.ControlledGate(cirq.X, control_values=xor_control_values)
# fmt: off
np.testing.assert_allclose(cirq.unitary(C_xorX), np.array([
[1, 0, 0, 0, 0, 0, 0, 0],
[0, 1, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 1, 0, 0, 0, 0],
[0, 0, 1, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 1, 0, 0],
[0, 0, 0, 0, 1, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 1, 0],
[0, 0, 0, 0, 0, 0, 0, 1]]
))
# fmt: on


@pytest.mark.parametrize(
'gate, should_decompose_to_target',
Expand Down Expand Up @@ -380,6 +413,10 @@ def test_unitary():
(cirq.MatrixGate(cirq.testing.random_unitary(4, random_state=1234)), False),
(cirq.XX ** sympy.Symbol("s"), True),
(cirq.CZ ** sympy.Symbol("s"), True),
# Non-trivial `cirq.ProductOfSum` controls.
(C_01_10_11H, False),
(C_xorH, False),
(C0C_xorH, False),
],
)
def test_controlled_gate_is_consistent(gate: cirq.Gate, should_decompose_to_target):
Expand Down Expand Up @@ -507,7 +544,7 @@ def _has_unitary_(self):
return True


def test_circuit_diagram():
def test_circuit_diagram_product_of_sums():
qubits = cirq.LineQubit.range(3)
c = cirq.Circuit()
c.append(cirq.ControlledGate(MultiH(2))(*qubits))
Expand Down Expand Up @@ -542,6 +579,35 @@ def test_circuit_diagram():
)


def test_circuit_diagram_sum_of_products():
q = cirq.LineQubit.range(4)
c = cirq.Circuit(C_xorH.on(*q[:3]), C_01_10_11H.on(*q[:3]), C0C_xorH.on(*q))
cirq.testing.assert_has_diagram(
c,
"""
0: ───@────────@(011)───@(00)───
│ │ │
1: ───@(xor)───@(101)───@(01)───
│ │ │
2: ───H────────H────────@(10)───
3: ─────────────────────H───────
""",
)
q = cirq.LineQid.for_qid_shape((2, 3, 2))
c = cirq.Circuit(C_02_20H(*q))
cirq.testing.assert_has_diagram(
c,
"""
0 (d=2): ───@(01)───
1 (d=3): ───@(20)───
2 (d=2): ───H───────
""",
)


class MockGate(cirq.testing.TwoQubitGate):
def _circuit_diagram_info_(self, args: cirq.CircuitDiagramInfoArgs) -> cirq.CircuitDiagramInfo:
self.captured_diagram_args = args
Expand Down Expand Up @@ -571,12 +637,21 @@ def test_bounded_effect():
assert cirq.trace_distance_bound(cirq.ControlledGate(cirq.X**foo)) == 1


def test_repr():
cirq.testing.assert_equivalent_repr(cirq.ControlledGate(cirq.Z))
cirq.testing.assert_equivalent_repr(cirq.ControlledGate(cirq.Z, num_controls=1))
cirq.testing.assert_equivalent_repr(cirq.ControlledGate(cirq.Z, num_controls=2))
cirq.testing.assert_equivalent_repr(C0C1H)
cirq.testing.assert_equivalent_repr(C2C2H)
@pytest.mark.parametrize(
'gate',
[
cirq.ControlledGate(cirq.Z),
cirq.ControlledGate(cirq.Z, num_controls=1),
cirq.ControlledGate(cirq.Z, num_controls=2),
C0C1H,
C2C2H,
C_01_10_11H,
C_xorH,
C_02_20H,
],
)
def test_repr(gate):
cirq.testing.assert_equivalent_repr(gate)


def test_str():
Expand All @@ -597,48 +672,3 @@ def test_controlled_mixture():
c_yes = cirq.ControlledGate(sub_gate=cirq.phase_flip(0.25), num_controls=1)
assert cirq.has_mixture(c_yes)
assert cirq.approx_eq(cirq.mixture(c_yes), [(0.75, np.eye(4)), (0.25, cirq.unitary(cirq.CZ))])


class MockControlValues(AbstractControlValues):
def __and__(self, other):
pass

def _expand(self):
pass

def diagram_repr(self):
pass

def _number_variables(self):
pass

def __len__(self):
return 1

def _identifier(self):
pass

def __hash__(self):
pass

def __repr__(self):
pass

def validate(self, shapes):
pass

def _are_ones(self):
pass

def _json_dict_(self):
pass


def test_decompose_applies_only_to_ProductOfSums():
g = cirq.ControlledGate(cirq.X, control_values=MockControlValues())
assert g._decompose_(None) is NotImplemented


def test_circuit_diagram_info_applies_only_to_ProductOfSums():
g = cirq.ControlledGate(cirq.X, control_values=MockControlValues())
assert g._circuit_diagram_info_(None) is NotImplemented
Loading

0 comments on commit beab8b7

Please sign in to comment.