Permalink
Browse files

Set the seed of the random_state generators to have nicely aligned re…

…sults
  • Loading branch information...
1 parent d3cb559 commit 47071524dd406a531447c1a38c0e382f8866dbac @NelleV NelleV committed with GaelVaroquaux May 29, 2012
Showing with 16 additions and 6 deletions.
  1. +5 −3 examples/manifold/plot_mds.py
  2. +11 −3 sklearn/manifold/mds.py
@@ -23,7 +23,8 @@
from sklearn.decomposition import PCA
n_samples = 20
-X_true = np.random.randint(0, 20, 2 * n_samples)
+seed = np.random.RandomState(seed=3)
+X_true = seed.randint(0, 20, 2 * n_samples)
X_true = X_true.reshape((n_samples, 2))
# Center the data
X_true -= X_true.mean()
@@ -37,12 +38,13 @@
similarities += noise
mds = manifold.MDS(n_components=2, max_iter=3000,
- eps=1e-9)
+ eps=1e-9, random_state=seed,
+ n_jobs=1)
pos = mds.fit(similarities).positions_
nmds = manifold.MDS(n_components=2, metric=False,
max_iter=3000,
- eps=1e-9)
+ eps=1e-9, random_state=seed, n_jobs=1)
npos = mds.fit(similarities).positions_
# Rotate the data
View
@@ -139,7 +139,7 @@ def _smacof_single(similarities, metric=True, n_components=2, init=None,
raise ValueError("similarities must be a square array (shape=%d)" % \
n_samples)
- if np.any(similarities != similarities.T):
+ if np.any((similarities - similarities.T) > 100 * np.finfo(np.float).resolution):
raise ValueError("similarities must be symmetric")
sim_flat = ((1 - np.tri(n_samples)) * similarities).flatten()
@@ -364,6 +364,11 @@ class MDS(BaseEstimator):
(n_cpus + 1 - n_jobs) are used. Thus for n_jobs = -2, all CPUs but one
are used.
+ random_state: integer or numpy.RandomState, optional
+ The generator used to initialize the centers. If an integer is
+ given, it fixes the seed. Defaults to the global numpy random
+ number generator.
+
Attributes
----------
@@ -388,14 +393,16 @@ class MDS(BaseEstimator):
"""
def __init__(self, n_components=2, metric=True, n_init=8,
- max_iter=300, verbose=0, eps=1e-3, n_jobs=1):
+ max_iter=300, verbose=0, eps=1e-3, n_jobs=1,
+ random_state=None):
self.n_components = n_components
self.metric = metric
self.n_init = n_init
self.max_iter = max_iter
self.eps = eps
self.verbose = verbose
self.n_jobs = n_jobs
+ self.random_state = None
def fit(self, X, init=None, y=None):
"""
@@ -416,7 +423,8 @@ def fit(self, X, init=None, y=None):
n_init=self.n_init,
max_iter=self.max_iter,
verbose=self.verbose,
- eps=self.eps)
+ eps=self.eps,
+ random_state=self.random_state)
return self
def fit_transform(self, X, init=None, y=None):

0 comments on commit 4707152

Please sign in to comment.