Skip to content

Commit

Permalink
Add tags to circuit diagrams (#2759)
Browse files Browse the repository at this point in the history
* Add tags to circuit diagrams

- Print out tags in circuit diagrams
- Also includes a parameter 'include_tags' that can disable this feature.
- Also some minor fixes to global phase circuit diagrams.
  • Loading branch information
dstrain115 committed Feb 18, 2020
1 parent 37a222f commit 82d4fb3
Show file tree
Hide file tree
Showing 6 changed files with 210 additions and 28 deletions.
56 changes: 40 additions & 16 deletions cirq/circuits/circuit.py
Original file line number Diff line number Diff line change
Expand Up @@ -1573,6 +1573,7 @@ def to_text_diagram(
*,
use_unicode_characters: bool = True,
transpose: bool = False,
include_tags: bool = True,
precision: Optional[int] = 3,
qubit_order: 'cirq.QubitOrderOrList' = ops.QubitOrder.DEFAULT
) -> str:
Expand All @@ -1582,6 +1583,7 @@ def to_text_diagram(
use_unicode_characters: Determines if unicode characters are
allowed (as opposed to ascii-only diagrams).
transpose: Arranges qubit wires vertically instead of horizontally.
include_tags: Whether tags on TaggedOperations should be printed
precision: Number of digits to display in text diagram
qubit_order: Determines how qubits are ordered in the diagram.
Expand All @@ -1590,6 +1592,7 @@ def to_text_diagram(
"""
diagram = self.to_text_diagram_drawer(
use_unicode_characters=use_unicode_characters,
include_tags=include_tags,
precision=precision,
qubit_order=qubit_order,
transpose=transpose)
Expand All @@ -1607,6 +1610,7 @@ def to_text_diagram_drawer(
use_unicode_characters: bool = True,
qubit_namer: Optional[Callable[['cirq.Qid'], str]] = None,
transpose: bool = False,
include_tags: bool = True,
precision: Optional[int] = 3,
qubit_order: 'cirq.QubitOrderOrList' = ops.QubitOrder.DEFAULT,
get_circuit_diagram_info: Optional[
Expand Down Expand Up @@ -1638,22 +1642,21 @@ def to_text_diagram_drawer(
diagram.write(0, 0, '')
for q, i in qubit_map.items():
diagram.write(0, i, qubit_namer(q))

if any(
isinstance(op, cirq.GlobalPhaseOperation)
isinstance(op, cirq.GlobalPhaseOperation) or
(isinstance(op, cirq.TaggedOperation) and
isinstance(op.sub_operation, cirq.GlobalPhaseOperation))
for op in self.all_operations()):
diagram.write(0,
max(qubit_map.values(), default=0) + 1,
'global phase:')

moment_groups = [] # type: List[Tuple[int, int]]
for moment in self._moments:
_draw_moment_in_diagram(moment,
use_unicode_characters,
qubit_map,
diagram,
precision,
moment_groups,
get_circuit_diagram_info)
_draw_moment_in_diagram(moment, use_unicode_characters, qubit_map,
diagram, precision, moment_groups,
get_circuit_diagram_info, include_tags)

w = diagram.width()
for i in qubit_map.values():
Expand Down Expand Up @@ -1809,12 +1812,20 @@ def _get_operation_circuit_diagram_info_with_fallback(
# Fallback to a default representation using the operation's __str__.
name = str(op)

# For TaggedOperation, use the sub_operations __str__ instead
if isinstance(op, cirq.TaggedOperation):
name = str(op.sub_operation)

# Representation usually looks like 'gate(qubit1, qubit2, etc)'.
# Try to cut off the qubit part, since that would be redundant information.
redundant_tail = '({})'.format(', '.join(str(e) for e in op.qubits))
if name.endswith(redundant_tail):
name = name[:-len(redundant_tail)]

# Add tags onto the representation, if they exist
if isinstance(op, cirq.TaggedOperation):
name += f'{list(op.tags)}'

# Include ordering in the qubit labels.
symbols = (name,) + tuple('#{}'.format(i + 1)
for i in range(1, len(op.qubits)))
Expand Down Expand Up @@ -1876,7 +1887,8 @@ def _draw_moment_in_diagram(
moment_groups: List[Tuple[int, int]],
get_circuit_diagram_info: Optional[
Callable[['cirq.Operation', 'cirq.CircuitDiagramInfoArgs'],
'cirq.CircuitDiagramInfo']] = None):
'cirq.CircuitDiagramInfo']] = None,
include_tags: bool = True):
if get_circuit_diagram_info is None:
get_circuit_diagram_info = (
_get_operation_circuit_diagram_info_with_fallback)
Expand All @@ -1902,7 +1914,8 @@ def _draw_moment_in_diagram(
known_qubit_count=len(op.qubits),
use_unicode_characters=use_unicode_characters,
qubit_map=qubit_map,
precision=precision)
precision=precision,
include_tags=include_tags)
info = get_circuit_diagram_info(op, args)

# Draw vertical line linking the gate's qubits.
Expand All @@ -1929,15 +1942,26 @@ def _draw_moment_in_diagram(
if x > max_x:
max_x = x

global_phase = np.product([
complex(e.coefficient)
for e in moment
if isinstance(e, ops.GlobalPhaseOperation)
])
if global_phase != 1:
global_phase = None
tags: List[Any] = []
for op in moment:
if (isinstance(op, ops.TaggedOperation) and
isinstance(op.sub_operation, ops.GlobalPhaseOperation)):
tags.extend(op.tags)
op = op.sub_operation
if isinstance(op, ops.GlobalPhaseOperation):
if global_phase:
global_phase *= complex(op.coefficient)
else:
global_phase = complex(op.coefficient)

# Print out global phase, unless it's 1 (phase of 0pi) or it's the only op.
if global_phase and (global_phase != 1 or not non_global_ops):
desc = _formatted_phase(global_phase, use_unicode_characters, precision)
if desc:
y = max(qubit_map.values(), default=0) + 1
if tags and include_tags:
desc = desc + str(tags)
out_diagram.write(x0, y, desc)

if not non_global_ops:
Expand Down
33 changes: 33 additions & 0 deletions cirq/circuits/circuit_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2079,6 +2079,39 @@ def test_diagram_wgate_none_precision():
precision=None)


def test_diagram_global_phase():
qa = cirq.NamedQubit('a')
global_phase = cirq.GlobalPhaseOperation(coefficient=1j)
c = cirq.Circuit([global_phase])
cirq.testing.assert_has_diagram(c,
"\n\nglobal phase: 0.5pi",
use_unicode_characters=False,
precision=2)
cirq.testing.assert_has_diagram(c,
"\n\nglobal phase: 0.5π",
use_unicode_characters=True,
precision=2)

c = cirq.Circuit([cirq.X(qa), global_phase, global_phase])
cirq.testing.assert_has_diagram(c,
"""\
a: ─────────────X───
global phase: π""",
use_unicode_characters=True,
precision=2)
c = cirq.Circuit([cirq.X(qa), global_phase],
cirq.Moment([cirq.X(qa), global_phase]))
cirq.testing.assert_has_diagram(c,
"""\
a: ─────────────X──────X──────
global phase: 0.5π 0.5π
""",
use_unicode_characters=True,
precision=2)


def test_has_unitary():

class NonUnitary(cirq.SingleQubitGate):
Expand Down
11 changes: 9 additions & 2 deletions cirq/ops/raw_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -578,8 +578,15 @@ def _resolve_parameters_(self, resolver):

def _circuit_diagram_info_(self, args: 'cirq.CircuitDiagramInfoArgs'
) -> 'cirq.CircuitDiagramInfo':
return protocols.circuit_diagram_info(self.sub_operation, args,
NotImplemented)
sub_op_info = protocols.circuit_diagram_info(self.sub_operation, args,
NotImplemented)
# Add tag to wire symbol if it exists.
if (sub_op_info is not NotImplemented and args.include_tags and
sub_op_info.wire_symbols):
sub_op_info.wire_symbols = (
(sub_op_info.wire_symbols[0] + str(list(self._tags)),) +
sub_op_info.wire_symbols[1:])
return sub_op_info

def _trace_distance_bound_(self) -> float:
return protocols.trace_distance_bound(self.sub_operation)
Expand Down
104 changes: 103 additions & 1 deletion cirq/ops/raw_types_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -459,6 +459,109 @@ def test_tagged_operation():
assert op.with_qubits(q2).qubits == (q2,)


def test_circuit_diagram():
h = cirq.H(cirq.GridQubit(1, 1))
tagged_h = h.with_tags('tag1')
expected = cirq.CircuitDiagramInfo(wire_symbols=("H['tag1']",),
exponent=1.0,
connected=True,
exponent_qubit_index=None,
auto_exponent_parens=True)
args = cirq.CircuitDiagramInfoArgs(None, None, None, None, None, False)
assert cirq.circuit_diagram_info(tagged_h) == expected
assert (cirq.circuit_diagram_info(tagged_h,
args) == cirq.circuit_diagram_info(h))

c = cirq.Circuit(tagged_h)
diagram_with_tags = "(1, 1): ───H['tag1']───"
diagram_without_tags = "(1, 1): ───H───"
assert str(cirq.Circuit(tagged_h)) == diagram_with_tags
assert c.to_text_diagram() == diagram_with_tags
assert c.to_text_diagram(include_tags=False) == diagram_without_tags


def test_circuit_diagram_tagged_global_phase():
# Tests global phase operation
q = cirq.NamedQubit('a')
global_phase = cirq.GlobalPhaseOperation(coefficient=-1.0).with_tags('tag0')

# Just global phase in a circuit
assert (cirq.circuit_diagram_info(global_phase,
default='default') == 'default')
cirq.testing.assert_has_diagram(cirq.Circuit(global_phase),
"\n\nglobal phase: π['tag0']",
use_unicode_characters=True)
cirq.testing.assert_has_diagram(cirq.Circuit(global_phase),
"\n\nglobal phase: π",
use_unicode_characters=True,
include_tags=False)

expected = cirq.CircuitDiagramInfo(wire_symbols=(),
exponent=1.0,
connected=True,
exponent_qubit_index=None,
auto_exponent_parens=True)

# Operation with no qubits and returns diagram info with no wire symbols
class NoWireSymbols(cirq.GlobalPhaseOperation):

def _circuit_diagram_info_(self, args: 'cirq.CircuitDiagramInfoArgs'
) -> 'cirq.CircuitDiagramInfo':
return expected

no_wire_symbol_op = NoWireSymbols(coefficient=-1.0).with_tags('tag0')
assert (cirq.circuit_diagram_info(no_wire_symbol_op,
default='default') == expected)
cirq.testing.assert_has_diagram(cirq.Circuit(no_wire_symbol_op),
"\n\nglobal phase: π['tag0']",
use_unicode_characters=True)

# Two global phases in one moment
tag1 = cirq.GlobalPhaseOperation(coefficient=1j).with_tags('tag1')
tag2 = cirq.GlobalPhaseOperation(coefficient=1j).with_tags('tag2')
c = cirq.Circuit([cirq.X(q), tag1, tag2])
cirq.testing.assert_has_diagram(c,
"""\
a: ─────────────X───────────────────
global phase: π['tag1', 'tag2']""",
use_unicode_characters=True,
precision=2)

# Two moments with global phase, one with another tagged gate
c = cirq.Circuit([cirq.X(q).with_tags('x_tag'), tag1])
c.append(cirq.Moment([cirq.X(q), tag2]))
for m in c:
print(m)
print('----')
cirq.testing.assert_has_diagram(c,
"""\
a: ─────────────X['x_tag']─────X──────────────
global phase: 0.5π['tag1'] 0.5π['tag2']
""",
use_unicode_characters=True,
include_tags=True)


def test_circuit_diagram_no_circuit_diagram():

class NoCircuitDiagram(cirq.Gate):

def num_qubits(self) -> int:
return 1

def __repr__(self):
return 'guess-i-will-repr'

q = cirq.GridQubit(1, 1)
expected = "(1, 1): ───guess-i-will-repr───"
assert cirq.Circuit(NoCircuitDiagram()(q)).to_text_diagram() == expected
expected = "(1, 1): ───guess-i-will-repr['taggy']───"
assert cirq.Circuit(
NoCircuitDiagram()(q).with_tags('taggy')).to_text_diagram() == expected


def test_tagged_operation_forwards_protocols():
"""The results of all protocols applied to an operation with a tag should
be equivalent to the result without tags.
Expand All @@ -475,7 +578,6 @@ def test_tagged_operation_forwards_protocols():
assert cirq.pauli_expansion(tagged_h) == cirq.pauli_expansion(h)
assert cirq.equal_up_to_global_phase(h, tagged_h)
assert np.isclose(cirq.channel(h), cirq.channel(tagged_h)).all()
assert cirq.circuit_diagram_info(h) == cirq.circuit_diagram_info(tagged_h)

assert (cirq.measurement_key(cirq.measure(
q1, key='blah').with_tags(tag)) == 'blah')
Expand Down
24 changes: 16 additions & 8 deletions cirq/protocols/circuit_diagram_info_protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,37 +102,45 @@ class CircuitDiagramInfoArgs:
precision: The number of digits after the decimal to show for numbers in
the text diagram. None means use full precision.
qubit_map: The map from qubits to diagram positions.
include_tags: Whether to print tags from TaggedOperations
"""

UNINFORMED_DEFAULT = None # type: CircuitDiagramInfoArgs

def __init__(self, known_qubits: Optional[Iterable['cirq.Qid']],
known_qubit_count: Optional[int], use_unicode_characters: bool,
def __init__(self,
known_qubits: Optional[Iterable['cirq.Qid']],
known_qubit_count: Optional[int],
use_unicode_characters: bool,
precision: Optional[int],
qubit_map: Optional[Dict['cirq.Qid', int]]) -> None:
qubit_map: Optional[Dict['cirq.Qid', int]],
include_tags: bool = True) -> None:
self.known_qubits = (None
if known_qubits is None else tuple(known_qubits))
self.known_qubit_count = known_qubit_count
self.use_unicode_characters = use_unicode_characters
self.precision = precision
self.qubit_map = qubit_map
self.include_tags = include_tags

def _value_equality_values_(self):
return (self.known_qubits, self.known_qubit_count,
self.use_unicode_characters, self.precision,
None if self.qubit_map is None else tuple(
sorted(self.qubit_map.items(), key=lambda e: e[0])))
sorted(self.qubit_map.items(), key=lambda e: e[0])),
self.include_tags)

def __repr__(self):
return ('cirq.CircuitDiagramInfoArgs('
'known_qubits={!r}, '
'known_qubit_count={!r}, '
'use_unicode_characters={!r}, '
'precision={!r}, '
'qubit_map={!r})'.format(self.known_qubits,
self.known_qubit_count,
self.use_unicode_characters,
self.precision, self.qubit_map))
'qubit_map={!r},'
'include_tags={!r})'.format(self.known_qubits,
self.known_qubit_count,
self.use_unicode_characters,
self.precision, self.qubit_map,
self.include_tags))

def format_real(self, val: Union[sympy.Basic, int, float]) -> str:
if isinstance(val, sympy.Basic):
Expand Down
10 changes: 9 additions & 1 deletion cirq/protocols/circuit_diagram_info_protocol_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,13 @@ def test_circuit_diagram_info_args_eq():
use_unicode_characters=False,
precision=None,
qubit_map=None))
eq.add_equality_group(
cirq.CircuitDiagramInfoArgs(known_qubits=cirq.LineQubit.range(2),
known_qubit_count=2,
use_unicode_characters=False,
precision=None,
qubit_map=None,
include_tags=False))
eq.add_equality_group(
cirq.CircuitDiagramInfoArgs(known_qubits=cirq.LineQubit.range(2),
known_qubit_count=2,
Expand All @@ -170,7 +177,8 @@ def test_circuit_diagram_info_args_repr():
qubit_map={
cirq.LineQubit(0): 5,
cirq.LineQubit(1): 7
}))
},
include_tags=False))


def test_formal_real():
Expand Down

0 comments on commit 82d4fb3

Please sign in to comment.