Skip to content

Commit

Permalink
Merge pull request #23 from rodrigo-arenas/0.5.xdev
Browse files Browse the repository at this point in the history
Base Callback
  • Loading branch information
rodrigo-arenas committed Jun 22, 2021
2 parents 2c47bee + 8b2e200 commit 4961376
Show file tree
Hide file tree
Showing 10 changed files with 115 additions and 78 deletions.
6 changes: 5 additions & 1 deletion .coveragerc
Original file line number Diff line number Diff line change
Expand Up @@ -9,4 +9,8 @@ omit =
setup.py
[report]
precision = 2
show_missing = True
show_missing = True
exclude_lines =
# Have to re-enable the standard pragma
pragma: no cover
noqa
4 changes: 4 additions & 0 deletions docs/api/callbacks.rst
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
Callbacks
----------

.. autoclass:: sklearn_genetic.callbacks.base.BaseCallback
:members:
:undoc-members: False

.. autoclass:: sklearn_genetic.callbacks.ConsecutiveStopping
:members:
:undoc-members: False
Expand Down
10 changes: 10 additions & 0 deletions docs/release_notes.rst
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,16 @@ Docs:
^^^^^

* Added user guide "Integrating with MLflow"
* Update the tutorial "Custom Callbacks" for new API inheritance behavior

^^^^^^^^^
Internal:
^^^^^^^^^

* Added a base class :class:`~sklearn_genetic.callbacks.base.BaseCallback` from
which all Callbacks must inherit from
* Now coverage report doesn't take into account the lines with # pragma: no cover
and # noqa

What's new in 0.4.1
-------------------
Expand Down
11 changes: 7 additions & 4 deletions docs/tutorials/custom_callback.rst
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,10 @@ sklearn-genetic-opt comes with some pre-defined callbacks,
but you can make one of your own by defining a callable with
certain methods.

The callback must be a class that implements the ``__call__`` and
``on_step`` methods, the result of them must be a bool, ``True`` means
that the optimization must stop, ``False``, means it can continue.
The callback must be a class with inheritance from the class
:class:`~sklearn_genetic.callbacks.base.BaseCallback` that implements the
``__call__`` and ``on_step`` methods, the result of them must be a bool,
``True`` means that the optimization must stop, ``False``, means it can continue.

In this example, we are going to define a dummy callback that
stops the process if there have been more that `N` fitness values
Expand Down Expand Up @@ -49,7 +50,9 @@ that will have all this parameters, so we can rewrite it like this:

.. code-block:: python
class DummyThreshold:
from sklearn_genetic.callbacks.base import BaseCallback
class DummyThreshold(BaseCallback):
def __init__(self, threshold, N, metric='fitness'):
self.threshold = threshold
self.N = N
Expand Down
31 changes: 31 additions & 0 deletions sklearn_genetic/callbacks/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
from abc import ABC, abstractmethod


class BaseCallback(ABC):
"""
Base Callback from which all Callbacks must inherit from
"""

@abstractmethod
def on_step(self, record=None, logbook=None, estimator=None):
"""
Parameters
----------
record: dict: default=None
A logbook record
logbook:
Current stream logbook with the stats required
estimator:
:class:`~sklearn_genetic.GASearchCV` Estimator that is being optimized
Returns
-------
decision: False
Always returns False as this class doesn't take decisions over the optimization
"""

pass # pragma: no cover

@abstractmethod
def __call__(self, record=None, logbook=None, estimator=None):
pass # pragma: no cover
52 changes: 4 additions & 48 deletions sklearn_genetic/callbacks/early_stoppers.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from .validations import check_stats
from .base import BaseCallback


