Skip to content

Commit

Permalink
Merge pull request #2425 from Ericgig/data.isequal
Browse files Browse the repository at this point in the history
Add `rtol` to `Qobj.__eq__`
  • Loading branch information
Ericgig committed May 15, 2024
2 parents 4acc9f1 + 23042b3 commit ad05a22
Show file tree
Hide file tree
Showing 7 changed files with 265 additions and 5 deletions.
1 change: 1 addition & 0 deletions doc/changes/2425.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Qobj.__eq__ uses core's settings rtol.
4 changes: 4 additions & 0 deletions qutip/core/data/matmul.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -527,6 +527,10 @@ cpdef CSR multiply_csr(CSR left, CSR right):
+ " and "
+ str(right.shape)
)

left = left.sort_indices()
right = right.sort_indices()

cdef idxint col_left, left_nnz = csr.nnz(left)
cdef idxint col_right, right_nnz = csr.nnz(right)
cdef idxint ptr_left, ptr_right, ptr_left_max, ptr_right_max
Expand Down
182 changes: 181 additions & 1 deletion qutip/core/data/properties.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,9 @@ from cpython cimport mem
from qutip.settings import settings

from qutip.core.data.base cimport idxint
from qutip.core.data cimport csr, dense, CSR, Dense, Dia
from qutip.core.data cimport csr, dense, dia, CSR, Dense, Dia
from qutip.core.data.adjoint cimport transpose_csr
import numpy as np

cdef extern from *:
# Not defined in cpython.mem for some reason, but is in pymem.h.
Expand All @@ -18,6 +19,7 @@ __all__ = [
'isherm', 'isherm_csr', 'isherm_dense', 'isherm_dia',
'isdiag', 'isdiag_csr', 'isdiag_dense', 'isdiag_dia',
'iszero', 'iszero_csr', 'iszero_dense', 'iszero_dia',
'isequal', 'isequal_csr', 'isequal_dense', 'isequal_dia',
]

cdef inline bint _conj_feq(double complex a, double complex b, double tol) nogil:
Expand All @@ -36,6 +38,30 @@ cdef inline bint _feq_zero(double complex a, double tol) nogil:
cdef inline double _abssq(double complex x) nogil:
return x.real*x.real + x.imag*x.imag

cdef inline bint _feq(double complex a, double complex b, double atol, double rtol) nogil:
"""
Follow numpy.allclose tolerance equation:
|a - b| <= (atol + rtol * |b|)
Avoid slow sqrt.
"""
cdef double diff = (a.real - b.real)**2 + (a.imag - b.imag)**2 - atol * atol
if diff <= 0:
# Early exit if under atol.
# |a - b|**2 <= atol**2
return True
cdef double normb_sq = b.real * b.real + b.imag * b.imag
if normb_sq == 0. or rtol == 0.:
# No rtol term, the previous computation was final.
return False
diff -= rtol * rtol * normb_sq
if diff <= 0:
# Early exit if under atol + rtol without cross term.
# |a - b|**2 <= atol**2 + (rtol * |b|)**2
return True
# Full computation
# (|a - b|**2 - atol**2 * (rtol * |b|)**2)**2 <= (2* atol * rtol * |b|)**2
return diff**2 <= 4 * atol * atol * rtol * rtol * normb_sq


