Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ITML raises error on toy example #177

Open
wdevazelhes opened this issue Mar 6, 2019 · 0 comments
Open

ITML raises error on toy example #177

wdevazelhes opened this issue Mar 6, 2019 · 0 comments

Comments

@wdevazelhes
Copy link
Member

Description

ITML raises an error on this example (1D iris). We should investigate why, and see if it is the surface of some
problem or not

Note that if we deduplicate, (remove pairs like [a, a]), there is no more problem, so maybe this is the direction to look for

Steps/Code to Reproduce

from metric_learn import ITML
from sklearn.datasets import load_iris
from sklearn.utils import shuffle
from metric_learn import Constraints
from sklearn.utils import check_random_state
import numpy as np

SEED = 42

input_data, labels = load_iris(return_X_y=True)
X, y = shuffle(input_data, labels, random_state=SEED)
num_constraints = 50
constraints = Constraints(y)
pairs = (constraints
      .positive_negative_pairs(num_constraints, same_length=True,
                               random_state=check_random_state(SEED)))
c = np.vstack([np.column_stack(pairs[:2]), np.column_stack(pairs[2:])])
target = np.concatenate([np.ones(pairs[0].shape[0]),
                           - np.ones(pairs[0].shape[0])])
c, target = shuffle(c, target, random_state=SEED)
pairs = X[c]
trunc_data = X[c][..., :1]
itml = ITML()
itml.fit(trunc_data, target)
print(itml.get_mahalanobis_matrix())
print(itml.predict(trunc_data))

Expected Results

No error is thrown.

Actual Results

/home/will/anaconda3/envs/standard/bin/python /home/will/.PyCharmCE2018.3/config/scratches/scratch_50.py
/home/will/Code/metric-learn/metric_learn/itml.py:95: RuntimeWarning: divide by zero encountered in double_scalars
  alpha = min(_lambda[i], gamma_proj*(1./wtw - 1./pos_bhat[i]))
/home/will/Code/metric-learn/metric_learn/itml.py:105: RuntimeWarning: divide by zero encountered in double_scalars
  alpha = min(_lambda[i+num_pos], gamma_proj*(1./neg_bhat[i] - 1./wtw))
/home/will/Code/metric-learn/metric_learn/itml.py:107: RuntimeWarning: invalid value encountered in double_scalars
  beta = -alpha/(1 + alpha*wtw)
/home/will/Code/metric-learn/metric_learn/itml.py:116: RuntimeWarning: invalid value encountered in double_scalars
  conv = np.abs(lambdaold - _lambda).sum() / normsum
/home/will/Code/metric-learn/metric_learn/itml.py:106: RuntimeWarning: invalid value encountered in double_scalars
  _lambda[i+num_pos] -= alpha
/home/will/Code/metric-learn/metric_learn/itml.py:108: RuntimeWarning: divide by zero encountered in double_scalars
  neg_bhat[i] = 1./((1 / neg_bhat[i]) - (alpha / gamma))
/home/will/Code/metric-learn/metric_learn/itml.py:108: RuntimeWarning: invalid value encountered in double_scalars
  neg_bhat[i] = 1./((1 / neg_bhat[i]) - (alpha / gamma))
Traceback (most recent call last):
  File "/home/will/.PyCharmCE2018.3/config/scratches/scratch_50.py", line 24, in <module>
    itml.fit(trunc_data, target)
  File "/home/will/Code/metric-learn/metric_learn/itml.py", line 179, in fit
    return self._fit(pairs, y, bounds=bounds)
  File "/home/will/Code/metric-learn/metric_learn/itml.py", line 127, in _fit
    self.transformer_ = transformer_from_metric(A)
  File "/home/will/Code/metric-learn/metric_learn/_util.py", line 356, in transformer_from_metric
    return np.linalg.cholesky(metric).T
  File "/home/will/anaconda3/envs/standard/lib/python3.7/site-packages/numpy/linalg/linalg.py", line 733, in cholesky
    r = gufunc(a, signature=signature, extobj=extobj)
  File "/home/will/anaconda3/envs/standard/lib/python3.7/site-packages/numpy/linalg/linalg.py", line 92, in _raise_linalgerror_nonposdef
    raise LinAlgError("Matrix is not positive definite")
numpy.linalg.linalg.LinAlgError: Matrix is not positive definite

Versions

Linux-4.4.0-142-generic-x86_64-with-debian-stretch-sid
Python 3.7.1 (default, Dec 14 2018, 19:28:38)
[GCC 7.3.0]
NumPy 1.15.4
SciPy 1.2.0
Scikit-Learn 0.20.2
Metric-Learn 0.4.0

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants