Skip to content

Commit

Permalink
add tests for functions
Browse files Browse the repository at this point in the history
  • Loading branch information
trevorstephens committed Aug 28, 2016
1 parent a3d0f9f commit ac6b602
Show file tree
Hide file tree
Showing 2 changed files with 75 additions and 0 deletions.
48 changes: 48 additions & 0 deletions gplearn/tests/test_functions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
"""Testing the Genetic Programming functions module."""

# Author: Trevor Stephens <trevorstephens.com>
#
# License: BSD 3 clause

import numpy as np

from numpy import maximum

from gplearn.functions import _protected_sqrt, make_function
from gplearn.skutils.testing import assert_raises


def test_validate_function():
"""Check that valid functions are accepted & invalid ones raise error"""

# Check arity tests
fun = make_function(function=_protected_sqrt, name='sqrt', arity=1)
# non-integer arity
assert_raises(ValueError, make_function, _protected_sqrt, 'sqrt', '1')
assert_raises(ValueError, make_function, _protected_sqrt, 'sqrt', 1.0)
# non-matching arity
assert_raises(ValueError, make_function, _protected_sqrt, 'sqrt', 2)
assert_raises(ValueError, make_function, maximum, 'max', 1)

# Check name test
assert_raises(ValueError, make_function, _protected_sqrt, 2, '1')

# Check return type tests
def bad_fun1(x1, x2):
return 'ni'
assert_raises(ValueError, make_function, bad_fun1, 'ni', 2)

# Check return shape tests
def bad_fun2(x1):
return np.ones((2, 1))
assert_raises(ValueError, make_function, bad_fun2, 'ni', 1)

# Check closure for negatives test
def _unprotected_sqrt(x1):
return np.sqrt(x1)
assert_raises(ValueError, make_function, _unprotected_sqrt, 'sqrt', 1)

# Check closure for zeros test
def _unprotected_div(x1, x2):
return np.divide(x1, x2)
assert_raises(ValueError, make_function, _unprotected_div, 'div', 2)
27 changes: 27 additions & 0 deletions gplearn/tests/test_genetic.py
Original file line number Diff line number Diff line change
Expand Up @@ -908,6 +908,33 @@ def test_print_overloading_estimator():
assert_true(output_fitted == output_program)


def test_validate_functions():
"""Check that valid functions are accepted & invalid ones raise error"""

random_state = check_random_state(415)
X = np.reshape(random_state.uniform(size=50), (5, 10))
y = random_state.uniform(size=5)

for Symbolic in (SymbolicRegressor, SymbolicTransformer):
# These should be fine
est = Symbolic(generations=2, random_state=0,
function_set=(add2, sub2, mul2, div2))
est.fit(boston.data, boston.target)
est = Symbolic(generations=2, random_state=0,
function_set=('add', 'sub', 'mul', div2))
est.fit(boston.data, boston.target)

# These should fail
est = Symbolic(generations=2, random_state=0,
function_set=('ni', 'sub', 'mul', div2))
assert_raises(ValueError, est.fit, boston.data, boston.target)
est = Symbolic(generations=2, random_state=0,
function_set=(7, 'sub', 'mul', div2))
assert_raises(ValueError, est.fit, boston.data, boston.target)
est = Symbolic(generations=2, random_state=0, function_set=())
assert_raises(ValueError, est.fit, boston.data, boston.target)


if __name__ == "__main__":
import nose
nose.runmodule()

0 comments on commit ac6b602

Please sign in to comment.