Skip to content

Commit

Permalink
fix ms gate equality (#6231)
Browse files Browse the repository at this point in the history
* fix ms gate equality

* update test

---------

Co-authored-by: Tanuj Khattar <tanujkhattar@google.com>
  • Loading branch information
richrines1 and tanujkhattar committed Aug 7, 2023
1 parent 081afab commit 5c36dc0
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 2 deletions.
3 changes: 3 additions & 0 deletions cirq-core/cirq/ops/parity_gates.py
Expand Up @@ -28,6 +28,7 @@
import cirq


@value.value_equality
class XXPowGate(gate_features.InterchangeableQubitsGate, eigen_gate.EigenGate):
r"""The X-parity gate, possibly raised to a power.
Expand Down Expand Up @@ -133,6 +134,7 @@ def __repr__(self) -> str:
)


@value.value_equality
class YYPowGate(gate_features.InterchangeableQubitsGate, eigen_gate.EigenGate):
r"""The Y-parity gate, possibly raised to a power.
Expand Down Expand Up @@ -237,6 +239,7 @@ def __repr__(self) -> str:
)


@value.value_equality
class ZZPowGate(gate_features.InterchangeableQubitsGate, eigen_gate.EigenGate):
r"""The Z-parity gate, possibly raised to a power.
Expand Down
20 changes: 18 additions & 2 deletions cirq-core/cirq/ops/parity_gates_test.py
Expand Up @@ -257,8 +257,24 @@ def test_trace_distance():

def test_ms_arguments():
eq_tester = cirq.testing.EqualsTester()
eq_tester.add_equality_group(cirq.ms(np.pi / 2), cirq.ops.MSGate(rads=np.pi / 2))
eq_tester.add_equality_group(cirq.XXPowGate(global_shift=-0.5))
eq_tester.add_equality_group(
cirq.ms(np.pi / 2), cirq.ops.MSGate(rads=np.pi / 2), cirq.XXPowGate(global_shift=-0.5)
)
eq_tester.add_equality_group(
cirq.ms(np.pi / 4), cirq.XXPowGate(exponent=0.5, global_shift=-0.5)
)
eq_tester.add_equality_group(cirq.XX)
eq_tester.add_equality_group(cirq.XX**0.5)


def test_ms_equal_up_to_global_phase():
assert cirq.equal_up_to_global_phase(cirq.ms(np.pi / 2), cirq.XX)
assert cirq.equal_up_to_global_phase(cirq.ms(np.pi / 4), cirq.XX**0.5)
assert not cirq.equal_up_to_global_phase(cirq.ms(np.pi / 4), cirq.XX)

assert cirq.ms(np.pi / 2) in cirq.GateFamily(cirq.XX)
assert cirq.ms(np.pi / 4) in cirq.GateFamily(cirq.XX**0.5)
assert cirq.ms(np.pi / 4) not in cirq.GateFamily(cirq.XX)


def test_ms_str():
Expand Down

0 comments on commit 5c36dc0

Please sign in to comment.