class ThresholdStopping:
class ThresholdStopping(BaseCallback):
"""
Stop the optimization if the metric from
cross validation score is greater or equals than the define threshold
Expand All @@ -24,21 +25,6 @@ def __init__(self, threshold, metric="fitness"):
self.metric = metric

def on_step(self, record, logbook, estimator):
"""
Parameters
----------
record: dict: default=None
A logbook record
logbook:
Current stream logbook with the stats required
estimator:
:class:`~sklearn_genetic.GASearchCV` Estimator that is being optimized
Returns
-------
decision: bool
True if the optimization algorithm must stop, false otherwise
"""
if record is not None:
return record[self.metric] >= self.threshold
elif logbook is not None:
Expand All @@ -54,7 +40,7 @@ def __call__(self, record=None, logbook=None, estimator=None):
return self.on_step(record, logbook, estimator)


class ConsecutiveStopping:
class ConsecutiveStopping(BaseCallback):
"""
Stop the optimization if the current metric value is no greater that at least one metric from the last N generations
"""
Expand All @@ -75,21 +61,6 @@ def __init__(self, generations, metric="fitness"):
self.metric = metric

def on_step(self, record=None, logbook=None, estimator=None):
"""
Parameters
----------
record: dict: default=None
A logbook record
logbook:
Current stream logbook with the stats required
estimator:
:class:`~sklearn_genetic.GASearchCV` Estimator that is being optimized
Returns
-------
decision: bool
True if the optimization algorithm must stop, false otherwise
"""
if logbook is not None:
if len(logbook) <= self.generations:
return False
Expand All @@ -109,7 +80,7 @@ def __call__(self, record=None, logbook=None, estimator=None):
return self.on_step(record, logbook, estimator)


class DeltaThreshold:
class DeltaThreshold(BaseCallback):
"""
Stop the optimization if the absolute difference between the current and last metric less or equals than a threshold
"""
Expand All @@ -130,21 +101,6 @@ def __init__(self, threshold, metric: str = "fitness"):
self.metric = metric

def on_step(self, record=None, logbook=None, estimator=None):
"""
Parameters
----------
record: dict: default=None
A logbook record
logbook:
Current stream logbook with the stats required
estimator:
:class:`~sklearn_genetic.GASearchCV` Estimator that is being optimized
Returns
-------
decision: bool
True if the optimization algorithm must stop, false otherwise
"""
if logbook is not None:
if len(logbook) <= 1:
return False
Expand Down
19 changes: 3 additions & 16 deletions sklearn_genetic/callbacks/loggers.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,10 @@
from copy import deepcopy
from joblib import dump

from .base import BaseCallback

class LogbookSaver:

class LogbookSaver(BaseCallback):
"""
Saves the estimator.logbook parameter chapter object in a local file system
"""
Expand All @@ -22,21 +24,6 @@ def __init__(self, checkpoint_path, **dump_options):
self.dump_options = dump_options

def on_step(self, record=None, logbook=None, estimator=None):
"""
Parameters
----------
record: dict: default=None
A logbook record
logbook:
Current stream logbook with the stats required
estimator:
:class:`~sklearn_genetic.GASearchCV` Estimator that is being optimized
Returns
-------
decision: False
Always returns False as this class doesn't take decisions over the optimization
"""
try:
dump_logbook = deepcopy(estimator.logbook.chapters["parameters"])
dump(dump_logbook, self.checkpoint_path, **self.dump_options)
Expand Down
10 changes: 5 additions & 5 deletions sklearn_genetic/callbacks/validations.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from typing import Callable

from ..parameters import Metrics
from .base import BaseCallback


def check_stats(metric):
Expand All @@ -15,17 +14,18 @@ def check_callback(callback):
Check if callback is a callable or a list of callables.
"""
if callback is not None:
if isinstance(callback, Callable):
if isinstance(callback, BaseCallback):
return [callback]

elif isinstance(callback, list) and all(
[isinstance(c, Callable) for c in callback]
[isinstance(c, BaseCallback) for c in callback]
):
return callback

else:
raise ValueError(
"callback should be either a callable or a list of callables."
"callback should be either a class or a list of classes with inheritance from "
"callbacks.base.BaseCallback"
)
else:
return []
Expand Down
2 changes: 1 addition & 1 deletion sklearn_genetic/genetic_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -486,7 +486,7 @@ def __next__(self):
self.n += 1
return result
else:
raise StopIteration
raise StopIteration # pragma: no cover

def __len__(self):
"""
Expand Down
48 changes: 45 additions & 3 deletions sklearn_genetic/tests/test_callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
LogbookSaver,
)
from ..callbacks.validations import check_stats, check_callback
from ..callbacks.base import BaseCallback

data = load_digits()
label_names = data["target_names"]
Expand All @@ -40,18 +41,59 @@ def test_check_metrics():


def test_check_callback():
assert check_callback(sum) == [sum]
callback_threshold = ThresholdStopping(threshold=0.8)
callback_consecutive = ConsecutiveStopping(generations=3)
assert check_callback(callback_threshold) == [callback_threshold]
assert check_callback(None) == []
assert check_callback([sum, min]) == [sum, min]
assert check_callback([callback_threshold, callback_consecutive]) == [
callback_threshold,
callback_consecutive,
]

with pytest.raises(Exception) as excinfo:
check_callback(1)
assert (
str(excinfo.value)
== "callback should be either a callable or a list of callables."
== "callback should be either a class or a list of classes with inheritance from "
"callbacks.base.BaseCallback"
)


def test_wrong_base_callback():
class MyDummyCallback(BaseCallback):
def __init__(self, metric):
self.metric = metric

def validate(self):
print(self.metric)

with pytest.raises(Exception) as excinfo:
callback = MyDummyCallback()
assert (
str(excinfo.value)
== "Can't instantiate abstract class MyDummyCallback with abstract methods __call__, on_step"
)


def test_base_callback_call():
possible_messages = [
"Can't instantiate abstract class MyDummyCallback with abstract methods __call__",
"Can't instantiate abstract class MyDummyCallback with abstract method __call__",
]

class MyDummyCallback(BaseCallback):
def __init__(self, metric):
self.metric = metric

def on_step(self, record=None, logbook=None, estimator=None):
print(record)

with pytest.raises(Exception) as excinfo:
callback = MyDummyCallback(metric="fitness")

assert any([str(excinfo.value) == i for i in possible_messages])


def test_threshold_callback():
callback = ThresholdStopping(threshold=0.8)
assert check_callback(callback) == [callback]
Expand Down

0 comments on commit 4961376

Please sign in to comment.