Skip to content

Commit

Permalink
Propagate global phase in diagram from sub circuits (#3892)
Browse files Browse the repository at this point in the history
* propagate global phase from sub circuits

* unlist

* type annotations and test

* type fixes
  • Loading branch information
smitsanghavi committed Mar 10, 2021
1 parent 3f90c62 commit 05a499c
Show file tree
Hide file tree
Showing 3 changed files with 55 additions and 8 deletions.
34 changes: 26 additions & 8 deletions cirq/circuits/circuit.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@

from cirq import devices, ops, protocols, qis
from cirq.circuits._bucket_priority_queue import BucketPriorityQueue
from cirq.circuits.circuit_operation import CircuitOperation
from cirq.circuits.insert_strategy import InsertStrategy
from cirq.circuits.text_diagram_drawer import TextDiagramDrawer
from cirq.circuits.qasm_output import QasmOutput
Expand Down Expand Up @@ -2271,14 +2272,7 @@ def _draw_moment_in_diagram(
if x > max_x:
max_x = x

global_phase: Optional[complex] = None
tags: List[Any] = []
for op in moment:
if isinstance(op.untagged, ops.GlobalPhaseOperation):
tags.extend(op.tags)
if global_phase is None:
global_phase = complex(1)
global_phase *= complex(op.untagged.coefficient)
global_phase, tags = _get_global_phase_and_tags_for_ops(moment)

# Print out global phase, unless it's 1 (phase of 0pi) or it's the only op.
if global_phase and (global_phase != 1 or not non_global_ops):
Expand All @@ -2297,6 +2291,30 @@ def _draw_moment_in_diagram(
moment_groups.append((x0, max_x))


def _get_global_phase_and_tags_for_op(op: 'cirq.Operation') -> Tuple[Optional[complex], List[Any]]:
if isinstance(op.untagged, ops.GlobalPhaseOperation):
return complex(op.untagged.coefficient), list(op.tags)
elif isinstance(op.untagged, CircuitOperation):
op_phase, op_tags = _get_global_phase_and_tags_for_ops(op.untagged.circuit.all_operations())
return op_phase, list(op.tags) + op_tags
else:
return None, []


def _get_global_phase_and_tags_for_ops(op_list: Any) -> Tuple[Optional[complex], List[Any]]:
global_phase: Optional[complex] = None
tags: List[Any] = []
for op in op_list:
op_phase, op_tags = _get_global_phase_and_tags_for_op(op)
if op_phase:
if global_phase is None:
global_phase = complex(1)
global_phase *= op_phase
if op_tags:
tags.extend(op_tags)
return global_phase, tags


def _formatted_phase(coefficient: complex, unicode: bool, precision: Optional[int]) -> str:
h = math.atan2(coefficient.imag, coefficient.real) / math.pi
unit = 'π' if unicode else 'pi'
Expand Down
17 changes: 17 additions & 0 deletions cirq/circuits/circuit_operation_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,6 +307,23 @@ def test_string_format():
[ ]"""
)

fc0_global_phase_inner = cirq.FrozenCircuit(
cirq.GlobalPhaseOperation(1j), cirq.GlobalPhaseOperation(1j)
)
op0_global_phase_inner = cirq.CircuitOperation(fc0_global_phase_inner)
fc0_global_phase_outer = cirq.FrozenCircuit(
op0_global_phase_inner, cirq.GlobalPhaseOperation(1j)
)
op0_global_phase_outer = cirq.CircuitOperation(fc0_global_phase_outer)
assert (
str(op0_global_phase_outer)
== f"""\
{op0_global_phase_outer.circuit.diagram_name()}:
[ ]
[ ]
[ global phase: -0.5π ]"""
)

fc1 = cirq.FrozenCircuit(cirq.X(x), cirq.H(y), cirq.CX(y, z), cirq.measure(x, y, z, key='m'))
op1 = cirq.CircuitOperation(fc1)
assert (
Expand Down
12 changes: 12 additions & 0 deletions cirq/circuits/circuit_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2417,6 +2417,18 @@ def test_diagram_global_phase(circuit_cls):
precision=2,
)

c = circuit_cls(
cirq.X(cirq.LineQubit(2)),
cirq.CircuitOperation(circuit_cls(cirq.GlobalPhaseOperation(-1).with_tags("tag")).freeze()),
)
cirq.testing.assert_has_diagram(
c,
"""\
2: ───X──────────
π['tag']""",
)


@pytest.mark.parametrize('circuit_cls', [cirq.Circuit, cirq.FrozenCircuit])
def test_has_unitary(circuit_cls):
Expand Down

0 comments on commit 05a499c

Please sign in to comment.