Permalink
Browse files

BUG: fix GaussianHMM.fit to allow input sequences of different lengths

  • Loading branch information...
1 parent 2377694 commit d1c6e5a657456bd47669b3988a5aed20f6892c7d @ronw ronw committed with fabianp Jan 2, 2011
Showing with 37 additions and 5 deletions.
  1. +5 −5 scikits/learn/hmm.py
  2. +20 −0 scikits/learn/tests/test_hmm.py
  3. +12 −0 scikits/learn/tests/test_mixture.py
View
@@ -327,8 +327,6 @@ def fit(self, obs, n_iter=10, thresh=1e-2, params=string.letters,
small). You can fix this by getting more training data, or
decreasing `covars_prior`.
"""
- obs = np.asanyarray(obs)
-
self._init(obs, init_params)
logprob = []
@@ -679,11 +677,13 @@ def _generate_sample_from_state(self, state):
def _init(self, obs, params='stmc'):
super(GaussianHMM, self)._init(obs, params=params)
- if hasattr(self, 'n_features') and self.n_features != obs.shape[2]:
+ if (hasattr(self, 'n_features')
+ and self.n_features != obs[0].shape[1]):
raise ValueError('Unexpected number of dimensions, got %s but '
- 'expected %s' % (obs.shape[2], self.n_features))
+ 'expected %s' % (obs[0].shape[1],
+ self.n_features))
- self.n_features = obs.shape[2]
+ self.n_features = obs[0].shape[1]
if 'm' in params:
self._means = cluster.KMeans(
@@ -330,6 +330,16 @@ def test_fit(self, params='stmc', n_iter=15, verbose=False, **kwargs):
% (self.cvtype, params, trainll, np.diff(trainll)))
self.assertTrue(np.all(np.diff(trainll) > -0.5))
+ def test_fit_works_on_sequences_of_different_length(self):
+ obs = [np.random.rand(3, self.n_features),
+ np.random.rand(4, self.n_features),
+ np.random.rand(5, self.n_features)]
+
+ h = hmm.GaussianHMM(self.n_states, self.cvtype)
+ # This shouldn't raise
+ # ValueError: setting an array element with a sequence.
+ h.fit(obs)
+
def test_fit_with_priors(self, params='stmc', n_iter=10,
verbose=False):
startprob_prior = 10 * self.startprob + 2.0
@@ -612,6 +622,16 @@ def test_fit(self, params='stmwc', n_iter=5, verbose=True, **kwargs):
np.diff(trainll))
self.assertTrue(np.all(np.diff(trainll) > -0.5))
+ def test_fit_works_on_sequences_of_different_length(self):
+ obs = [np.random.rand(3, self.n_features),
+ np.random.rand(4, self.n_features),
+ np.random.rand(5, self.n_features)]
+
+ h = hmm.GMMHMM(self.n_states, cvtype=self.cvtype)
+ # This shouldn't raise
+ # ValueError: setting an array element with a sequence.
+ h.fit(obs)
+
class TestGMMHMMWithSphericalCovars(TestGMMHMM):
cvtype = 'spherical'
@@ -169,6 +169,18 @@ def test_GMM_attributes():
assert_raises(ValueError, mixture.GMM, n_states=20, cvtype='badcvtype')
+def test_GMM_fit_works_on_sequences_of_different_length():
+ ndim = 3
+ obs = [np.random.rand(3, ndim),
+ np.random.rand(4, ndim),
+ np.random.rand(5, ndim)]
+
+ gmm = mixture.GMM(n_states=1)
+ # This shouldn't raise
+ # ValueError: setting an array element with a sequence.
+ gmm.fit(obs)
+
+
class GMMTester():
n_states = 10
n_features = 4

0 comments on commit d1c6e5a

Please sign in to comment.