Skip to content

Commit

Permalink
Fix issue#42-(1)
Browse files Browse the repository at this point in the history
  • Loading branch information
eggachecat committed Aug 3, 2017
1 parent 946051d commit 25e12cb
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 1 deletion.
5 changes: 4 additions & 1 deletion gplearn/genetic.py
Original file line number Diff line number Diff line change
Expand Up @@ -471,7 +471,10 @@ def fit(self, X, y, sample_weight=None):

if isinstance(self, RegressorMixin):
# Find the best individual in the final generation
self._program = self._programs[-1][np.argmin(fitness)]
if self._metric.greater_is_better:
self._program = self._programs[-1][np.argmax(fitness)]
else:
self._program = self._programs[-1][np.argmin(fitness)]

if isinstance(self, TransformerMixin):
# Find the best individuals in the final generation
Expand Down
25 changes: 25 additions & 0 deletions gplearn/tests/test_genetic.py
Original file line number Diff line number Diff line change
Expand Up @@ -1010,6 +1010,31 @@ def test_warm_start():
assert_equal(cold_program, warm_program)


def test_customizied_regressor_metrics():
"""Check whether parameter greater_is_better works fine"""

x_data = rng.uniform(-1, 1, 100).reshape(50, 2)
y_true = x_data[:, 0] ** 2 + x_data[:, 1] ** 2

est_gp = SymbolicRegressor(metric='mean absolute error', stopping_criteria=0.000001, random_state=415,
parsimony_coefficient=0.001, verbose=0, init_method='full', init_depth=(2, 4))
est_gp.fit(x_data, y_true)
formula = est_gp.__str__()
assert_equal("add(mul(X1, X1), mul(X0, X0))", formula, True)

def neg_mean_absolute_error(y, y_pred, sample_weight):
return -1 * mean_absolute_error(y, y_pred, sample_weight)

customizied_fitness = make_fitness(neg_mean_absolute_error, greater_is_better=True)

c_est_gp = SymbolicRegressor(metric=customizied_fitness, stopping_criteria=-0.000001, random_state=415,
parsimony_coefficient=0.001, verbose=0, init_method='full', init_depth=(2, 4))
c_est_gp.fit(x_data, y_true)
c_formula = c_est_gp.__str__()

assert_equal("add(mul(X1, X1), mul(X0, X0))", c_formula, True)


if __name__ == "__main__":
import nose
nose.runmodule()

0 comments on commit 25e12cb

Please sign in to comment.