-
Notifications
You must be signed in to change notification settings - Fork 231
Open
Description
Description
LSML returns an error in some particular cases. We should decide what to do in those cases, either let it as such, or return a particular error. This error occurs at these lines :
https://github.com/metric-learn/metric-learn/blob/bf5c7224cc7ad4c025e15b247a80e076b7f75062/metric_learn/lsml.py#L129-L130
Sometimes a quotient above can be zero (for example if in input we have quadruplets that have the two last points identical)
I'm not sure that this case would happen in a real-life case (like if the data is clean, there is no duplicates etc...) I need to think more about it
Steps/Code to Reproduce
Example:
from metric_learn import LSML
import numpy as np
quadruplets = np.array([[[3,4.52], [3.5, 2], [1, 2], [1, 2]],
[[3.1234, 2.6526], [3.13451, 2.572], [31.346, 2.13451632], [3.13461, 2.727]],
[[3.725712, 2.1577126], [3.135717, 2.17517], [3.175472, 6.137], [6.13571, 6.13471]]])
lsml = LSML()
lsml.fit(quadruplets)Expected Results
To be defined
Actual Results
/home/will/Code/metric-learn/metric_learn/lsml.py:130: RuntimeWarning: divide by zero encountered in double_scalars
(1-np.sqrt(dab/dcd))*np.outer(vcd, vcd))
/home/will/Code/metric-learn/metric_learn/lsml.py:130: RuntimeWarning: invalid value encountered in multiply
(1-np.sqrt(dab/dcd))*np.outer(vcd, vcd))
Traceback (most recent call last):
File "/home/will/.PyCharmCE2018.3/config/scratches/scratch_51.py", line 7, in <module>
lsml.fit(quadruplets)
File "/home/will/Code/metric-learn/metric_learn/lsml.py", line 167, in fit
return self._fit(quadruplets, weights=weights)
File "/home/will/Code/metric-learn/metric_learn/lsml.py", line 79, in _fit
grad_norm = scipy.linalg.norm(grad)
File "/home/will/anaconda3/envs/standard/lib/python3.7/site-packages/scipy/linalg/misc.py", line 137, in norm
a = np.asarray_chkfinite(a)
File "/home/will/anaconda3/envs/standard/lib/python3.7/site-packages/numpy/lib/function_base.py", line 461, in asarray_chkfinite
"array must not contain infs or NaNs")
ValueError: array must not contain infs or NaNsVersions
Linux-4.4.0-142-generic-x86_64-with-debian-stretch-sid
Python 3.7.1 (default, Dec 14 2018, 19:28:38)
[GCC 7.3.0]
NumPy 1.15.4
SciPy 1.2.0
Scikit-Learn 0.20.2
Metric-Learn 0.4.0
Metadata
Metadata
Assignees
Labels
No labels