Skip to content

Commit

Permalink
add shape assertions to test
Browse files Browse the repository at this point in the history
  • Loading branch information
ljwolf committed Oct 20, 2016
1 parent 1acd7f7 commit f25eacf
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 2 deletions.
2 changes: 1 addition & 1 deletion sklearn/mixture/base.py
Expand Up @@ -385,7 +385,7 @@ def sample(self, n_samples=1):

_, n_features = self.means_.shape
rng = check_random_state(self.random_state)
n_samples_comp = rng.multinomial(n_samples, self.weights_).astype(int)
n_samples_comp = rng.multinomial(n_samples, self.weights_)

if self.covariance_type == 'full':
X = np.vstack([
Expand Down
10 changes: 9 additions & 1 deletion sklearn/mixture/tests/test_gaussian_mixture.py
Expand Up @@ -935,7 +935,8 @@ def test_sample():
gmm.sample, 0)

# Just to make sure the class samples correctly
X_s, y_s = gmm.sample(20000)
n_samples = 20000
X_s, y_s = gmm.sample(n_samples)
for k in range(n_features):
if covar_type == 'full':
assert_array_almost_equal(gmm.covariances_[k],
Expand All @@ -956,6 +957,13 @@ def test_sample():
for k in range(n_features)])
assert_array_almost_equal(gmm.means_, means_s, decimal=1)

# Check that sizes that are drawn match what is requested
assert_equal(X_s.shape, (n_samples, n_components))
for sample_size in range(1, 50):
X_s, _ = gmm.sample(sample_size)
assert_equal(X_s.shape, (sample_size, n_components))



@ignore_warnings(category=ConvergenceWarning)
def test_init():
Expand Down

0 comments on commit f25eacf

Please sign in to comment.