Skip to content

Commit

Permalink
Use Epsilon instead of an arbitrary small number
Browse files Browse the repository at this point in the history
  • Loading branch information
Ericgig committed Jul 6, 2023
1 parent 9e4bb6c commit 932ad81
Show file tree
Hide file tree
Showing 3 changed files with 58 additions and 21 deletions.
44 changes: 40 additions & 4 deletions qutip/core/data/convert.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,44 @@ from qutip.core.data.base cimport Data
__all__ = ['to', 'create']


class _Epsilon:
"""
Constant for an small weight non-null weight.
Use to set `Data` specialisation just over direct specialisation.
"""
def __repr__(self):
return "EPSILON"

def __eq__(self, other):
if isinstance(other, _Epsilon):
return True
return NotImplemented

def __add__(self, other):
if isinstance(other, _Epsilon):
return self
return other

def __radd__(self, other):
if isinstance(other, _Epsilon):
return self
return other

def __lt__(self, other):
""" positive number > _Epsilon > 0 """
if isinstance(other, _Epsilon):
return False
return other > 0.

def __gt__(self, other):
if isinstance(other, _Epsilon):
return False
return other <= 0.


EPSILON = _Epsilon()


def _raise_if_unconnected(dtype_list, weights):
unconnected = {}
for i, type_ in enumerate(dtype_list):
Expand Down Expand Up @@ -152,7 +190,6 @@ cdef class _to:
cdef dict _convert
cdef readonly dict weight
cdef readonly dict _str2type
cdef readonly float anydataweight

def __init__(self):
self._direct_convert = {}
Expand All @@ -161,7 +198,6 @@ cdef class _to:
self.weight = {}
self.dispatchers = []
self._str2type = {}
self.anydataweight = 0.001

def add_conversions(self, converters):
"""
Expand Down Expand Up @@ -269,9 +305,9 @@ cdef class _to:
_converter(convert[::-1], to_t, from_t)
for dtype in self.dtypes:
self.weight[(dtype, Data)] = 1.
self.weight[(Data, dtype)] = self.anydataweight
self.weight[(Data, dtype)] = EPSILON
self._convert[(dtype, Data)] = _partial_converter(self, dtype)
self._convert[(Data, dtype)] = dummyconverter
self._convert[(Data, dtype)] = identity_converter
for dispatcher in self.dispatchers:
dispatcher.rebuild_lookup()

Expand Down
7 changes: 4 additions & 3 deletions qutip/core/data/dispatch.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import itertools
import warnings

from .convert import to as _to
from .convert import EPSILON

cimport cython
from libc cimport math
Expand All @@ -33,9 +34,9 @@ cdef double _conversion_weight(tuple froms, tuple tos, dict weight_map, bint out
)
if out:
n = n - 1
weight += weight_map[froms[n], tos[n]]
weight = weight + weight_map[froms[n], tos[n]]
for i in range(n):
weight += weight_map[tos[i], froms[i]]
weight = weight + weight_map[tos[i], froms[i]]
return weight


Expand Down Expand Up @@ -307,7 +308,7 @@ cdef class Dispatcher:
if cur == math.INFINITY:
raise ValueError("No valid specialisations found")

if weight <= 0.01 and not (output and types[-1] is Data):
if weight in [EPSILON, 0.] and not (output and types[-1] is Data):
self._lookup[in_types] = function
else:
if output:
Expand Down
28 changes: 14 additions & 14 deletions qutip/tests/core/data/test_dispatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,29 +77,29 @@ def f(a=None, b=None, c=None, /):
if not output:
out = dispatched(*ins)
out = dispatched[in_types](*ins)
continue

out = dispatched(*ins)
assert out is not None
else:
out = dispatched(*ins)
assert out is not None

out = dispatched[in_types](*ins)
assert out is not None
out = dispatched[in_types](*ins)
assert out is not None

if output:
for out_dtype in _data.to.dtypes:
if output:
for out_dtype in _data.to.dtypes:

out = dispatched[in_types + (out_dtype,)](*ins)
if output:
assert isinstance(out, out_dtype)
out = dispatched[in_types + (out_dtype,)](*ins)
if output:
assert isinstance(out, out_dtype)

out = dispatched(*ins, dtype=out_dtype)
if output:
assert isinstance(out, out_dtype)
out = dispatched(*ins, dtype=out_dtype)
if output:
assert isinstance(out, out_dtype)


def test_Data_low_priority_one_dispatch():
class func():
__name__ = ""
__name__ = "dummy name"
def __call__(self, a, /):
return _data.zeros[_data.Dense](1, 1)

Expand Down

0 comments on commit 932ad81

Please sign in to comment.