cdef bint _isherm_csr_full(CSR matrix, double tol) except 2:
"""
Expand Down Expand Up @@ -300,6 +326,116 @@ cpdef bint iszero_dense(Dense matrix, double tol=-1) nogil:
return True


cpdef bint isequal_dia(Dia A, Dia B, double atol=-1, double rtol=-1):
if A.shape[0] != B.shape[0] or A.shape[1] != B.shape[1]:
return False
if atol < 0:
atol = settings.core["atol"]
if rtol < 0:
rtol = settings.core["rtol"]

cdef idxint diag_a=0, diag_b=0
cdef double complex *ptr_a
cdef double complex *ptr_b
cdef idxint size=A.shape[1]

# TODO:
# Works only for a sorted offsets list.
# We don't have a check for whether it's already sorted, but it should be
# in most cases. Could be improved by tracking whether it is or not.
A = dia.clean_dia(A)
B = dia.clean_dia(B)

ptr_a = A.data
ptr_b = B.data

with nogil:
while diag_a < A.num_diag and diag_b < B.num_diag:
if A.offsets[diag_a] == B.offsets[diag_b]:
for i in range(size):
if not _feq(ptr_a[i], ptr_b[i], atol, rtol):
return False
ptr_a += size
diag_a += 1
ptr_b += size
diag_b += 1
elif A.offsets[diag_a] <= B.offsets[diag_b]:
for i in range(size):
if not _feq(ptr_a[i], 0., atol, rtol):
return False
ptr_a += size
diag_a += 1
else:
for i in range(size):
if not _feq(0., ptr_b[i], atol, rtol):
return False
ptr_b += size
diag_b += 1
return True


cpdef bint isequal_dense(Dense A, Dense B, double atol=-1, double rtol=-1):
if A.shape[0] != B.shape[0] or A.shape[1] != B.shape[1]:
return False
if atol < 0:
atol = settings.core["atol"]
if rtol < 0:
rtol = settings.core["rtol"]
return np.allclose(A.as_ndarray(), B.as_ndarray(), rtol, atol)


cpdef bint isequal_csr(CSR A, CSR B, double atol=-1, double rtol=-1):
if A.shape[0] != B.shape[0] or A.shape[1] != B.shape[1]:
return False
if atol < 0:
atol = settings.core["atol"]
if rtol < 0:
rtol = settings.core["rtol"]

cdef idxint row, ptr_a, ptr_b, ptr_a_max, ptr_b_max, col_a, col_b
cdef idxint ncols = A.shape[1], prev_col_a, prev_col_b

# TODO:
# Works only for sorted indices.
# We don't have a check for whether it's already sorted, but it should be
# in most cases.
A = A.sort_indices()
B = B.sort_indices()

with nogil:
ptr_a_max = ptr_b_max = 0
for row in range(A.shape[0]):
ptr_a = ptr_a_max
ptr_a_max = A.row_index[row + 1]
ptr_b = ptr_b_max
ptr_b_max = B.row_index[row + 1]
col_a = A.col_index[ptr_a] if ptr_a < ptr_a_max else ncols + 1
col_b = B.col_index[ptr_b] if ptr_b < ptr_b_max else ncols + 1
prev_col_a = -1
prev_col_b = -1
while ptr_a < ptr_a_max or ptr_b < ptr_b_max:

if col_a == col_b:
if not _feq(A.data[ptr_a], B.data[ptr_b], atol, rtol):
return False
ptr_a += 1
ptr_b += 1
col_a = A.col_index[ptr_a] if ptr_a < ptr_a_max else ncols + 1
col_b = B.col_index[ptr_b] if ptr_b < ptr_b_max else ncols + 1
elif col_a < col_b:
if not _feq(A.data[ptr_a], 0., atol, rtol):
return False
ptr_a += 1
col_a = A.col_index[ptr_a] if ptr_a < ptr_a_max else ncols + 1
else:
if not _feq(0., B.data[ptr_b], atol, rtol):
return False
ptr_b += 1
col_b = B.col_index[ptr_b] if ptr_b < ptr_b_max else ncols + 1

return True


from .dispatch import Dispatcher as _Dispatcher
import inspect as _inspect

Expand Down Expand Up @@ -397,4 +533,48 @@ iszero.add_specialisations([
(Dense, iszero_dense),
], _defer=True)

isequal = _Dispatcher(
_inspect.Signature([
_inspect.Parameter('A', _inspect.Parameter.POSITIONAL_ONLY),
_inspect.Parameter('B', _inspect.Parameter.POSITIONAL_ONLY),
_inspect.Parameter('atol', _inspect.Parameter.POSITIONAL_OR_KEYWORD,
default=-1),
_inspect.Parameter('rtol', _inspect.Parameter.POSITIONAL_OR_KEYWORD,
default=-1),
]),
name='isequal',
module=__name__,
inputs=('A', 'B',),
out=False,
)
isequal.__doc__ =\
"""
Test if two matrices are equal up to absolute and relative tolerance:
|A - B| <= atol + rtol * |b|
Similar to ``numpy.allclose``.
Parameters
----------
A, B : Data
Matrices to compare.
atol : real, optional
The absolute tolerance to use. If not given, or
less than 0, use the core setting `atol`.
rtol : real, optional
The relative tolerance to use. If not given, or
less than 0, use the core setting `atol`.
Returns
-------
bool
Whether the matrix are equal.
"""
isequal.add_specialisations([
(CSR, CSR, isequal_csr),
(Dia, Dia, isequal_dia),
(Dense, Dense, isequal_dense),
], _defer=True)

del _inspect, _Dispatcher
4 changes: 2 additions & 2 deletions qutip/core/qobj.py
Original file line number Diff line number Diff line change
Expand Up @@ -483,8 +483,8 @@ def __eq__(self, other) -> bool:
return True
if not isinstance(other, Qobj) or self._dims != other._dims:
return False
return _data.iszero(_data.sub(self._data, other._data),
tol=settings.core['atol'])
# isequal uses both atol and rtol from settings.core
return _data.isequal(self._data, other._data)

def __pow__(self, n: int, m=None) -> Qobj: # calculates powers of Qobj
if (
Expand Down
2 changes: 1 addition & 1 deletion qutip/tests/core/data/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def random_scipy_csr(shape, density, sorted_):
cols = np.random.choice(np.arange(shape[1]), nnz)
sci = scipy.sparse.coo_matrix((data, (rows, cols)), shape=shape).tocsr()
if not sorted_:
shuffle_indices_scipy_csr(sci)
sci = shuffle_indices_scipy_csr(sci)
return sci


Expand Down
66 changes: 66 additions & 0 deletions qutip/tests/core/data/test_properties.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@

from qutip import data as _data
from qutip import CoreOptions
from . import conftest
from qutip.core.data.dia import clean_dia

@pytest.fixture(params=[_data.CSR, _data.Dense, _data.Dia], ids=["CSR", "Dense", "Dia"])
def datatype(request):
Expand Down Expand Up @@ -177,3 +179,67 @@ def test_isdiag(self, shape, datatype):
mat[1, 0] = 1
data = _data.to(datatype, _data.Dense(mat))
assert not _data.isdiag(data)


class TestIsEqual:
def op_numpy(self, left, right, atol, rtol):
return np.allclose(left.to_array(), right.to_array(), rtol, atol)

def rand_dense(shape):
return conftest.random_dense(shape, False)

def rand_diag(shape):
return conftest.random_diag(shape, 0.5, True)

def rand_csr(shape):
return conftest.random_csr(shape, 0.5, True)

@pytest.mark.parametrize("factory", [rand_dense, rand_diag, rand_csr])
@pytest.mark.parametrize("shape", [(1, 20), (20, 20), (20, 2)])
def test_same_shape(self, factory, shape):
atol = 1e-8
rtol = 1e-6
A = factory(shape)
B = factory(shape)
assert _data.isequal(A, A, atol, rtol)
assert _data.isequal(B, B, atol, rtol)
assert (
_data.isequal(A, B, atol, rtol) == self.op_numpy(A, B, atol, rtol)
)

@pytest.mark.parametrize("factory", [rand_dense, rand_diag, rand_csr])
@pytest.mark.parametrize("shapeA", [(1, 10), (9, 9), (10, 2)])
@pytest.mark.parametrize("shapeB", [(1, 9), (10, 10), (10, 1)])
def test_different_shape(self, factory, shapeA, shapeB):
A = factory(shapeA)
B = factory(shapeB)
assert not _data.isequal(A, B, np.inf, np.inf)

@pytest.mark.parametrize("rtol", [1e-6, 100])
@pytest.mark.parametrize("factory", [rand_dense, rand_diag, rand_csr])
@pytest.mark.parametrize("shape", [(1, 20), (20, 20), (20, 2)])
def test_rtol(self, factory, shape, rtol):
mat = factory(shape)
assert _data.isequal(mat + mat * (rtol / 10), mat, 1e-14, rtol)
assert not _data.isequal(mat * (1 + rtol * 10), mat, 1e-14, rtol)

@pytest.mark.parametrize("atol", [1e-14, 1e-6, 100])
@pytest.mark.parametrize("factory", [rand_dense, rand_diag, rand_csr])
@pytest.mark.parametrize("shape", [(1, 20), (20, 20), (20, 2)])
def test_atol(self, factory, shape, atol):
A = factory(shape)
B = factory(shape)
assert _data.isequal(A, A + B * (atol / 10), atol, 0)
assert not _data.isequal(A, A + B * (atol * 10), atol, 0)

@pytest.mark.parametrize("shape", [(1, 20), (20, 20), (20, 2)])
def test_csr_mismatch_sort(self, shape):
A = conftest.random_csr(shape, 0.5, False)
B = A.copy().sort_indices()
assert _data.isequal(A, B)

@pytest.mark.parametrize("shape", [(1, 20), (20, 20), (20, 2)])
def test_dia_mismatch_sort(self, shape):
A = conftest.random_diag(shape, 0.5, False)
B = clean_dia(A)
assert _data.isequal(A, B)
11 changes: 10 additions & 1 deletion qutip/tests/core/test_qobj.py
Original file line number Diff line number Diff line change
Expand Up @@ -442,6 +442,15 @@ def test_QobjEquals():
q2 = qutip.Qobj(-data)
assert q1 != q2

# data's entry are of order 1,
with qutip.CoreOptions(atol=10):
assert q1 == q2
assert q1 != q2 * 100

with qutip.CoreOptions(rtol=10):
assert q1 == q2
assert q1 == q2 * 100


def test_QobjGetItem():
"qutip.Qobj getitem"
Expand Down Expand Up @@ -1273,4 +1282,4 @@ def test_qobj_dtype(dtype):
@pytest.mark.parametrize('dtype', ["CSR", "Dense", "Dia"])
def test_dtype_in_info_string(dtype):
obj = qutip.qeye(2, dtype=dtype)
assert dtype.lower() in str(obj).lower()
assert dtype.lower() in str(obj).lower()

0 comments on commit ad05a22

Please sign in to comment.