Skip to content

Commit

Permalink
Add Tags and TaggedOperation class (#2670)
Browse files Browse the repository at this point in the history
Add TaggedOperation class that wraps Operation and adds ability to tag operations.

This PR adds a TaggedOperation class which is a wrapper around a sub-Operation but contains unstructured meta-data that can be associated with this operation.  This meta-data can be used to distinguish this instance of the operation from other operations of the same type.

Possible uses include using tags to signal optimization passes to skip the operation, marking operations as noise to aid in composing noise models, and giving hardware specific information.

Any modification of this operation will default to dropping the tags, as the tags are considered to apply only to that specific operation.
  • Loading branch information
dstrain115 committed Feb 8, 2020
1 parent 1e5bf55 commit 2715417
Show file tree
Hide file tree
Showing 8 changed files with 301 additions and 2 deletions.
1 change: 1 addition & 0 deletions cirq/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,6 +254,7 @@
SWAP,
SwapPowGate,
T,
TaggedOperation,
ThreeQubitGate,
ThreeQubitDiagonalGate,
TOFFOLI,
Expand Down
1 change: 1 addition & 0 deletions cirq/ops/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,7 @@
Gate,
Operation,
Qid,
TaggedOperation,
)

from cirq.ops.swap_gates import (
Expand Down
168 changes: 166 additions & 2 deletions cirq/ops/raw_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,12 @@

"""Basic types defining qubits, gates, and operations."""

from typing import (Any, Callable, Collection, Optional, Sequence, Tuple,
TYPE_CHECKING, Union)
from typing import (Any, Callable, Collection, Hashable, Optional, Sequence,
Tuple, TYPE_CHECKING, Union)

import abc
import functools
import numpy as np

from cirq import protocols, value
from cirq.type_workarounds import NotImplementedType
Expand Down Expand Up @@ -392,6 +393,26 @@ def with_qubits(self, *new_qubits: 'cirq.Qid') -> 'cirq.Operation':
`qubits` property.
"""

def with_tags(self, *new_tags: Hashable) -> 'cirq.TaggedOperation':
"""Creates a new TaggedOperation, with this op and the specified tags.
This method can be used to attach meta-data to specific operations
without affecting their functionality. The intended usage is to
attach classes intended for this purpose or strings to mark operations
for specific usage that will be recognized by consumers. Specific
examples include ignoring this operation in optimization passes,
hardware-specific functionality, or circuit diagram customizability.
Tags can be a list of any type of object that is useful to identify
this operation as long as the type is hashable. If you wish the
resulting operation to be eventually serialized into JSON, you should
also restrict the operation to be JSON serializable.
Args:
new_tags: The tags to wrap this operation in.
"""
return TaggedOperation(self, *new_tags)

def transform_qubits(self, func: Callable[['cirq.Qid'], 'cirq.Qid']
) -> 'Operation':
"""Returns the same operation, but with different qubits.
Expand Down Expand Up @@ -445,6 +466,149 @@ def validate_args(self, qubits: Sequence['cirq.Qid']):
_validate_qid_shape(self, qubits)


@value.value_equality
class TaggedOperation(Operation):
"""A specific operation instance that has been identified with a set
of Tags for special processing. This can be initialized with
Using Operation.with_tags(tag) or by TaggedOperation(op, tag).
Tags added can be of any type, but they should be Hashable in order
to allow equality checking. If you wish to serialize operations into
JSON, you should restrict yourself to only use objects that have a JSON
serialization.
See Operation.with_tags() for more information on intended usage.
"""

def __init__(self, sub_operation: 'cirq.Operation', *tags: Hashable):
self.sub_operation = sub_operation
self._tags = tuple(tags)

@property
def qubits(self) -> Tuple['cirq.Qid', ...]:
return self.sub_operation.qubits

@property
def gate(self) -> Optional['cirq.Gate']:
return self.sub_operation.gate

def with_qubits(self, *new_qubits: 'cirq.Qid'):
return TaggedOperation(self.sub_operation.with_qubits(*new_qubits),
*self._tags)

def controlled_by(self,
*control_qubits: 'cirq.Qid',
control_values: Optional[Sequence[
Union[int, Collection[int]]]] = None
) -> 'cirq.Operation':
return self.sub_operation.controlled_by(*control_qubits,
control_values=control_values)

@property
def tags(self) -> Tuple[Hashable, ...]:
"""Returns a tuple of the operation's tags."""
return self._tags

def with_tags(self, *new_tags: Hashable) -> 'cirq.TaggedOperation':
"""Creates a new TaggedOperation with combined tags.
Overloads Operation.with_tags to create a new TaggedOperation
that has the tags of this operation combined with the new_tags
specified as the parameter.
"""
return TaggedOperation(self.sub_operation, *self._tags, *new_tags)

def __str__(self):
tag_repr = ','.join(repr(t) for t in self._tags)
return f"cirq.TaggedOperation({repr(self.sub_operation)}, {tag_repr})"

def __repr__(self):
return str(self)

def _value_equality_values_(self):
return (self.sub_operation, self._tags)

@classmethod
def _from_json_dict_(cls, sub_operation, tags, **kwargs):
return cls(sub_operation, *tags)

def _json_dict_(self):
return protocols.obj_to_dict_helper(self, ['sub_operation', 'tags'])

def _decompose_(self) -> 'cirq.OP_TREE':
return protocols.decompose(self.sub_operation)

def _pauli_expansion_(self) -> value.LinearDict[str]:
return protocols.pauli_expansion(self.sub_operation)

def _apply_unitary_(self, args: 'protocols.ApplyUnitaryArgs'
) -> Union[np.ndarray, None, NotImplementedType]:
return protocols.apply_unitary(self.sub_operation, args, default=None)

def _has_unitary_(self) -> bool:
return protocols.has_unitary(self.sub_operation)

def _unitary_(self) -> Union[np.ndarray, NotImplementedType]:
return protocols.unitary(self.sub_operation, default=None)

def _commutes_(self, other: Any, *, atol: Union[int, float] = 1e-8
) -> Union[bool, NotImplementedType, None]:
return protocols.commutes(self.sub_operation, other, atol=atol)

def _has_mixture_(self) -> bool:
return protocols.has_mixture(self.sub_operation)

def _mixture_(self) -> Sequence[Tuple[float, Any]]:
return protocols.mixture(self.sub_operation, NotImplemented)

def _has_channel_(self) -> bool:
return protocols.has_channel(self.sub_operation)

def _channel_(self) -> Union[Tuple[np.ndarray], NotImplementedType]:
return protocols.channel(self.sub_operation, NotImplemented)

def _measurement_key_(self) -> str:
return protocols.measurement_key(self.sub_operation, NotImplemented)

def _is_parameterized_(self) -> bool:
return protocols.is_parameterized(self.sub_operation)

def _resolve_parameters_(self, resolver):
return protocols.resolve_parameters(self.sub_operation, resolver)

def _circuit_diagram_info_(self, args: 'cirq.CircuitDiagramInfoArgs'
) -> 'cirq.CircuitDiagramInfo':
return protocols.circuit_diagram_info(self.sub_operation, args,
NotImplemented)

def _trace_distance_bound_(self) -> float:
return protocols.trace_distance_bound(self.sub_operation)

def _phase_by_(self, phase_turns: float,
qubit_index: int) -> 'cirq.Operation':
return protocols.phase_by(self.sub_operation, phase_turns, qubit_index)

def __pow__(self, exponent: Any) -> 'cirq.Operation':
return self.sub_operation**exponent

def __mul__(self, other: Any) -> Any:
return self.sub_operation * other

def __rmul__(self, other: Any) -> Any:
return other * self.sub_operation

def _qasm_(self, args: 'protocols.QasmArgs') -> Optional[str]:
return protocols.qasm(self.sub_operation, args=args, default=None)

def _equal_up_to_global_phase_(self,
other: Any,
atol: Union[int, float] = 1e-8
) -> Union[NotImplementedType, bool]:
return protocols.equal_up_to_global_phase(self.sub_operation,
other,
atol=atol)


@value.value_equality
class _InverseCompositeGate(Gate):
"""The inverse of a composite gate."""
Expand Down
108 changes: 108 additions & 0 deletions cirq/ops/raw_types_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
# limitations under the License.

import pytest
import numpy as np
import sympy

import cirq

Expand Down Expand Up @@ -423,3 +425,109 @@ def _circuit_diagram_info_(self, args):
c = cirq.inverse(Gate2())
assert cirq.circuit_diagram_info(c) == cirq.CircuitDiagramInfo(
wire_symbols=('s!',), exponent=-1)


def test_tagged_operation_equality():
eq = cirq.testing.EqualsTester()
q1 = cirq.GridQubit(1, 1)
op = cirq.X(q1)
op2 = cirq.Y(q1)

eq.add_equality_group(op)
eq.add_equality_group(op.with_tags('tag1'),
cirq.TaggedOperation(op, 'tag1'))
eq.add_equality_group(op2.with_tags('tag1'),
cirq.TaggedOperation(op2, 'tag1'))
eq.add_equality_group(op.with_tags('tag2'),
cirq.TaggedOperation(op, 'tag2'))
eq.add_equality_group(op.with_tags('tag1', 'tag2'),
op.with_tags('tag1').with_tags('tag2'),
cirq.TaggedOperation(op, 'tag1', 'tag2'))


def test_tagged_operation():
q1 = cirq.GridQubit(1, 1)
q2 = cirq.GridQubit(2, 2)
op = cirq.X(q1).with_tags('tag1')
op_repr = "cirq.X.on(cirq.GridQubit(1, 1))"
assert repr(op) == f"cirq.TaggedOperation({op_repr}, 'tag1')"

assert op.qubits == (q1,)
assert op.tags == ('tag1',)
assert op.gate == cirq.X
assert op.with_qubits(q2) == cirq.X(q2).with_tags('tag1')
assert op.with_qubits(q2).qubits == (q2,)


def test_tagged_operation_forwards_protocols():
"""The results of all protocols applied to an operation with a tag should
be equivalent to the result without tags.
"""
q1 = cirq.GridQubit(1, 1)
q2 = cirq.GridQubit(1, 2)
h = cirq.H(q1)
tag = 'tag1'
tagged_h = cirq.H(q1).with_tags(tag)

np.testing.assert_equal(cirq.unitary(tagged_h), cirq.unitary(h))
assert cirq.has_unitary(tagged_h)
assert cirq.decompose(tagged_h) == cirq.decompose(h)
assert cirq.pauli_expansion(tagged_h) == cirq.pauli_expansion(h)
assert cirq.equal_up_to_global_phase(h, tagged_h)
assert np.isclose(cirq.channel(h), cirq.channel(tagged_h)).all()
assert cirq.circuit_diagram_info(h) == cirq.circuit_diagram_info(tagged_h)

assert (cirq.measurement_key(cirq.measure(
q1, key='blah').with_tags(tag)) == 'blah')

parameterized_op = cirq.XPowGate(
exponent=sympy.Symbol('t'))(q1).with_tags(tag)
assert cirq.is_parameterized(parameterized_op)
resolver = cirq.study.ParamResolver({'t': 0.25})
assert (cirq.resolve_parameters(
parameterized_op, resolver) == cirq.XPowGate(exponent=0.25)(q1))

y = cirq.Y(q1)
tagged_y = cirq.Y(q1).with_tags(tag)
assert tagged_y**0.5 == cirq.YPowGate(exponent=0.5)(q1)
assert tagged_y * 2 == (y * 2)
assert (3 * tagged_y == (3 * y))
assert cirq.phase_by(y, 0.125, 0) == cirq.phase_by(tagged_y, 0.125, 0)
controlled_y = tagged_y.controlled_by(q2)
assert controlled_y.qubits == (
q2,
q1,
)
assert isinstance(controlled_y, cirq.Operation)
assert not isinstance(controlled_y, cirq.TaggedOperation)

clifford_x = cirq.SingleQubitCliffordGate.X(q1)
tagged_x = cirq.SingleQubitCliffordGate.X(q1).with_tags(tag)
assert cirq.commutes(clifford_x, clifford_x)
assert cirq.commutes(tagged_x, clifford_x)
assert cirq.commutes(clifford_x, tagged_x)
assert cirq.commutes(tagged_x, tagged_x)

assert (cirq.trace_distance_bound(y**0.001) == cirq.trace_distance_bound(
(y**0.001).with_tags(tag)))

flip = cirq.bit_flip(0.5)(q1)
tagged_flip = cirq.bit_flip(0.5)(q1).with_tags(tag)
assert cirq.has_mixture(tagged_flip)
assert cirq.has_channel(tagged_flip)

flip_mixture = cirq.mixture(flip)
tagged_mixture = cirq.mixture(tagged_flip)
assert len(tagged_mixture) == 2
assert len(tagged_mixture[0]) == 2
assert len(tagged_mixture[1]) == 2
assert tagged_mixture[0][0] == flip_mixture[0][0]
assert np.isclose(tagged_mixture[0][1], flip_mixture[0][1]).all()
assert tagged_mixture[1][0] == flip_mixture[1][0]
assert np.isclose(tagged_mixture[1][1], flip_mixture[1][1]).all()

qubit_map = {q1: 'q1'}
qasm_args = cirq.QasmArgs(qubit_id_map=qubit_map)
assert (cirq.qasm(h, args=qasm_args) == cirq.qasm(tagged_h, args=qasm_args))

cirq.testing.assert_has_consistent_apply_unitary(tagged_h)
1 change: 1 addition & 0 deletions cirq/protocols/json_serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,7 @@ def two_qubit_matrix_gate(matrix):
cirq.experiments.SingleQubitReadoutCalibrationResult,
'SwapPowGate': cirq.SwapPowGate,
'SycamoreGate': cirq.google.SycamoreGate,
'TaggedOperation': cirq.TaggedOperation,
'TrialResult': cirq.TrialResult,
'TwoQubitMatrixGate': two_qubit_matrix_gate,
'_UnconstrainedDevice':
Expand Down
22 changes: 22 additions & 0 deletions cirq/protocols/json_test_data/TaggedOperation.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
{
"cirq_type": "TaggedOperation",
"sub_operation":
{
"cirq_type": "SingleQubitPauliStringGateOperation",
"pauli":
{
"cirq_type": "_PauliX",
"exponent": 1.0,
"global_shift": 0.0
},
"qubit": {
"cirq_type": "NamedQubit",
"name": "q1"
}
},
"tags":
[
"tag1",
"tag2"
]
}
1 change: 1 addition & 0 deletions cirq/protocols/json_test_data/TaggedOperation.repr
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
cirq.TaggedOperation(cirq.X.on(cirq.NamedQubit('q1')), 'tag1', 'tag2')
1 change: 1 addition & 0 deletions docs/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ Unitary effects that can be applied to one or more qubits.
cirq.SingleQubitGate
cirq.SingleQubitMatrixGate
cirq.SwapPowGate
cirq.TaggedOperation
cirq.ThreeQubitDiagonalGate
cirq.TwoQubitMatrixGate
cirq.WaitGate
Expand Down

0 comments on commit 2715417

Please sign in to comment.