Skip to content

Commit

Permalink
Merge pull request #1254 from josef-pkt/fix_predict_1032
Browse files Browse the repository at this point in the history
REF: Results.predict convert to array and adjust shape
  • Loading branch information
josef-pkt committed Dec 19, 2013
2 parents 01fd4a2 + ada08b9 commit 0b5ed74
Show file tree
Hide file tree
Showing 2 changed files with 61 additions and 3 deletions.
8 changes: 8 additions & 0 deletions statsmodels/base/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -995,6 +995,14 @@ def predict(self, exog=None, transform=True, *args, **kwargs):
from patsy import dmatrix
exog = dmatrix(self.model.data.orig_exog.design_info.builder,
exog)

if exog is not None:
exog = np.asarray(exog)
if exog.ndim == 1 and (self.model.exog.ndim == 1 or
self.model.exog.shape[1] == 1):
exog = exog[:, None]
exog = np.atleast_2d(exog) # needed in count model shape[1]

return self.model.predict(self.params, exog, *args, **kwargs)


Expand Down
56 changes: 53 additions & 3 deletions statsmodels/base/tests/test_generic_methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,9 @@ def test_ttest_tvalues(self):
tt = res.t_test(mat)

assert_allclose(tt.effect, res.params, rtol=1e-12)
assert_allclose(tt.sd, res.bse, rtol=1e-10)
assert_allclose(tt.tvalue, res.tvalues, rtol=1e-12)
# TODO: tt.sd and tt.tvalue are 2d also for single regressor, squeeze
assert_allclose(np.squeeze(tt.sd), res.bse, rtol=1e-10)
assert_allclose(np.squeeze(tt.tvalue), res.tvalues, rtol=1e-12)
assert_allclose(tt.pvalue, res.pvalues, rtol=5e-10)
assert_allclose(tt.conf_int(), res.conf_int(), rtol=1e-10)

Expand Down Expand Up @@ -99,6 +100,39 @@ def test_fitted(self):
assert_allclose(res.model.endog - fitted, res.resid, rtol=1e-12)
assert_allclose(fitted, res.predict(), rtol=1e-12)

def test_predict_types(self):
res = self.results
# squeeze to make 1d for single regressor test case
p_exog = np.squeeze(np.asarray(res.model.exog[:2]))

# ignore wrapper for isinstance check
from statsmodels.genmod.generalized_linear_model import GLMResults
from statsmodels.discrete.discrete_model import DiscreteResults
results = self.results._results
if (isinstance(results, GLMResults) or
isinstance(results, DiscreteResults)):
# SMOKE test only TODO
res.predict(p_exog)
res.predict(p_exog.tolist())
res.predict(p_exog[0].tolist())
else:
fitted = res.fittedvalues[:2]
assert_allclose(fitted, res.predict(p_exog), rtol=1e-12)
# this needs reshape to column-vector:
assert_allclose(fitted, res.predict(np.squeeze(p_exog).tolist()),
rtol=1e-12)
# only one prediction:
assert_allclose(fitted[:1], res.predict(p_exog[0].tolist()),
rtol=1e-12)
assert_allclose(fitted[:1], res.predict(p_exog[0]),
rtol=1e-12)

# predict doesn't preserve DataFrame, e.g. dot converts to ndarray
# import pandas
# predicted = res.predict(pandas.DataFrame(p_exog))
# assert_(isinstance(predicted, pandas.DataFrame))
# assert_allclose(predicted, fitted, rtol=1e-12)


######### subclasses for individual models, unchanged from test_shrink_pickle
# TODO: check if setup_class is faster than setup
Expand All @@ -113,6 +147,17 @@ def setup(self):
self.results = sm.OLS(y, self.exog).fit()


class TestGenericOLSOneExog(CheckGenericMixin):
# check with single regressor (no constant)

def setup(self):
#fit for each test, because results will be changed by test
x = self.exog[:, 1]
np.random.seed(987689)
y = x + np.random.randn(x.shape[0])
self.results = sm.OLS(y, x).fit()


class TestGenericWLS(CheckGenericMixin):

def setup(self):
Expand Down Expand Up @@ -147,7 +192,12 @@ def setup(self):
data = sm.datasets.randhie.load()
exog = sm.add_constant(data.exog, prepend=False)
mod = sm.NegativeBinomial(data.endog, data.exog)
self.results = mod.fit(disp=0)
start_params = np.array([-0.0565406 , -0.21213599, 0.08783076,
-0.02991835, 0.22901974, 0.0621026,
0.06799283, 0.08406688, 0.18530969,
1.36645452])
self.results = mod.fit(start_params=start_params, disp=0)


class TestGenericLogit(CheckGenericMixin):

Expand Down

0 comments on commit 0b5ed74

Please sign in to comment.