Skip to content

Commit

Permalink
Merge pull request #107 from Habush/master
Browse files Browse the repository at this point in the history
Add the support to pass parameters needed by the estimator during fitting
  • Loading branch information
rodrigo-arenas authored Jan 19, 2023
2 parents 12e35ad + 6b31057 commit 9f18449
Showing 1 changed file with 40 additions and 12 deletions.
52 changes: 40 additions & 12 deletions sklearn_genetic/genetic_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -316,10 +316,16 @@ def __init__(
error_score=error_score,
)

def _register(self):
def _register(self, fit_params):
"""
This function is the responsible for registering the DEAPs necessary methods
and create other objects to hold the hof, logbook and stats.
Parameters
----------
fit_params : dict, default=None
Parameters to pass to the fit method of the estimator.
----------
"""

self.creator.create("FitnessMax", base.Fitness, weights=[self.criteria_sign])
Expand Down Expand Up @@ -364,7 +370,9 @@ def _register(self):
else:
self.toolbox.register("select", tools.selRoulette)

self.toolbox.register("evaluate", self.evaluate)
evaluate = lambda ind: self.evaluate(ind, fit_params)

self.toolbox.register("evaluate", evaluate)

self._pop = self.toolbox.population(n=self.population_size)
self._hof = tools.HallOfFame(self.keep_top_k)
Expand Down Expand Up @@ -401,14 +409,15 @@ def mutate(self, individual):

return [individual]

def evaluate(self, individual):
def evaluate(self, individual, fit_params):
"""
Compute the cross-validation scores and record the logbook and mlflow (if specified)
Parameters
----------
individual: Individual object
The individual (set of hyperparameters) that is being evaluated
fit_params : dict, default=None
Parameters to pass to the fit method of the estimator.
Returns
-------
The fitness value of the estimator candidate, corresponding to the cv-score
Expand All @@ -429,6 +438,7 @@ def evaluate(self, individual):
self.X_,
self.y_,
cv=self.cv,
fit_params=fit_params,
scoring=self.scoring,
n_jobs=self.n_jobs,
pre_dispatch=self.pre_dispatch,
Expand Down Expand Up @@ -469,7 +479,7 @@ def evaluate(self, individual):

return [score]

def fit(self, X, y, callbacks=None):
def fit(self, X, y, callbacks=None, **fit_params):
"""
Main method of GASearchCV, starts the optimization
procedure with the hyperparameters of the given estimator
Expand All @@ -488,6 +498,9 @@ def fit(self, X, y, callbacks=None):
:class:`~sklearn_genetic.callbacks`.
The callback is evaluated after fitting the estimators from the generation 1.
fit_params : dict, default=None
Parameters to pass to the fit method of the estimator.
"""

self.X_ = X
Expand All @@ -514,7 +527,7 @@ def fit(self, X, y, callbacks=None):
self.n_splits_ = cv_orig.get_n_splits(X, y)

# Set the DEAPs necessary methods
self._register()
self._register(fit_params)

# Optimization routine from the selected evolutionary algorithm
pop, log, n_gen = self._select_algorithm(
Expand Down Expand Up @@ -552,7 +565,7 @@ def fit(self, X, y, callbacks=None):
self.estimator.set_params(**self.best_params_)

refit_start_time = time.time()
self.estimator.fit(self.X_, self.y_)
self.estimator.fit(self.X_, self.y_, **fit_params)
refit_end_time = time.time()
self.refit_time_ = refit_end_time - refit_start_time

Expand Down Expand Up @@ -989,10 +1002,17 @@ def __init__(
error_score=error_score,
)

def _register(self):
def _register(self, fit_params):
"""
This function is the responsible for registering the DEAPs necessary methods
and create other objects to hold the hof, logbook and stats.
Parameters
----------
fit_params : dict, default=None
Parameters to pass to the fit method of the estimator.
----------
"""

# Criteria sign to set max or min problem
Expand Down Expand Up @@ -1032,6 +1052,8 @@ def _register(self):
else:
self.toolbox.register("select", tools.selRoulette)

evaluate = lambda ind: self.evaluate(ind, fit_params)

self.toolbox.register("evaluate", self.evaluate)

self._pop = self.toolbox.population(n=self.population_size)
Expand All @@ -1047,14 +1069,17 @@ def _register(self):

self.logbook = tools.Logbook()

def evaluate(self, individual):
def evaluate(self, individual, fit_params):
"""
Compute the cross-validation scores and record the logbook and mlflow (if specified)
Parameters
----------
individual: Individual object
The individual (set of features) that is being evaluated
fit_params : dict, default=None
Parameters to pass to the fit method of the estimator.
Returns
-------
fitness: List
Expand All @@ -1077,6 +1102,7 @@ def evaluate(self, individual):
self.X_[:, bool_individual],
self.y_,
cv=self.cv,
fit_params=fit_params,
scoring=self.scoring,
n_jobs=self.n_jobs,
pre_dispatch=self.pre_dispatch,
Expand Down Expand Up @@ -1124,7 +1150,7 @@ def evaluate(self, individual):

return [score, n_selected_features]

def fit(self, X, y, callbacks=None):
def fit(self, X, y, callbacks=None, **fit_params):
"""
Main method of GAFeatureSelectionCV, starts the optimization
procedure with to find the best features set
Expand All @@ -1140,6 +1166,8 @@ def fit(self, X, y, callbacks=None):
One or a list of the callbacks methods available in
:class:`~sklearn_genetic.callbacks`.
The callback is evaluated after fitting the estimators from the generation 1.
fit_params : dict, default=None
Parameters to pass to the fit method of the estimator.
"""

self.X_, self.y_ = check_X_y(X, y)
Expand Down Expand Up @@ -1169,7 +1197,7 @@ def fit(self, X, y, callbacks=None):
self.n_splits_ = cv_orig.get_n_splits(X, y)

# Set the DEAPs necessary methods
self._register()
self._register(fit_params)

# Optimization routine from the selected evolutionary algorithm
pop, log, n_gen = self._select_algorithm(
Expand Down Expand Up @@ -1199,7 +1227,7 @@ def fit(self, X, y, callbacks=None):
bool_individual = np.array(self.best_features_, dtype=bool)

refit_start_time = time.time()
self.estimator.fit(self.X_[:, bool_individual], self.y_)
self.estimator.fit(self.X_[:, bool_individual], self.y_, fit_params)
refit_end_time = time.time()
self.refit_time_ = refit_end_time - refit_start_time

Expand Down

0 comments on commit 9f18449

Please sign in to comment.