Skip to content

Commit

Permalink
Add approximate equality for all common_channels except ResetChannel (
Browse files Browse the repository at this point in the history
#3887)

* Define _approx_eq_ for cirq.depolarize

* Add other channels

* Remove for reset channel

* Cover the remaining lines

* Correct test

* CHeck for num_qubits
  • Loading branch information
vtomole committed Mar 8, 2021
1 parent 8cef3d9 commit 2ef656c
Show file tree
Hide file tree
Showing 2 changed files with 84 additions and 6 deletions.
36 changes: 36 additions & 0 deletions cirq/ops/common_channels.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,13 @@ def _circuit_diagram_info_(self, args: 'protocols.CircuitDiagramInfoArgs') -> st
error_probabilities = [f"{pauli}:{p}" for pauli, p in self._error_probabilities.items()]
return f"A({', '.join(error_probabilities)})"

@property
def p_i(self) -> float:
"""The probability that an Identity I and no other gate occurs."""
if self._num_qubits != 1:
raise ValueError('num_qubits should be 1')
return self._error_probabilities.get('I', 0.0)

@property
def p_x(self) -> float:
"""The probability that a Pauli X and no other gate occurs."""
Expand Down Expand Up @@ -184,6 +191,15 @@ def error_probabilities(self) -> Dict[str, float]:
def _json_dict_(self) -> Dict[str, Any]:
return protocols.obj_to_dict_helper(self, ['error_probabilities'])

def _approx_eq_(self, other: Any, atol: float) -> bool:
return (
self.num_qubits == other.num_qubits
and np.isclose(self.p_i, other.p_i, atol=atol)
and np.isclose(self.p_x, other.p_x, atol=atol)
and np.isclose(self.p_y, other.p_y, atol=atol)
and np.isclose(self.p_z, other.p_z, atol=atol)
)


def asymmetric_depolarize(
p_x: Optional[float] = None,
Expand Down Expand Up @@ -339,6 +355,9 @@ def _json_dict_(self) -> Dict[str, Any]:
return protocols.obj_to_dict_helper(self, ['p'])
return protocols.obj_to_dict_helper(self, ['p', 'n_qubits'])

def _approx_eq_(self, other: Any, atol: float) -> bool:
return np.isclose(self.p, other.p, atol=atol) and self.n_qubits == other.n_qubits


def depolarize(p: float, n_qubits: int = 1) -> DepolarizingChannel:
r"""Returns a DepolarizingChannel with given probability of error.
Expand Down Expand Up @@ -478,6 +497,11 @@ def gamma(self) -> float:
def _json_dict_(self) -> Dict[str, Any]:
return protocols.obj_to_dict_helper(self, ['p', 'gamma'])

def _approx_eq_(self, other: Any, atol: float) -> bool:
return np.isclose(self.gamma, other.gamma, atol=atol) and np.isclose(
self.p, other.p, atol=atol
)


def generalized_amplitude_damp(p: float, gamma: float) -> GeneralizedAmplitudeDampingChannel:
r"""
Expand Down Expand Up @@ -603,6 +627,9 @@ def gamma(self) -> float:
def _json_dict_(self) -> Dict[str, Any]:
return protocols.obj_to_dict_helper(self, ['gamma'])

def _approx_eq_(self, other: Any, atol: float) -> bool:
return np.isclose(self.gamma, other.gamma, atol=atol)


def amplitude_damp(gamma: float) -> AmplitudeDampingChannel:
r"""
Expand Down Expand Up @@ -835,6 +862,9 @@ def gamma(self) -> float:
def _json_dict_(self) -> Dict[str, Any]:
return protocols.obj_to_dict_helper(self, ['gamma'])

def _approx_eq_(self, other: Any, atol: float) -> bool:
return np.isclose(self._gamma, other._gamma, atol=atol)


def phase_damp(gamma: float) -> PhaseDampingChannel:
r"""
Expand Down Expand Up @@ -942,6 +972,9 @@ def p(self) -> float:
def _json_dict_(self) -> Dict[str, Any]:
return protocols.obj_to_dict_helper(self, ['p'])

def _approx_eq_(self, other: Any, atol: float) -> bool:
return np.isclose(self.p, other.p, atol=atol)


def _phase_flip_Z() -> common_gates.ZPowGate:
"""
Expand Down Expand Up @@ -1095,6 +1128,9 @@ def p(self) -> float:
def _json_dict_(self) -> Dict[str, Any]:
return protocols.obj_to_dict_helper(self, ['p'])

def _approx_eq_(self, other: Any, atol: float) -> bool:
return np.isclose(self._p, other._p, atol=atol)


def _bit_flip(p: float) -> BitFlipChannel:
r"""
Expand Down
54 changes: 48 additions & 6 deletions cirq/ops/common_channels_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,8 +83,14 @@ def test_asymmetric_depolarizing_channel_str():


def test_asymmetric_depolarizing_channel_eq():
et = cirq.testing.EqualsTester()

a = cirq.asymmetric_depolarize(0.0099999, 0.01)
b = cirq.asymmetric_depolarize(0.01, 0.0099999)
c = cirq.asymmetric_depolarize(0.0, 0.0, 0.0)

assert cirq.approx_eq(a, b, atol=1e-2)

et = cirq.testing.EqualsTester()
et.make_equality_group(lambda: c)
et.add_equality_group(cirq.asymmetric_depolarize(0.0, 0.0, 0.1))
et.add_equality_group(cirq.asymmetric_depolarize(0.0, 0.1, 0.0))
Expand Down Expand Up @@ -254,8 +260,14 @@ def test_asymmetric_depolarizing_channel_apply_two_qubits():


def test_depolarizing_channel_eq():
et = cirq.testing.EqualsTester()
a = cirq.depolarize(p=0.0099999)
b = cirq.depolarize(p=0.01)
c = cirq.depolarize(0.0)

assert cirq.approx_eq(a, b, atol=1e-2)

et = cirq.testing.EqualsTester()

et.make_equality_group(lambda: c)
et.add_equality_group(cirq.depolarize(0.1))
et.add_equality_group(cirq.depolarize(0.9))
Expand Down Expand Up @@ -322,6 +334,11 @@ def test_generalized_amplitude_damping_str():


def test_generalized_amplitude_damping_channel_eq():
a = cirq.generalized_amplitude_damp(0.0099999, 0.01)
b = cirq.generalized_amplitude_damp(0.01, 0.0099999)

assert cirq.approx_eq(a, b, atol=1e-2)

et = cirq.testing.EqualsTester()
c = cirq.generalized_amplitude_damp(0.0, 0.0)
et.make_equality_group(lambda: c)
Expand Down Expand Up @@ -375,8 +392,13 @@ def test_amplitude_damping_channel_str():


def test_amplitude_damping_channel_eq():
et = cirq.testing.EqualsTester()
a = cirq.amplitude_damp(0.0099999)
b = cirq.amplitude_damp(0.01)
c = cirq.amplitude_damp(0.0)

assert cirq.approx_eq(a, b, atol=1e-2)

et = cirq.testing.EqualsTester()
et.make_equality_group(lambda: c)
et.add_equality_group(cirq.amplitude_damp(0.1))
et.add_equality_group(cirq.amplitude_damp(0.4))
Expand Down Expand Up @@ -499,8 +521,13 @@ def test_phase_damping_channel_str():


def test_phase_damping_channel_eq():
et = cirq.testing.EqualsTester()
a = cirq.phase_damp(0.0099999)
b = cirq.phase_damp(0.01)
c = cirq.phase_damp(0.0)

assert cirq.approx_eq(a, b, atol=1e-2)

et = cirq.testing.EqualsTester()
et.make_equality_group(lambda: c)
et.add_equality_group(cirq.phase_damp(0.1))
et.add_equality_group(cirq.phase_damp(0.4))
Expand Down Expand Up @@ -555,8 +582,13 @@ def test_phase_flip_channel_str():


def test_phase_flip_channel_eq():
et = cirq.testing.EqualsTester()
a = cirq.phase_flip(0.0099999)
b = cirq.phase_flip(0.01)
c = cirq.phase_flip(0.0)

assert cirq.approx_eq(a, b, atol=1e-2)

et = cirq.testing.EqualsTester()
et.make_equality_group(lambda: c)
et.add_equality_group(cirq.phase_flip(0.1))
et.add_equality_group(cirq.phase_flip(0.4))
Expand Down Expand Up @@ -611,8 +643,14 @@ def test_bit_flip_channel_str():


def test_bit_flip_channel_eq():
et = cirq.testing.EqualsTester()

a = cirq.bit_flip(0.0099999)
b = cirq.bit_flip(0.01)
c = cirq.bit_flip(0.0)

assert cirq.approx_eq(a, b, atol=1e-2)

et = cirq.testing.EqualsTester()
et.make_equality_group(lambda: c)
et.add_equality_group(cirq.bit_flip(0.1))
et.add_equality_group(cirq.bit_flip(0.4))
Expand Down Expand Up @@ -650,6 +688,7 @@ def test_stabilizer_supports_depolarize():

def test_default_asymmetric_depolarizing_channel():
d = cirq.asymmetric_depolarize()
assert d.p_i == 1.0
assert d.p_x == 0.0
assert d.p_y == 0.0
assert d.p_z == 0.0
Expand Down Expand Up @@ -684,6 +723,9 @@ def test_multi_asymmetric_depolarizing_channel():
)
assert cirq.has_channel(d)
np.testing.assert_equal(d._num_qubits_(), 2)

with pytest.raises(ValueError, match="num_qubits should be 1"):
assert d.p_i == 1.0
with pytest.raises(ValueError, match="num_qubits should be 1"):
assert d.p_x == 0.0
with pytest.raises(ValueError, match="num_qubits should be 1"):
Expand Down

0 comments on commit 2ef656c

Please sign in to comment.