Skip to content

Commit

Permalink
change n_obs -> nobs, add a test for nobs
Browse files Browse the repository at this point in the history
  • Loading branch information
yl565 committed Dec 11, 2017
1 parent 4fa1309 commit 47ef9d2
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 6 deletions.
8 changes: 4 additions & 4 deletions statsmodels/multivariate/factor.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,13 +48,13 @@ class Factor(Model):
endog_names: str
Names of endogeous variables.
If specified, it will be used instead of the column names in endog
n_obs: int
nobs: int
The number of observations. To be used together with `corr`
Should be equals to the number of rows in `endog`.
"""
def __init__(self, endog, n_factor, corr=None, method='pa', smc=True,
missing='drop', endog_names=None, n_obs=None):
missing='drop', endog_names=None, nobs=None):
if endog is not None:
k_endog = endog.shape[1]
elif corr is not None:
Expand Down Expand Up @@ -101,8 +101,8 @@ def __init__(self, endog, n_factor, corr=None, method='pa', smc=True,
self.corr = None

# Check validity of n_obs
if n_obs is not None:
if endog is not None and endog.shape[0] != n_obs:
if nobs is not None:
if endog is not None and endog.shape[0] != nobs:
raise ValueError('n_obs must be equal to the number of rows in endog')

# Do not preprocess endog if None
Expand Down
10 changes: 8 additions & 2 deletions statsmodels/multivariate/tests/test_factor.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,13 +32,19 @@
columns=['Loc', 'Basal', 'Occ', 'Max', 'id', 'alt'])


def test_specify_nobs():
# Test specifying nobs
Factor(np.zeros([10, 3]), 2, nobs=10)
assert_raises(ValueError, Factor, np.zeros([10, 3]), 2, nobs=9)


def test_auto_col_name():
# Test auto generated variable names when endog_names is None
mod = Factor(None, 2, corr=np.zeros([11, 11]),endog_names=None,
mod = Factor(None, 2, corr=np.zeros([11, 11]), endog_names=None,
smc=False)
assert_array_equal(mod.endog_names,
['var00', 'var01', 'var02', 'var03', 'var04', 'var05',
'var06', 'var07', 'var08', 'var09', 'var10',])
'var06', 'var07', 'var08', 'var09', 'var10'])


def test_direct_corr_matrix():
Expand Down

0 comments on commit 47ef9d2

Please sign in to comment.