Skip to content

Commit

Permalink
Type annotations for optimization_pass (#3962)
Browse files Browse the repository at this point in the history
Part of: #554
  • Loading branch information
vtomole committed Mar 24, 2021
1 parent 966fc02 commit ce9fde6
Show file tree
Hide file tree
Showing 8 changed files with 63 additions and 24 deletions.
10 changes: 8 additions & 2 deletions cirq/circuits/circuit_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
import os
from collections import defaultdict
from random import randint, random, sample, randrange
from typing import Tuple
from typing import Optional, Tuple, TYPE_CHECKING

import numpy as np
import pytest
Expand Down Expand Up @@ -65,6 +65,10 @@ def can_add_operation_into_moment(
)


if TYPE_CHECKING:
import cirq


class _MomentAndOpTypeValidatingDeviceType(cirq.Device):
def validate_operation(self, operation):
if not isinstance(operation, cirq.Operation):
Expand Down Expand Up @@ -958,7 +962,9 @@ def __init__(self, replacer=(lambda x: x)):
super().__init__()
self.replacer = replacer

def optimization_at(self, circuit, index, op):
def optimization_at(
self, circuit: 'cirq.Circuit', index: int, op: 'cirq.Operation'
) -> Optional['cirq.PointOptimizationSummary']:
new_ops = self.replacer(op)
return cirq.PointOptimizationSummary(
clear_span=1, clear_qubits=op.qubits, new_operations=new_ops
Expand Down
33 changes: 22 additions & 11 deletions cirq/circuits/optimization_pass_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,16 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional, TYPE_CHECKING, Set, List

import pytest
import cirq
from cirq import PointOptimizer, PointOptimizationSummary
from cirq import PointOptimizer, PointOptimizationSummary, Operation
from cirq.testing import EqualsTester

if TYPE_CHECKING:
import cirq


def test_equality():
a = cirq.NamedQubit('a')
Expand Down Expand Up @@ -57,18 +61,20 @@ class ReplaceWithXGates(PointOptimizer):
operation's qubits.
"""

def optimization_at(self, circuit, index, op):
def optimization_at(
self, circuit: 'cirq.Circuit', index: int, op: 'cirq.Operation'
) -> Optional['cirq.PointOptimizationSummary']:
end = index + 1
new_ops = [cirq.X(q) for q in op.qubits]
done = False
while not done:
n = circuit.next_moment_operating_on(op.qubits, end)
if n is None:
break
next_ops = {circuit.operation_at(q, n) for q in op.qubits}
next_ops = [e for e in next_ops if e]
next_ops = sorted(next_ops, key=lambda e: str(e.qubits))
for next_op in next_ops:
next_ops: Set[Optional[Operation]] = {circuit.operation_at(q, n) for q in op.qubits}
next_ops_list: List[Operation] = [e for e in next_ops if e]
next_ops_sorted = sorted(next_ops_list, key=lambda e: str(e.qubits))
for next_op in next_ops_sorted:
if next_op:
if set(next_op.qubits).issubset(op.qubits):
end = n + 1
Expand Down Expand Up @@ -149,14 +155,19 @@ def test_point_optimizer_raises_on_gates_changing_qubits():
class EverythingIs42(cirq.PointOptimizer):
"""Changes all single qubit operations to act on LineQubit(42)"""

def optimization_at(self, circuit, index, op):
if len(op.qubits) == 1:
def optimization_at(
self, circuit: 'cirq.Circuit', index: int, op: 'cirq.Operation'
) -> Optional['cirq.PointOptimizationSummary']:
new_op = op
if len(op.qubits) == 1 and isinstance(op, cirq.GateOperation):
new_op = op.gate(cirq.LineQubit(42))
return cirq.PointOptimizationSummary(
clear_span=1, clear_qubits=op.qubits, new_operations=new_op
)

return cirq.PointOptimizationSummary(
clear_span=1, clear_qubits=op.qubits, new_operations=new_op
)

c = cirq.Circuit(cirq.X(cirq.LineQubit(0)), cirq.X(cirq.LineQubit(1)))

with pytest.raises(ValueError, match='new qubits'):
EverythingIs42().optimize_circuit(c)

Expand Down
8 changes: 5 additions & 3 deletions cirq/contrib/acquaintance/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import DefaultDict, Dict, Sequence, TYPE_CHECKING
from typing import DefaultDict, Dict, Sequence, TYPE_CHECKING, Optional

import abc
from collections import defaultdict
Expand Down Expand Up @@ -81,7 +81,9 @@ def __call__(self, strategy: 'cirq.Circuit'):
super().optimize_circuit(strategy)
return self.mapping.copy()

def optimization_at(self, circuit: 'cirq.Circuit', index: int, op: 'cirq.Operation'):
def optimization_at(
self, circuit: 'cirq.Circuit', index: int, op: 'cirq.Operation'
) -> Optional['cirq.PointOptimizationSummary']:
if isinstance(op.gate, AcquaintanceOpportunityGate):
logical_indices = tuple(self.mapping[q] for q in op.qubits)
logical_operations = self.execution_strategy.get_operations(logical_indices, op.qubits)
Expand All @@ -93,7 +95,7 @@ def optimization_at(self, circuit: 'cirq.Circuit', index: int, op: 'cirq.Operati

if isinstance(op, ops.GateOperation) and isinstance(op.gate, PermutationGate):
op.gate.update_mapping(self.mapping, op.qubits)
return
return None

raise TypeError(
'Can only execute a strategy consisting of gates that '
Expand Down
5 changes: 4 additions & 1 deletion cirq/google/optimizers/convert_to_sqrt_iswap.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

from cirq import ops, circuits, protocols


if TYPE_CHECKING:
import cirq

Expand Down Expand Up @@ -104,7 +105,9 @@ def convert(self, op: 'cirq.Operation') -> List['cirq.Operation']:
)
return a

def optimization_at(self, circuit, index, op):
def optimization_at(
self, circuit: 'cirq.Circuit', index: int, op: 'cirq.Operation'
) -> Optional['cirq.PointOptimizationSummary']:
if isinstance(op.gate, ops.MatrixGate) and len(op.qubits) == 1:
return None

Expand Down
7 changes: 6 additions & 1 deletion cirq/google/optimizers/convert_to_sycamore_gates.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import numpy as np
import scipy.linalg
from cirq import circuits, google, linalg, ops, optimizers, protocols

from cirq.google.ops import SycamoreGate
from cirq.google.optimizers.two_qubit_gates.gate_compilation import GateTabulation

Expand Down Expand Up @@ -129,7 +130,9 @@ def on_stuck_raise(bad):
on_stuck_raise=None if self.ignore_failures else on_stuck_raise,
)

def optimization_at(self, circuit, index, op):
def optimization_at(
self, circuit: 'cirq.Circuit', index: int, op: 'cirq.Operation'
) -> Optional['cirq.PointOptimizationSummary']:
if not isinstance(op, ops.GateOperation):
return None

Expand All @@ -144,6 +147,8 @@ def optimization_at(self, circuit, index, op):
ops_in_front = list({circuit.operation_at(q, next_index) for q in op.qubits})
if len(ops_in_front) == 1 and isinstance(ops_in_front[0], ops.GateOperation):
gate2 = ops_in_front[0].gate
else:
next_index = 0

if isinstance(gate, ops.SwapPowGate) and isinstance(gate2, ops.ZZPowGate):
rads = gate2.exponent * np.pi / 2
Expand Down
6 changes: 4 additions & 2 deletions cirq/google/optimizers/convert_to_xmon_gates.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List, TYPE_CHECKING
from typing import List, TYPE_CHECKING, Optional

from cirq import ops, protocols
from cirq.circuits.optimization_pass import (
Expand Down Expand Up @@ -90,7 +90,9 @@ def on_stuck_raise(bad):
on_stuck_raise=None if self.ignore_failures else on_stuck_raise,
)

def optimization_at(self, circuit, index, op):
def optimization_at(
self, circuit: 'cirq.Circuit', index: int, op: 'cirq.Operation'
) -> Optional['cirq.PointOptimizationSummary']:
converted = self.convert(op)
if len(converted) == 1 and converted[0] is op:
return None
Expand Down
9 changes: 7 additions & 2 deletions cirq/neutral_atoms/convert_to_neutral_atom_gates.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List
from typing import List, Optional, TYPE_CHECKING

from cirq import ops, protocols
from cirq.circuits.optimization_pass import (
Expand All @@ -20,6 +20,9 @@
)
from cirq import optimizers

if TYPE_CHECKING:
import cirq


class ConvertToNeutralAtomGates(PointOptimizer):
"""Attempts to convert gates into native Atom gates.
Expand Down Expand Up @@ -76,7 +79,9 @@ def on_stuck_raise(bad):
on_stuck_raise=None if self.ignore_failures else on_stuck_raise,
)

def optimization_at(self, circuit, index, op):
def optimization_at(
self, circuit: 'cirq.Circuit', index: int, op: 'cirq.Operation'
) -> Optional['cirq.PointOptimizationSummary']:
converted = self.convert(op)
if len(converted) == 1 and converted[0] is op:
return None
Expand Down
9 changes: 7 additions & 2 deletions cirq/optimizers/expand_composite.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,17 @@

"""An optimizer that expands composite operations via `cirq.decompose`."""

from typing import Callable
from typing import Callable, Optional, TYPE_CHECKING

from cirq import ops, protocols
from cirq.circuits.optimization_pass import (
PointOptimizer,
PointOptimizationSummary,
)

if TYPE_CHECKING:
import cirq


class ExpandComposite(PointOptimizer):
"""An optimizer that expands composite operations via `cirq.decompose`.
Expand All @@ -41,7 +44,9 @@ def __init__(self, no_decomp: Callable[[ops.Operation], bool] = (lambda _: False
super().__init__()
self.no_decomp = no_decomp

def optimization_at(self, circuit, index, op):
def optimization_at(
self, circuit: 'cirq.Circuit', index: int, op: 'cirq.Operation'
) -> Optional['cirq.PointOptimizationSummary']:
decomposition = protocols.decompose(op, keep=self.no_decomp, on_stuck_raise=None)
if decomposition == [op]:
return None
Expand Down

0 comments on commit ce9fde6

Please sign in to comment.