Skip to content

Commit

Permalink
Preserve (and resolve) tags when resolving TaggedOperation (#3172)
Browse files Browse the repository at this point in the history
Previously, tags would get dropped from tagged ops when resolving parameters. Instead, we would like to preserve these tags. In addition, we allow tags themselves to be parameterized and resolve them as well when resolving the tagged op.
  • Loading branch information
maffoo committed Jul 23, 2020
1 parent 3577ea2 commit d16f3a1
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 3 deletions.
8 changes: 6 additions & 2 deletions cirq/ops/raw_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -609,10 +609,14 @@ def _measurement_key_(self) -> str:
return protocols.measurement_key(self.sub_operation, NotImplemented)

def _is_parameterized_(self) -> bool:
return protocols.is_parameterized(self.sub_operation)
return protocols.is_parameterized(self.sub_operation) or any(
protocols.is_parameterized(tag) for tag in self.tags)

def _resolve_parameters_(self, resolver):
return protocols.resolve_parameters(self.sub_operation, resolver)
resolved_op = protocols.resolve_parameters(self.sub_operation, resolver)
resolved_tags = (
protocols.resolve_parameters(tag, resolver) for tag in self._tags)
return TaggedOperation(resolved_op, *resolved_tags)

def _circuit_diagram_info_(self, args: 'cirq.CircuitDiagramInfoArgs'
) -> 'cirq.CircuitDiagramInfo':
Expand Down
29 changes: 28 additions & 1 deletion cirq/ops/raw_types_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -601,7 +601,8 @@ def test_tagged_operation_forwards_protocols():
assert cirq.is_parameterized(parameterized_op)
resolver = cirq.study.ParamResolver({'t': 0.25})
assert (cirq.resolve_parameters(
parameterized_op, resolver) == cirq.XPowGate(exponent=0.25)(q1))
parameterized_op,
resolver) == cirq.XPowGate(exponent=0.25)(q1).with_tags(tag))

y = cirq.Y(q1)
tagged_y = cirq.Y(q1).with_tags(tag)
Expand Down Expand Up @@ -649,6 +650,32 @@ def test_tagged_operation_forwards_protocols():
cirq.testing.assert_has_consistent_apply_unitary(tagged_h)


class ParameterizableTag:

def __init__(self, value):
self.value = value

def __eq__(self, other):
return self.value == other.value

def _is_parameterized_(self) -> bool:
return cirq.is_parameterized(self.value)

def _resolve_parameters_(self, resolver) -> 'ParameterizableTag':
return ParameterizableTag(cirq.resolve_parameters(self.value, resolver))


def test_tagged_operation_resolves_parameterized_tags():
q = cirq.GridQubit(0, 0)
tag = ParameterizableTag(sympy.Symbol('t'))
assert cirq.is_parameterized(tag)
op = cirq.Z(q).with_tags(tag)
assert cirq.is_parameterized(op)
resolved_op = cirq.resolve_parameters(op, {'t': 10})
assert resolved_op == cirq.Z(q).with_tags(ParameterizableTag(10))
assert not cirq.is_parameterized(resolved_op)


def test_inverse_composite_standards():

@cirq.value_equality
Expand Down

0 comments on commit d16f3a1

Please sign in to comment.