Skip to content

Undefined behaviour for LSML #176

@wdevazelhes

Description

@wdevazelhes

Description

LSML returns an error in some particular cases. We should decide what to do in those cases, either let it as such, or return a particular error. This error occurs at these lines :
https://github.com/metric-learn/metric-learn/blob/bf5c7224cc7ad4c025e15b247a80e076b7f75062/metric_learn/lsml.py#L129-L130

Sometimes a quotient above can be zero (for example if in input we have quadruplets that have the two last points identical)
I'm not sure that this case would happen in a real-life case (like if the data is clean, there is no duplicates etc...) I need to think more about it

Steps/Code to Reproduce

Example:

from metric_learn import LSML
import numpy as np
quadruplets = np.array([[[3,4.52], [3.5, 2], [1, 2], [1, 2]],
                        [[3.1234, 2.6526], [3.13451, 2.572], [31.346, 2.13451632], [3.13461, 2.727]],
                        [[3.725712, 2.1577126], [3.135717, 2.17517], [3.175472, 6.137], [6.13571, 6.13471]]])
lsml = LSML()
lsml.fit(quadruplets)

Expected Results

To be defined

Actual Results

/home/will/Code/metric-learn/metric_learn/lsml.py:130: RuntimeWarning: divide by zero encountered in double_scalars
  (1-np.sqrt(dab/dcd))*np.outer(vcd, vcd))
/home/will/Code/metric-learn/metric_learn/lsml.py:130: RuntimeWarning: invalid value encountered in multiply
  (1-np.sqrt(dab/dcd))*np.outer(vcd, vcd))
Traceback (most recent call last):
  File "/home/will/.PyCharmCE2018.3/config/scratches/scratch_51.py", line 7, in <module>
    lsml.fit(quadruplets)
  File "/home/will/Code/metric-learn/metric_learn/lsml.py", line 167, in fit
    return self._fit(quadruplets, weights=weights)
  File "/home/will/Code/metric-learn/metric_learn/lsml.py", line 79, in _fit
    grad_norm = scipy.linalg.norm(grad)
  File "/home/will/anaconda3/envs/standard/lib/python3.7/site-packages/scipy/linalg/misc.py", line 137, in norm
    a = np.asarray_chkfinite(a)
  File "/home/will/anaconda3/envs/standard/lib/python3.7/site-packages/numpy/lib/function_base.py", line 461, in asarray_chkfinite
    "array must not contain infs or NaNs")
ValueError: array must not contain infs or NaNs

Versions

Linux-4.4.0-142-generic-x86_64-with-debian-stretch-sid
Python 3.7.1 (default, Dec 14 2018, 19:28:38) 
[GCC 7.3.0]
NumPy 1.15.4
SciPy 1.2.0
Scikit-Learn 0.20.2
Metric-Learn 0.4.0

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions