From d805d82375d221237d5dfe44d7c089c6911a0462 Mon Sep 17 00:00:00 2001 From: Matthew Neeley Date: Tue, 19 Sep 2023 20:03:13 +0000 Subject: [PATCH] Make InternalGate hashable if all gate args are hashable (#6294) Review: @NoureldinYosri --- cirq-google/cirq_google/ops/internal_gate.py | 15 +++++++++-- .../cirq_google/ops/internal_gate_test.py | 26 ++++++++++++++++++- 2 files changed, 38 insertions(+), 3 deletions(-) diff --git a/cirq-google/cirq_google/ops/internal_gate.py b/cirq-google/cirq_google/ops/internal_gate.py index 5822aa1fefa..f5e1f37d498 100644 --- a/cirq-google/cirq_google/ops/internal_gate.py +++ b/cirq-google/cirq_google/ops/internal_gate.py @@ -43,7 +43,7 @@ def __init__( self.gate_module = gate_module self.gate_name = gate_name self._num_qubits = num_qubits - self.gate_args = {arg: val for arg, val in kwargs.items()} + self.gate_args = kwargs def _num_qubits_(self) -> int: return self._num_qubits @@ -72,4 +72,15 @@ def _json_dict_(self) -> Dict[str, Any]: ) def _value_equality_values_(self): - return (self.gate_module, self.gate_name, self._num_qubits, self.gate_args) + hashable = True + for arg in self.gate_args.values(): + try: + hash(arg) + except TypeError: + hashable = False + return ( + self.gate_module, + self.gate_name, + self._num_qubits, + frozenset(self.gate_args.items()) if hashable else self.gate_args, + ) diff --git a/cirq-google/cirq_google/ops/internal_gate_test.py b/cirq-google/cirq_google/ops/internal_gate_test.py index 00fd480ccaa..b212d4f6151 100644 --- a/cirq-google/cirq_google/ops/internal_gate_test.py +++ b/cirq-google/cirq_google/ops/internal_gate_test.py @@ -14,6 +14,7 @@ import cirq import cirq_google +import pytest def test_internal_gate(): @@ -39,7 +40,30 @@ def test_internal_gate_with_no_args(): g = cirq_google.InternalGate(gate_name="GateWithNoArgs", gate_module='test', num_qubits=3) assert str(g) == 'test.GateWithNoArgs()' want_repr = ( - "cirq_google.InternalGate(gate_name='GateWithNoArgs', " "gate_module='test', num_qubits=3)" + "cirq_google.InternalGate(gate_name='GateWithNoArgs', gate_module='test', num_qubits=3)" ) assert repr(g) == want_repr assert cirq.qid_shape(g) == (2, 2, 2) + + +def test_internal_gate_with_hashable_args_is_hashable(): + hashable = cirq_google.InternalGate( + gate_name="GateWithHashableArgs", + gate_module='test', + num_qubits=3, + foo=1, + bar="2", + baz=(("a", 1),), + ) + _ = hash(hashable) + + unhashable = cirq_google.InternalGate( + gate_name="GateWithHashableArgs", + gate_module='test', + num_qubits=3, + foo=1, + bar="2", + baz={"a": 1}, + ) + with pytest.raises(TypeError, match="unhashable"): + _ = hash(unhashable)