Skip to content

Commit

Permalink
DRY radian formatting (#2907)
Browse files Browse the repository at this point in the history
  • Loading branch information
viathor committed Apr 16, 2020
1 parent 3132f97 commit beaa782
Show file tree
Hide file tree
Showing 7 changed files with 107 additions and 52 deletions.
30 changes: 6 additions & 24 deletions cirq/ops/common_gates.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,10 +163,8 @@ def _pauli_expansion_(self) -> value.LinearDict[str]:
def _circuit_diagram_info_(self, args: 'cirq.CircuitDiagramInfoArgs'
) -> Union[str, 'protocols.CircuitDiagramInfo']:
if self._global_shift == -0.5:
return _rads_func_symbol(
'Rx',
args,
self._diagram_exponent(args, ignore_global_phase=False))
angle_str = self._format_exponent_as_angle(args)
return f'Rx({angle_str})'

return protocols.CircuitDiagramInfo(
wire_symbols=('X',),
Expand Down Expand Up @@ -307,10 +305,8 @@ def _pauli_expansion_(self) -> value.LinearDict[str]:
def _circuit_diagram_info_(self, args: 'cirq.CircuitDiagramInfoArgs'
) -> Union[str, 'protocols.CircuitDiagramInfo']:
if self._global_shift == -0.5:
return _rads_func_symbol(
'Ry',
args,
self._diagram_exponent(args, ignore_global_phase=False))
angle_str = self._format_exponent_as_angle(args)
return f'Ry({angle_str})'

return protocols.CircuitDiagramInfo(
wire_symbols=('Y',),
Expand Down Expand Up @@ -492,10 +488,8 @@ def _has_stabilizer_effect_(self) -> Optional[bool]:
def _circuit_diagram_info_(self, args: 'cirq.CircuitDiagramInfoArgs'
) -> Union[str, 'protocols.CircuitDiagramInfo']:
if self._global_shift == -0.5:
return _rads_func_symbol(
'Rz',
args,
self._diagram_exponent(args, ignore_global_phase=False))
angle_str = self._format_exponent_as_angle(args)
return f'Rz({angle_str})'

e = self._diagram_exponent(args)
if e in [-0.25, 0.25]:
Expand Down Expand Up @@ -835,18 +829,6 @@ def __repr__(self) -> str:
).format(proper_repr(self._exponent), self._global_shift)


def _rads_func_symbol(func_name: str, args: 'protocols.CircuitDiagramInfoArgs',
half_turns: Any) -> str:
if protocols.is_parameterized(half_turns):
return '{}({})'.format(func_name, sympy.pi * half_turns)
unit = 'π' if args.use_unicode_characters else 'pi'
if half_turns == 1:
return '{}({})'.format(func_name, unit)
if half_turns == -1:
return '{}(-{})'.format(func_name, unit)
return '{}({}{})'.format(func_name, half_turns, unit)


class CXPowGate(eigen_gate.EigenGate, gate_features.TwoQubitGate):
"""A gate that applies a controlled power of an X gate.
Expand Down
19 changes: 19 additions & 0 deletions cirq/ops/eigen_gate.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,25 @@ def _diagram_exponent(self,

return result

def _format_exponent_as_angle(
self,
args: 'protocols.CircuitDiagramInfoArgs',
order: int = 2,
) -> str:
"""Returns string with exponent expressed as angle in radians.
Args:
args: CircuitDiagramInfoArgs describing the desired drawing style.
order: Exponent corresponding to full rotation by 2π.
Returns:
Angle in radians corresponding to the exponent of self and
formatted according to style described by args.
"""
exponent = self._diagram_exponent(args, ignore_global_phase=False)
pi = sympy.pi if protocols.is_parameterized(exponent) else np.pi
return args.format_radians(radians=2 * pi * exponent / order)

# virtual method
def _eigen_shifts(self) -> List[float]:
"""Describes the eigenvalues of the gate's matrix.
Expand Down
22 changes: 3 additions & 19 deletions cirq/ops/fsim_gate.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,9 +143,9 @@ def _decompose_(self, qubits) -> 'cirq.OP_TREE':
yield cirq.CZ(a, b)**(-self.phi / np.pi)

def _circuit_diagram_info_(self, args: 'cirq.CircuitDiagramInfoArgs'):
t = _format_rads(args, self.theta)
p = _format_rads(args, self.phi)
return 'fsim({}, {})'.format(t, p), '#2'
t = args.format_radians(self.theta)
p = args.format_radians(self.phi)
return f'fsim({t}, {p})', '#2'

def __pow__(self, power):
return FSimGate(cirq.mul(self.theta, power), cirq.mul(self.phi, power))
Expand All @@ -156,19 +156,3 @@ def __repr__(self):

def _json_dict_(self):
return protocols.obj_to_dict_helper(self, ['theta', 'phi'])


def _format_rads(args: 'cirq.CircuitDiagramInfoArgs', radians: float) -> str:
if cirq.is_parameterized(radians):
return str(radians)
unit = 'π' if args.use_unicode_characters else 'pi'
if radians == np.pi:
return unit
if radians == 0:
return '0'
if radians == -np.pi:
return '-' + unit
if args.precision is not None:
quantity = args.format_real(radians / np.pi)
return quantity + unit
return repr(radians)
5 changes: 2 additions & 3 deletions cirq/ops/parity_gates.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,9 +88,8 @@ def _circuit_diagram_info_(self, args: 'cirq.CircuitDiagramInfoArgs'
) -> Union[str, 'protocols.CircuitDiagramInfo']:
if self._global_shift == -0.5:
# Mølmer–Sørensen gate.
symbol = common_gates._rads_func_symbol(
'MS', args,
self._diagram_exponent(args, ignore_global_phase=False) / 2)
angle_str = self._format_exponent_as_angle(args, order=4)
symbol = f'MS({angle_str})'
return protocols.CircuitDiagramInfo(
wire_symbols=(symbol, symbol))

Expand Down
11 changes: 7 additions & 4 deletions cirq/ops/parity_gates_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ def test_xx_str():

ms = cirq.XXPowGate(global_shift=-0.5)
assert str(ms) == 'MS(π/2)'
assert str(ms**0.5) == 'MS(0.5π/2)'
assert str(ms**2) == 'MS(2.0π/2)'
assert str(ms**-1) == 'MS(-1.0π/2)'

Expand Down Expand Up @@ -108,11 +109,13 @@ def test_xx_diagrams():
cirq.XX(a, b),
cirq.XX(a, b)**3,
cirq.XX(a, b)**0.5,
cirq.XXPowGate(global_shift=-0.5).on(a, b),
)
cirq.testing.assert_has_diagram(circuit, """
a: ───XX───XX───XX───────
│ │ │
b: ───XX───XX───XX^0.5───
cirq.testing.assert_has_diagram(
circuit, """
a: ───XX───XX───XX───────MS(0.5π)───
│ │ │ │
b: ───XX───XX───XX^0.5───MS(0.5π)───
""")


Expand Down
19 changes: 18 additions & 1 deletion cirq/protocols/circuit_diagram_info_protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,11 @@
from typing import (Any, TYPE_CHECKING, Optional, Union, TypeVar, Dict,
overload, Iterable)

import numpy as np
import sympy
from typing_extensions import Protocol

from cirq import value
from cirq import protocols, value

if TYPE_CHECKING:
import cirq
Expand Down Expand Up @@ -143,6 +144,22 @@ def format_real(self, val: Union[sympy.Basic, int, float]) -> str:
return str(val)
return f'{float(val):.{self.precision}}'

def format_radians(self, radians: Union[sympy.Basic, int, float]) -> str:
"""Returns angle in radians as a human-readable string."""
if protocols.is_parameterized(radians):
return str(radians)
unit = 'π' if self.use_unicode_characters else 'pi'
if radians == np.pi:
return unit
if radians == 0:
return '0'
if radians == -np.pi:
return '-' + unit
if self.precision is not None:
quantity = self.format_real(radians / np.pi)
return quantity + unit
return repr(radians)

def copy(self):
return self.__class__(
known_qubits=self.known_qubits,
Expand Down
53 changes: 52 additions & 1 deletion cirq/protocols/circuit_diagram_info_protocol_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import pytest
import sympy

Expand Down Expand Up @@ -181,7 +182,7 @@ def test_circuit_diagram_info_args_repr():
include_tags=False))


def test_formal_real():
def test_format_real():
args = cirq.CircuitDiagramInfoArgs.UNINFORMED_DEFAULT
assert args.format_real(1) == '1'
assert args.format_real(1.1) == '1.1'
Expand All @@ -197,3 +198,53 @@ def test_formal_real():
assert args.format_real(1 / 7) == repr(1 / 7)
assert args.format_real(sympy.Symbol('t')) == 't'
assert args.format_real(sympy.Symbol('t') * 2 + 1) == '2*t + 1'


def test_format_radians_without_precision():
args = cirq.CircuitDiagramInfoArgs(known_qubits=None,
known_qubit_count=None,
use_unicode_characters=False,
precision=None,
qubit_map=None)
assert args.format_radians(np.pi) == 'pi'
assert args.format_radians(-np.pi) == '-pi'
assert args.format_radians(1.1) == '1.1'
assert args.format_radians(1.234567) == '1.234567'
assert args.format_radians(1 / 7) == repr(1 / 7)
assert args.format_radians(sympy.Symbol('t')) == 't'
assert args.format_radians(sympy.Symbol('t') * 2 + 1) == '2*t + 1'

args.use_unicode_characters = True
assert args.format_radians(np.pi) == 'π'
assert args.format_radians(-np.pi) == '-π'
assert args.format_radians(1.1) == '1.1'
assert args.format_radians(1.234567) == '1.234567'
assert args.format_radians(1 / 7) == repr(1 / 7)
assert args.format_radians(sympy.Symbol('t')) == 't'
assert args.format_radians(sympy.Symbol('t') * 2 + 1) == '2*t + 1'


def test_format_radians_with_precision():
args = cirq.CircuitDiagramInfoArgs(known_qubits=None,
known_qubit_count=None,
use_unicode_characters=False,
precision=3,
qubit_map=None)
assert args.format_radians(np.pi) == 'pi'
assert args.format_radians(-np.pi) == '-pi'
assert args.format_radians(np.pi / 2) == '0.5pi'
assert args.format_radians(-3 * np.pi / 4) == '-0.75pi'
assert args.format_radians(1.1) == '0.35pi'
assert args.format_radians(1.234567) == '0.393pi'
assert args.format_radians(sympy.Symbol('t')) == 't'
assert args.format_radians(sympy.Symbol('t') * 2 + 1) == '2*t + 1'

args.use_unicode_characters = True
assert args.format_radians(np.pi) == 'π'
assert args.format_radians(-np.pi) == '-π'
assert args.format_radians(np.pi / 2) == '0.5π'
assert args.format_radians(-3 * np.pi / 4) == '-0.75π'
assert args.format_radians(1.1) == '0.35π'
assert args.format_radians(1.234567) == '0.393π'
assert args.format_radians(sympy.Symbol('t')) == 't'
assert args.format_radians(sympy.Symbol('t') * 2 + 1) == '2*t + 1'

0 comments on commit beaa782

Please sign in to comment.