Skip to content

Commit

Permalink
Make InternalGate hashable if all gate args are hashable (#6294)
Browse files Browse the repository at this point in the history
  • Loading branch information
maffoo committed Sep 19, 2023
1 parent 1366494 commit d805d82
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 3 deletions.
15 changes: 13 additions & 2 deletions cirq-google/cirq_google/ops/internal_gate.py
Expand Up @@ -43,7 +43,7 @@ def __init__(
self.gate_module = gate_module
self.gate_name = gate_name
self._num_qubits = num_qubits
self.gate_args = {arg: val for arg, val in kwargs.items()}
self.gate_args = kwargs

def _num_qubits_(self) -> int:
return self._num_qubits
Expand Down Expand Up @@ -72,4 +72,15 @@ def _json_dict_(self) -> Dict[str, Any]:
)

def _value_equality_values_(self):
return (self.gate_module, self.gate_name, self._num_qubits, self.gate_args)
hashable = True
for arg in self.gate_args.values():
try:
hash(arg)
except TypeError:
hashable = False
return (
self.gate_module,
self.gate_name,
self._num_qubits,
frozenset(self.gate_args.items()) if hashable else self.gate_args,
)
26 changes: 25 additions & 1 deletion cirq-google/cirq_google/ops/internal_gate_test.py
Expand Up @@ -14,6 +14,7 @@

import cirq
import cirq_google
import pytest


def test_internal_gate():
Expand All @@ -39,7 +40,30 @@ def test_internal_gate_with_no_args():
g = cirq_google.InternalGate(gate_name="GateWithNoArgs", gate_module='test', num_qubits=3)
assert str(g) == 'test.GateWithNoArgs()'
want_repr = (
"cirq_google.InternalGate(gate_name='GateWithNoArgs', " "gate_module='test', num_qubits=3)"
"cirq_google.InternalGate(gate_name='GateWithNoArgs', gate_module='test', num_qubits=3)"
)
assert repr(g) == want_repr
assert cirq.qid_shape(g) == (2, 2, 2)


def test_internal_gate_with_hashable_args_is_hashable():
hashable = cirq_google.InternalGate(
gate_name="GateWithHashableArgs",
gate_module='test',
num_qubits=3,
foo=1,
bar="2",
baz=(("a", 1),),
)
_ = hash(hashable)

unhashable = cirq_google.InternalGate(
gate_name="GateWithHashableArgs",
gate_module='test',
num_qubits=3,
foo=1,
bar="2",
baz={"a": 1},
)
with pytest.raises(TypeError, match="unhashable"):
_ = hash(unhashable)

0 comments on commit d805d82

Please sign in to comment.