Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions cirq-core/cirq/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -358,6 +358,7 @@
CompilationTargetGateset,
CZTargetGateset,
compute_cphase_exponents_for_fsim_decomposition,
create_transformer_with_kwargs,
decompose_clifford_tableau_to_operations,
decompose_cphase_into_two_fsim,
decompose_multi_controlled_x,
Expand Down
1 change: 1 addition & 0 deletions cirq-core/cirq/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
)

from cirq.transformers.target_gatesets import (
create_transformer_with_kwargs,
CompilationTargetGateset,
CZTargetGateset,
SqrtIswapTargetGateset,
Expand Down
1 change: 1 addition & 0 deletions cirq-core/cirq/transformers/target_gatesets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
"""Gatesets which can act as compilation targets in Cirq."""

from cirq.transformers.target_gatesets.compilation_target_gateset import (
create_transformer_with_kwargs,
CompilationTargetGateset,
TwoQubitCompilationTargetGateset,
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,15 +29,53 @@
import cirq


def _create_transformer_with_kwargs(func: 'cirq.TRANSFORMER', **kwargs) -> 'cirq.TRANSFORMER':
"""Hack to capture additional keyword arguments to transformers while preserving mypy type."""
def create_transformer_with_kwargs(transformer: 'cirq.TRANSFORMER', **kwargs) -> 'cirq.TRANSFORMER':
"""Method to capture additional keyword arguments to transformers while preserving mypy type.

Returns a `cirq.TRANSFORMER` which, when called with a circuit and transformer context, is
equivalent to calling `transformer(circuit, context=context, **kwargs)`. It is often useful to
capture keyword arguments of a transformer before passing them as an argument to an API that
expects `cirq.TRANSFORMER`. For example:

>>> def run_transformers(transformers: List[cirq.TRANSFORMER]):
>>> for transformer in transformers:
>>> transformer(circuit, context=context)
>>>
>>> transformers: List[cirq.TRANSFORMER] = []
>>> transformers.append(
>>> cirq.create_transformer_with_kwargs(
>>> cirq.expand_composite, no_decomp=lambda op: cirq.num_qubits(op) <= 2
>>> )
>>> )
>>> transformers.append(cirq.create_transformer_with_kwargs(cirq.merge_k_qubit_unitaries, k=2))
>>> run_transformers(transformers)


Args:
transformer: A `cirq.TRANSFORMER` for which additional kwargs should be captured.
**kwargs: The keyword arguments which should be captured and passed to `transformer`.

Returns:
A `cirq.TRANSFORMER` method `transformer_with_kwargs`, s.t. executing
`transformer_with_kwargs(circuit, context=context)` is equivalent to executing
`transformer(circuit, context=context, **kwargs)`.

Raises:
SyntaxError: if **kwargs contain a 'context'.
"""
if 'context' in kwargs:
raise SyntaxError('**kwargs to be captured must not contain `context`.')

def transformer(
def transformer_with_kwargs(
circuit: 'cirq.AbstractCircuit', *, context: Optional['cirq.TransformerContext'] = None
) -> 'cirq.AbstractCircuit':
return func(circuit, context=context, **kwargs) # type: ignore
# Need to ignore mypy type because `cirq.TRANSFORMER` is a callable protocol which only
# accepts circuit and context; and doesn't expect additional keyword arguments. Note
# that transformers with additional keyword arguments with a default value do satisfy the
# `cirq.TRANSFORMER` API.
return transformer(circuit, context=context, **kwargs) # type: ignore
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we need the type ignore? I guess it's because cirq.TRANSFORMER is a callable protocol whose __call__ signature does not allow for extra kwargs. Might we be able to remove the type ignore by extending the signature with explicit kwargs?

In any case, there should be a comment explaining the reason for the ignore and linking to a mypy issue if there is one open.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess it's because cirq.TRANSFORMER is a callable protocol whose call signature does not allow for extra kwargs.

Yes, that's right.

Might we be able to remove the type ignore by extending the signature with explicit kwargs?

This is not possible because function arguments are contravariant, so adding **kwargs to the signature of the protocol would mean that every transformer would need to support a **kwargs : Any; which is bad.

In any case, there should be a comment explaining the reason for the ignore and linking to a mypy issue if there is one open.

Added a comment.


return transformer
return transformer_with_kwargs


class CompilationTargetGateset(ops.Gateset, metaclass=abc.ABCMeta):
Expand Down Expand Up @@ -93,11 +131,11 @@ def _intermediate_result_tag(self) -> Hashable:
def preprocess_transformers(self) -> List['cirq.TRANSFORMER']:
"""List of transformers which should be run before decomposing individual operations."""
return [
_create_transformer_with_kwargs(
create_transformer_with_kwargs(
expand_composite.expand_composite,
no_decomp=lambda op: protocols.num_qubits(op) <= self.num_qubits,
),
_create_transformer_with_kwargs(
create_transformer_with_kwargs(
merge_k_qubit_gates.merge_k_qubit_unitaries,
k=self.num_qubits,
rewriter=lambda op: op.with_tags(self._intermediate_result_tag),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.

from typing import List
import pytest
import cirq
from cirq.protocols.decompose_protocol import DecomposeResult

Expand Down Expand Up @@ -219,3 +220,10 @@ def _decompose_single_qubit_operation(self, op: 'cirq.Operation', _) -> Decompos
c_expected = cirq.Circuit(cirq.X.on_each(*q), ops[-2:])
c_new = cirq.optimize_for_target_gateset(c_orig, gateset=DummyTargetGateset())
cirq.testing.assert_same_circuits(c_new, c_expected)


def test_create_transformer_with_kwargs_raises():
with pytest.raises(SyntaxError, match="must not contain `context`"):
cirq.create_transformer_with_kwargs(
cirq.merge_k_qubit_unitaries, k=2, context=cirq.TransformerContext()
)
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,6 @@

import cirq
from cirq.protocols.decompose_protocol import DecomposeResult
from cirq.transformers.target_gatesets.compilation_target_gateset import (
_create_transformer_with_kwargs,
)
from cirq_google import ops
from cirq_google.transformers.analytical_decompositions import two_qubit_to_sycamore

Expand Down Expand Up @@ -137,10 +134,10 @@ def __init__(
@property
def preprocess_transformers(self) -> List[cirq.TRANSFORMER]:
return [
_create_transformer_with_kwargs(
cirq.create_transformer_with_kwargs(
cirq.expand_composite, no_decomp=lambda op: cirq.num_qubits(op) <= self.num_qubits
),
_create_transformer_with_kwargs(
cirq.create_transformer_with_kwargs(
merge_swap_rzz_and_2q_unitaries,
intermediate_result_tag=self._intermediate_result_tag,
),
Expand Down
5 changes: 1 addition & 4 deletions cirq-ionq/cirq_ionq/ionq_gateset.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,6 @@
from typing import List

import cirq
from cirq.transformers.target_gatesets.compilation_target_gateset import (
_create_transformer_with_kwargs,
)


class IonQTargetGateset(cirq.TwoQubitCompilationTargetGateset):
Expand Down Expand Up @@ -85,7 +82,7 @@ def _decompose_two_qubit_operation(self, op: cirq.Operation, _) -> cirq.OP_TREE:
def preprocess_transformers(self) -> List['cirq.TRANSFORMER']:
"""List of transformers which should be run before decomposing individual operations."""
return [
_create_transformer_with_kwargs(
cirq.create_transformer_with_kwargs(
cirq.expand_composite, no_decomp=lambda op: cirq.num_qubits(op) <= self.num_qubits
)
]
Expand Down