Skip to content

Commit

Permalink
MatrixGate names don't survive serialization (#6026)
Browse files Browse the repository at this point in the history
  • Loading branch information
markedmiston committed Mar 3, 2023
1 parent 18991b5 commit c8f7a02
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 3 deletions.
10 changes: 7 additions & 3 deletions cirq-core/cirq/ops/matrix_gates.py
Expand Up @@ -114,11 +114,15 @@ def with_name(self, name: str) -> 'MatrixGate':
return MatrixGate(self._matrix, name=name, qid_shape=self._qid_shape, unitary_check=False)

def _json_dict_(self) -> Dict[str, Any]:
return {'matrix': self._matrix.tolist(), 'qid_shape': self._qid_shape}
return {
'matrix': self._matrix.tolist(),
'qid_shape': self._qid_shape,
**({'name': self._name} if self._name is not None else {}),
}

@classmethod
def _from_json_dict_(cls, matrix, qid_shape, **kwargs):
return cls(matrix=np.array(matrix), qid_shape=qid_shape)
def _from_json_dict_(cls, matrix, qid_shape, name=None, **kwargs):
return cls(matrix=np.array(matrix), qid_shape=qid_shape, name=name)

def _qid_shape_(self) -> Tuple[int, ...]:
return self._qid_shape
Expand Down
22 changes: 22 additions & 0 deletions cirq-core/cirq/ops/matrix_gates_test.py
Expand Up @@ -388,3 +388,25 @@ def test_matrixgate_unitary_tolerance():
# very low atol -> the check never converges
with pytest.raises(ValueError):
_ = cirq.MatrixGate(np.array([[0.707, 0.707], [-0.707, 0.707]]), unitary_check_rtol=1e-10)


def test_matrixgate_name_serialization():
# https://github.com/quantumlib/Cirq/issues/5999

# Test name serialization
gate1 = cirq.MatrixGate(np.eye(2), name='test_name')
gate_after_serialization1 = cirq.read_json(json_text=cirq.to_json(gate1))
assert gate1._name == 'test_name'
assert gate_after_serialization1._name == 'test_name'

# Test name backwards compatibility
gate2 = cirq.MatrixGate(np.eye(2))
gate_after_serialization2 = cirq.read_json(json_text=cirq.to_json(gate2))
assert gate2._name is None
assert gate_after_serialization2._name is None

# Test empty name
gate3 = cirq.MatrixGate(np.eye(2), name='')
gate_after_serialization3 = cirq.read_json(json_text=cirq.to_json(gate3))
assert gate3._name == ''
assert gate_after_serialization3._name == ''

0 comments on commit c8f7a02

Please sign in to comment.