Skip to content

Commit

Permalink
BalancingLearner: test the "cycle" strategy
Browse files Browse the repository at this point in the history
  • Loading branch information
basnijholt committed Apr 29, 2019
1 parent 5c8ff1d commit 858ca89
Showing 1 changed file with 6 additions and 2 deletions.
8 changes: 6 additions & 2 deletions adaptive/tests/test_balancing_learner.py
Expand Up @@ -6,6 +6,9 @@
from adaptive.runner import simple


strategies = ['loss', 'loss_improvements', 'npoints', 'cycle']


def test_balancing_learner_loss_cache():
learner = Learner1D(lambda x: x, bounds=(-1, 1))
learner.tell(-1, -1)
Expand All @@ -26,7 +29,7 @@ def test_balancing_learner_loss_cache():
assert bl.loss(real=True) == real_loss


@pytest.mark.parametrize('strategy', ['loss', 'loss_improvements', 'npoints'])
@pytest.mark.parametrize('strategy', strategies)
def test_distribute_first_points_over_learners(strategy):
for initial_points in [0, 3]:
learners = [Learner1D(lambda x: x, bounds=(-1, 1)) for i in range(10)]
Expand All @@ -41,7 +44,7 @@ def test_distribute_first_points_over_learners(strategy):
assert len(set(i_learner)) == len(learners)


@pytest.mark.parametrize('strategy', ['loss', 'loss_improvements', 'npoints'])
@pytest.mark.parametrize('strategy', strategies)
def test_ask_0(strategy):
learners = [Learner1D(lambda x: x, bounds=(-1, 1)) for i in range(10)]
learner = BalancingLearner(learners, strategy=strategy)
Expand All @@ -53,6 +56,7 @@ def test_ask_0(strategy):
('loss', lambda l: l.loss() < 0.1),
('loss_improvements', lambda l: l.loss() < 0.1),
('npoints', lambda bl: all(l.npoints > 10 for l in bl.learners)),
('cycle', lambda l: l.loss() < 0.1),
])
def test_strategies(strategy, goal):
learners = [Learner1D(lambda x: x, bounds=(-1, 1)) for i in range(10)]
Expand Down

0 comments on commit 858ca89

Please sign in to comment.