Skip to content

Commit

Permalink
Merge pull request #18 from rodrigo-arenas/0.5.xdev
Browse files Browse the repository at this point in the history
plot fitness evolution metric selection
  • Loading branch information
rodrigo-arenas committed Jun 9, 2021
2 parents 8cfdfd8 + a000e3c commit 8d643ec
Show file tree
Hide file tree
Showing 4 changed files with 20 additions and 7 deletions.
2 changes: 1 addition & 1 deletion demo/Boson_Houses_decision_tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@
print("r-squared: ", "{:.2f}".format(r_squared))

print("Best k solutions: ", evolved_estimator.hof)
plot = plot_fitness_evolution(evolved_estimator)
plot = plot_fitness_evolution(evolved_estimator, metric="fitness_sd")
plt.show()

plot_search_space(evolved_estimator)
Expand Down
4 changes: 1 addition & 3 deletions sklearn_genetic/_version.py
Original file line number Diff line number Diff line change
@@ -1,3 +1 @@

__version__ = "0.4.1"

__version__ = "0.5.0dev0"
13 changes: 10 additions & 3 deletions sklearn_genetic/plots.py
Original file line number Diff line number Diff line change
@@ -1,36 +1,43 @@
import seaborn as sns

from .utils import logbook_to_pandas
from .parameters import Metrics


"""
This module contains some useful function to explore the results of the optimization routines
"""


def plot_fitness_evolution(estimator):
def plot_fitness_evolution(estimator, metric="fitness"):
"""
Parameters
----------
estimator: estimator object
A fitted estimator from :class:`~sklearn_genetic.GASearchCV`
metric: {"fitness", "fitness_std", "fitness_max", "fitness_min"}, default="fitness"
Logged metric into the estimator history to plot
Returns
-------
Lines plot with the fitness value in each generation
"""

if metric not in Metrics.list():
raise ValueError(f"metric must be one of {Metrics.list()}, but got {metric} instead")

sns.set_style("white")

fitness_history = estimator.history["fitness"]
fitness_history = estimator.history[metric]

palette = sns.color_palette("rocket")
sns.set(rc={"figure.figsize": (10, 10)})

ax = sns.lineplot(
x=range(len(estimator)), y=fitness_history, markers=True, palette=palette
)
ax.set_title("Fitness average evolution over generations")
ax.set_title(f"{metric.capitalize()} average evolution over generations")

ax.set(xlabel="generations", ylabel=f"fitness ({estimator.scoring})")
return ax
Expand Down
8 changes: 8 additions & 0 deletions sklearn_genetic/tests/test_plots.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import pytest
from sklearn.datasets import load_boston
from sklearn.model_selection import train_test_split
from sklearn.tree import DecisionTreeRegressor
Expand Down Expand Up @@ -44,8 +45,15 @@
def test_plot_evolution():
plot = plot_fitness_evolution(evolved_estimator)

with pytest.raises(Exception) as excinfo:
plot = plot_fitness_evolution(evolved_estimator, metric="accuracy")

assert str(excinfo.value) == "metric must be one of ['fitness', 'fitness_std', 'fitness_max', 'fitness_min'], " \
"but got accuracy instead"


def test_plot_space():
plot = plot_search_space(evolved_estimator)
plot = plot_search_space(evolved_estimator)
plot = plot_search_space(
evolved_estimator, features=["ccp_alpha", "max_depth", "min_samples_split"]
Expand Down

0 comments on commit 8d643ec

Please sign in to comment.