-
Notifications
You must be signed in to change notification settings - Fork 2.8k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #1699 from josef-pkt/reparameterize
- Loading branch information
Showing
3 changed files
with
188 additions
and
3 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,71 @@ | ||
# -*- coding: utf-8 -*- | ||
""" | ||
Created on Tue May 27 13:26:01 2014 | ||
Author: Josef Perktold | ||
License: BSD-3 | ||
""" | ||
|
||
import numpy as np | ||
from numpy.testing import assert_allclose, assert_equal | ||
from scipy import stats | ||
|
||
from statsmodels.regression.linear_model import OLS | ||
from statsmodels.tools.transform_model import StandardizeTransform | ||
|
||
|
||
def test_standardize1(): | ||
|
||
np.random.seed(123) | ||
x = 1 + np.random.randn(5, 4) | ||
|
||
transf = StandardizeTransform(x) | ||
xs1 = transf(x) | ||
|
||
assert_allclose(transf.mean, x.mean(0), rtol=1e-13) | ||
assert_allclose(transf.scale, x.std(0, ddof=1), rtol=1e-13) | ||
|
||
xs2 = stats.zscore(x, ddof=1) | ||
assert_allclose(xs1, xs2, rtol=1e-13, atol=1e-20) | ||
|
||
# check we use stored transformation | ||
xs4 = transf(2 * x) | ||
assert_allclose(xs4, (2*x - transf.mean) / transf.scale, rtol=1e-13, atol=1e-20) | ||
|
||
|
||
# affine transform doesn't change standardized | ||
x2 = 2 * x + np.random.randn(4) | ||
transf2 = StandardizeTransform(x2) | ||
xs3 = transf2(x2) | ||
assert_allclose(xs3, xs1, rtol=1e-13, atol=1e-20) | ||
|
||
# check constant | ||
x5 = np.column_stack((np.ones(x.shape[0]), x)) | ||
transf5 = StandardizeTransform(x5) | ||
xs5 = transf5(x5) | ||
|
||
assert_equal(transf5.const_idx, 0) | ||
assert_equal(xs5[:, 0], np.ones(x.shape[0])) | ||
assert_allclose(xs5[:, 1:], xs1, rtol=1e-13, atol=1e-20) | ||
|
||
|
||
def test_standardize_ols(): | ||
|
||
np.random.seed(123) | ||
nobs = 20 | ||
x = 1 + np.random.randn(nobs, 4) | ||
exog = np.column_stack((np.ones(nobs), x)) | ||
endog = exog.sum(1) + np.random.randn(nobs) | ||
|
||
res2 = OLS(endog, exog).fit() | ||
transf = StandardizeTransform(exog) | ||
exog_st = transf(exog) | ||
res1 = OLS(endog, exog_st).fit() | ||
params = transf.transform_params(res1.params) | ||
assert_allclose(params, res2.params, rtol=1e-13) | ||
|
||
|
||
|
||
test_standardize1() | ||
test_standardize_ols() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,98 @@ | ||
# -*- coding: utf-8 -*- | ||
""" | ||
Created on Tue May 27 13:23:24 2014 | ||
Author: Josef Perktold | ||
License: BSD-3 | ||
""" | ||
|
||
import numpy as np | ||
|
||
|
||
class StandardizeTransform(object): | ||
"""class to reparameterize a model for standardized exog | ||
Parameters | ||
---------- | ||
data : array_like | ||
data that is standardized along axis=0 | ||
ddof : None or int | ||
degrees of freedom for calculation of standard deviation. | ||
default is 1, in contrast to numpy.std | ||
const_idx : None or int | ||
If None, then the presence of a constant is detected if the standard | ||
deviation of a column is **equal** to zero. A constant column is | ||
not transformed. If this is an integer, then the corresponding column | ||
will not be transformed. | ||
demean : bool, default is True | ||
If demean is true, then the data will be demeaned, otherwise it will | ||
only be rescaled. | ||
Notes | ||
----- | ||
Warning: Not all options are tested and it is written for one use case. | ||
API changes are expected. | ||
This can be used to transform only the design matrix, exog, in a model, | ||
which is required in some discrete models when the endog cannot be rescaled | ||
or demeaned. | ||
The transformation is full rank and does not drop the constant. | ||
""" | ||
|
||
def __init__(self, data, ddof=1, const_idx=None, demean=True): | ||
data = np.asarray(data) | ||
self.mean = data.mean(0) | ||
self.scale = data.std(0, ddof=1) | ||
|
||
# do not transform a constant | ||
if const_idx is None: | ||
const_idx = np.nonzero(self.scale == 0)[0] | ||
if len(const_idx) == 0: | ||
const_idx = 'nc' | ||
|
||
if const_idx != 'nc': | ||
self.mean[const_idx] = 0 | ||
self.scale[const_idx] = 1 | ||
|
||
|
||
if demean is False: | ||
self.mean = None | ||
|
||
self.const_idx = const_idx | ||
|
||
def transform(self, data): | ||
"""standardize the data using the stored transformation | ||
""" | ||
# could use scipy.stats.zscore instead | ||
if self.mean is None: | ||
return np.asarray(data) / self.scale | ||
else: | ||
return (np.asarray(data) - self.mean) / self.scale | ||
|
||
def transform_params(self, params): | ||
"""Transform parameters of the standardized model to the original model | ||
Parameters | ||
---------- | ||
params : ndarray | ||
parameters estimated with the standardized model | ||
Returns | ||
------- | ||
params_new : ndarray | ||
parameters transformed to the parameterization of the original model | ||
""" | ||
|
||
params_new = params / self.scale | ||
if self.const_idx != 'nc': | ||
params_new[self.const_idx] -= (params_new * self.mean).sum() | ||
|
||
return params_new | ||
|
||
|
||
__call__ = transform | ||
|