Skip to content

Commit

Permalink
Allow creating new gatesets with added gates (#2458)
Browse files Browse the repository at this point in the history
This will let us create "pre-release" gatesets to deploy server-side support for new gates that are not yet supported by the API.
  • Loading branch information
maffoo authored and CirqBot committed Nov 6, 2019
1 parent 58a794f commit 94f7fd2
Show file tree
Hide file tree
Showing 3 changed files with 100 additions and 28 deletions.
2 changes: 1 addition & 1 deletion cirq/google/api/v2/program.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Type, TYPE_CHECKING
from typing import TYPE_CHECKING

from cirq import devices, ops

Expand Down
35 changes: 28 additions & 7 deletions cirq/google/serializable_gate_set.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,8 @@
# limitations under the License.
"""Support for serializing and deserializing cirq.api.google.v2 protos."""

from collections import defaultdict

from typing import cast, Dict, Iterable, List, Optional, Tuple, Type, Union, \
TYPE_CHECKING
from typing import (cast, Dict, Iterable, List, Optional, Tuple, Type, Union,
TYPE_CHECKING)

from google.protobuf import json_format

Expand Down Expand Up @@ -51,12 +49,35 @@ def __init__(self, gate_set_name: str,
forms of gates to GateOperations.
"""
self.gate_set_name = gate_set_name
self.serializers = defaultdict(
list) # type: Dict[Type, List[op_serializer.GateOpSerializer]]
self.serializers: Dict[Type, List[op_serializer.GateOpSerializer]] = {}
for s in serializers:
self.serializers[s.gate_type].append(s)
self.serializers.setdefault(s.gate_type, []).append(s)
self.deserializers = {d.serialized_gate_id: d for d in deserializers}

def with_added_gates(
self,
*,
gate_set_name: Optional[str] = None,
serializers: Iterable[op_serializer.GateOpSerializer] = (),
deserializers: Iterable[op_deserializer.GateOpDeserializer] = (),
) -> 'SerializableGateSet':
"""Creates a new gateset with additional (de)serializers.
Args:
gate_set_name: Optional new name of the gateset. If not given, use
the same name as this gateset.
serializers: Serializers to add to those in this gateset.
deserializers: Deserializers to add to those in this gateset.
"""
# Iterate over all serializers in this gateset.
curr_serializers = (serializer
for serializers in self.serializers.values()
for serializer in serializers)
return SerializableGateSet(
gate_set_name or self.gate_set_name,
serializers=[*curr_serializers, *serializers],
deserializers=[*self.deserializers.values(), *deserializers])

def supported_gate_types(self) -> Tuple:
return tuple(self.serializers.keys())

Expand Down
91 changes: 71 additions & 20 deletions cirq/google/serializable_gate_set_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,26 +17,57 @@
import cirq
import cirq.google as cg

X_SERIALIZER = cg.GateOpSerializer(gate_type=cirq.XPowGate,
serialized_gate_id='x_pow',
args=[
cg.SerializingArg(
serialized_name='half_turns',
serialized_type=float,
gate_getter='exponent')
])

X_DESERIALIZER = cg.GateOpDeserializer(serialized_gate_id='x_pow',
gate_constructor=cirq.XPowGate,
args=[
cg.DeserializingArg(
serialized_name='half_turns',
constructor_arg_name='exponent')
])

MY_GATE_SET = cg.SerializableGateSet(gate_set_name='my_gate_set',
serializers=[X_SERIALIZER],
deserializers=[X_DESERIALIZER])
X_SERIALIZER = cg.GateOpSerializer(
gate_type=cirq.XPowGate,
serialized_gate_id='x_pow',
args=[
cg.SerializingArg(
serialized_name='half_turns',
serialized_type=float,
gate_getter='exponent',
)
],
)

X_DESERIALIZER = cg.GateOpDeserializer(
serialized_gate_id='x_pow',
gate_constructor=cirq.XPowGate,
args=[
cg.DeserializingArg(
serialized_name='half_turns',
constructor_arg_name='exponent',
)
],
)

Y_SERIALIZER = cg.GateOpSerializer(
gate_type=cirq.YPowGate,
serialized_gate_id='y_pow',
args=[
cg.SerializingArg(
serialized_name='half_turns',
serialized_type=float,
gate_getter='exponent',
)
],
)

Y_DESERIALIZER = cg.GateOpDeserializer(
serialized_gate_id='y_pow',
gate_constructor=cirq.XPowGate,
args=[
cg.DeserializingArg(
serialized_name='half_turns',
constructor_arg_name='exponent',
)
],
)

MY_GATE_SET = cg.SerializableGateSet(
gate_set_name='my_gate_set',
serializers=[X_SERIALIZER],
deserializers=[X_DESERIALIZER],
)


def test_supported_gate_types():
Expand Down Expand Up @@ -338,6 +369,26 @@ def test_multiple_serializers():
assert gate_set.serialize_op(cirq.X(q0)**0.5).gate.id == 'x_pow'


def test_gateset_with_added_gates():
x_gateset = cg.SerializableGateSet(
gate_set_name='x',
serializers=[X_SERIALIZER],
deserializers=[X_DESERIALIZER],
)
xy_gateset = x_gateset.with_added_gates(
gate_set_name='xy',
serializers=[Y_SERIALIZER],
deserializers=[Y_DESERIALIZER],
)
assert x_gateset.gate_set_name == 'x'
assert x_gateset.is_supported_gate(cirq.X)
assert not x_gateset.is_supported_gate(cirq.Y)

assert xy_gateset.gate_set_name == 'xy'
assert xy_gateset.is_supported_gate(cirq.X)
assert xy_gateset.is_supported_gate(cirq.Y)


def test_deserialize_op_invalid_gate():
proto = {
'gate': {},
Expand Down

0 comments on commit 94f7fd2

Please sign in to comment.