Skip to content

Commit

Permalink
[MRG+1] Fixes scikit-learn#8198 - error in datasets.make_moons (sciki…
Browse files Browse the repository at this point in the history
  • Loading branch information
levy5674 authored and sergeyf committed Feb 28, 2017
1 parent 5aadcb4 commit c43f5a7
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 2 deletions.
4 changes: 4 additions & 0 deletions doc/whats_new.rst
Expand Up @@ -140,6 +140,10 @@ Enhancements
Bug fixes
.........

- Fixed a bug where :func:`sklearn.datasets.make_moons` gives an
incorrect result when ``n_samples`` is odd.
:issue:`8198` by :user:`Josh Levy <levy5674>`.

- Fixed a bug where :class:`sklearn.linear_model.LassoLars` does not give
the same result as the LassoLars implementation available
in R (lars library). :issue:`7849` by :user:`Jair Montoya Martinez <jmontoyam>`
Expand Down
4 changes: 2 additions & 2 deletions sklearn/datasets/samples_generator.py
Expand Up @@ -665,8 +665,8 @@ def make_moons(n_samples=100, shuffle=True, noise=None, random_state=None):

X = np.vstack((np.append(outer_circ_x, inner_circ_x),
np.append(outer_circ_y, inner_circ_y))).T
y = np.hstack([np.zeros(n_samples_in, dtype=np.intp),
np.ones(n_samples_out, dtype=np.intp)])
y = np.hstack([np.zeros(n_samples_out, dtype=np.intp),
np.ones(n_samples_in, dtype=np.intp)])

if shuffle:
X, y = util_shuffle(X, y, random_state=generator)
Expand Down
10 changes: 10 additions & 0 deletions sklearn/datasets/tests/test_samples_generator.py
Expand Up @@ -24,6 +24,7 @@
from sklearn.datasets import make_friedman2
from sklearn.datasets import make_friedman3
from sklearn.datasets import make_low_rank_matrix
from sklearn.datasets import make_moons
from sklearn.datasets import make_sparse_coded_signal
from sklearn.datasets import make_sparse_uncorrelated
from sklearn.datasets import make_spd_matrix
Expand Down Expand Up @@ -360,3 +361,12 @@ def test_make_checkerboard():
X2, _, _ = make_checkerboard(shape=(100, 100), n_clusters=2,
shuffle=True, random_state=0)
assert_array_equal(X1, X2)


def test_make_moons():
X, y = make_moons(3, shuffle=False)
for x, label in zip(X, y):
center = [0.0, 0.0] if label == 0 else [1.0, 0.5]
dist_sqr = ((x - center) ** 2).sum()
assert_almost_equal(dist_sqr, 1.0,
err_msg="Point is not on expected unit circle")

0 comments on commit c43f5a7

Please sign in to comment.