Skip to content

Commit

Permalink
Use Epsilon instead of an arbitrary small number
Browse files Browse the repository at this point in the history
  • Loading branch information
Ericgig committed Jul 5, 2023
1 parent 9e4bb6c commit a01128d
Show file tree
Hide file tree
Showing 6 changed files with 66 additions and 29 deletions.
2 changes: 1 addition & 1 deletion qutip/core/_brtensor.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -281,7 +281,7 @@ cdef class _BlochRedfieldElement(_BaseElement):
return BR_eig
return self.H.from_eigbasis(t, BR_eig)

cdef Data matmul_data_t(self, t, Data state, Data out=None):
cpdef Data matmul_data_t(self, t, Data state, Data out=None):
cdef size_t i
cdef double cutoff = self.sec_cutoff * self._compute_spectrum(t)
cdef Data A_eig, BR_eig
Expand Down
4 changes: 2 additions & 2 deletions qutip/core/cy/_element.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ cdef class _BaseElement:
cpdef Data data(self, t)
cpdef object qobj(self, t)
cpdef object coeff(self, t)
cdef Data matmul_data_t(_BaseElement self, t, Data state, Data out=?)
cpdef Data matmul_data_t(_BaseElement self, t, Data state, Data out=?)


cdef class _ConstantElement(_BaseElement):
Expand All @@ -32,7 +32,7 @@ cdef class _FuncElement(_BaseElement):
cdef class _MapElement(_BaseElement):
cdef readonly _FuncElement _base
cdef readonly list _transform
cdef readonly double complex _coeff
cdef readonly object _coeff


cdef class _ProdElement(_BaseElement):
Expand Down
10 changes: 5 additions & 5 deletions qutip/core/cy/_element.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ cdef class _BaseElement:
"Sub-classes of _BaseElement should implement .coeff(t)."
)

cdef Data matmul_data_t(_BaseElement self, t, Data state, Data out=None):
cpdef Data matmul_data_t(_BaseElement self, t, Data state, Data out=None):
"""
Possibly in-place multiplication and addition. Multiplies a given state
by the elemen's value at time ``t`` and adds the result to ``out``.
Expand Down Expand Up @@ -507,7 +507,7 @@ cdef class _MapElement(_BaseElement):

def __mul__(left, right):
cdef _MapElement out, self
cdef double complex factor
cdef object factor
if type(left) is _MapElement:
self = left
factor = right
Expand All @@ -517,7 +517,7 @@ cdef class _MapElement(_BaseElement):
return _MapElement(
self._base,
self._transform.copy(),
self._coeff*factor
self._coeff * factor
)

def __matmul__(left, right):
Expand Down Expand Up @@ -568,7 +568,7 @@ cdef class _ProdElement(_BaseElement):

def __mul__(left, right):
cdef _ProdElement self
cdef double complex factor
cdef object factor
if type(left) is _ProdElement:
self = left
factor = right
Expand All @@ -594,7 +594,7 @@ cdef class _ProdElement(_BaseElement):
cdef double complex out = self._left.coeff(t) * self._right.coeff(t)
return conj(out) if self._conj else out

cdef Data matmul_data_t(_ProdElement self, t, Data state, Data out=None):
cpdef Data matmul_data_t(_ProdElement self, t, Data state, Data out=None):
cdef Data temp
if not self._transform:
temp = self._right.matmul_data_t(t, state)
Expand Down
44 changes: 40 additions & 4 deletions qutip/core/data/convert.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,44 @@ from qutip.core.data.base cimport Data
__all__ = ['to', 'create']


class _Epsilon:
"""
Constant for an small weight non-null weight.
Use to set `Data` specialisation just over direct specialisation.
"""
def __repr__(self):
return "EPSILON"

def __eq__(self, other):
if isinstance(other, _Epsilon):
return True
return NotImplemented

def __add__(self, other):
if isinstance(other, _Epsilon):
return self
return other

def __radd__(self, other):
if isinstance(other, _Epsilon):
return self
return other

def __lt__(self, other):
""" positive number > _Epsilon > 0 """
if isinstance(other, _Epsilon):
return False
return other > 0.

def __gt__(self, other):
if isinstance(other, _Epsilon):
return False
return other <= 0.


EPSILON = _Epsilon()


def _raise_if_unconnected(dtype_list, weights):
unconnected = {}
for i, type_ in enumerate(dtype_list):
Expand Down Expand Up @@ -152,7 +190,6 @@ cdef class _to:
cdef dict _convert
cdef readonly dict weight
cdef readonly dict _str2type
cdef readonly float anydataweight

def __init__(self):
self._direct_convert = {}
Expand All @@ -161,7 +198,6 @@ cdef class _to:
self.weight = {}
self.dispatchers = []
self._str2type = {}
self.anydataweight = 0.001

def add_conversions(self, converters):
"""
Expand Down Expand Up @@ -269,9 +305,9 @@ cdef class _to:
_converter(convert[::-1], to_t, from_t)
for dtype in self.dtypes:
self.weight[(dtype, Data)] = 1.
self.weight[(Data, dtype)] = self.anydataweight
self.weight[(Data, dtype)] = EPSILON
self._convert[(dtype, Data)] = _partial_converter(self, dtype)
self._convert[(Data, dtype)] = dummyconverter
self._convert[(Data, dtype)] = identity_converter
for dispatcher in self.dispatchers:
dispatcher.rebuild_lookup()

Expand Down
7 changes: 4 additions & 3 deletions qutip/core/data/dispatch.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import itertools
import warnings

from .convert import to as _to
from .convert import EPSILON

cimport cython
from libc cimport math
Expand All @@ -33,9 +34,9 @@ cdef double _conversion_weight(tuple froms, tuple tos, dict weight_map, bint out
)
if out:
n = n - 1
weight += weight_map[froms[n], tos[n]]
weight = weight + weight_map[froms[n], tos[n]]
for i in range(n):
weight += weight_map[tos[i], froms[i]]
weight = weight + weight_map[tos[i], froms[i]]
return weight


Expand Down Expand Up @@ -307,7 +308,7 @@ cdef class Dispatcher:
if cur == math.INFINITY:
raise ValueError("No valid specialisations found")

if weight <= 0.01 and not (output and types[-1] is Data):
if weight in [EPSILON, 0.] and not (output and types[-1] is Data):
self._lookup[in_types] = function
else:
if output:
Expand Down
28 changes: 14 additions & 14 deletions qutip/tests/core/data/test_dispatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,29 +77,29 @@ def f(a=None, b=None, c=None, /):
if not output:
out = dispatched(*ins)
out = dispatched[in_types](*ins)
continue

out = dispatched(*ins)
assert out is not None
else:
out = dispatched(*ins)
assert out is not None

out = dispatched[in_types](*ins)
assert out is not None
out = dispatched[in_types](*ins)
assert out is not None

if output:
for out_dtype in _data.to.dtypes:
if output:
for out_dtype in _data.to.dtypes:

out = dispatched[in_types + (out_dtype,)](*ins)
if output:
assert isinstance(out, out_dtype)
out = dispatched[in_types + (out_dtype,)](*ins)
if output:
assert isinstance(out, out_dtype)

out = dispatched(*ins, dtype=out_dtype)
if output:
assert isinstance(out, out_dtype)
out = dispatched(*ins, dtype=out_dtype)
if output:
assert isinstance(out, out_dtype)


def test_Data_low_priority_one_dispatch():
class func():
__name__ = ""
__name__ = "dummy name"
def __call__(self, a, /):
return _data.zeros[_data.Dense](1, 1)

Expand Down

0 comments on commit a01128d

Please sign in to comment.