Skip to content
Merged
10 changes: 6 additions & 4 deletions docs/tutorials/10-Hyperparameter Tuning.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -1396,11 +1396,12 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"Result is a namedtuple with trials_df, best_params, and best_score\\\n",
"Result is a namedtuple with trials_df, best_params, best_score and best_model\\\n",
"\n",
"- trials_df: A dataframe with all the hyperparameter combinations and their corresponding scores\n",
"- best_params: The best hyperparameter combination\n",
"- best_score: The best score"
"- best_score: The best score\n",
"- best_model: If return_best_model is True, return best_model otherwise return None"
]
},
{
Expand Down Expand Up @@ -1895,11 +1896,12 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"Result is a namedtuple with trials_df, best_params, and best_score\\\n",
"Result is a namedtuple with trials_df, best_params, best_score and best_model\\\n",
"\n",
"- trials_df: A dataframe with all the hyperparameter combinations and their corresponding scores\n",
"- best_params: The best hyperparameter combination\n",
"- best_score: The best score"
"- best_score: The best score\n",
"- best_model: If return_best_model is True, return best_model otherwise return None"
]
},
{
Expand Down
7 changes: 6 additions & 1 deletion src/pytorch_tabular/tabular_model_sweep.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,11 @@ def model_sweep(
verbose (bool, optional): If True, will print the progress. Defaults to True.

suppress_lightning_logger (bool, optional): If True, will suppress the lightning logger. Defaults to True.

Returns:
results: Training results.

best_model: If return_best_model is True, return best_model otherwise return None.
"""
_validate_args(
task=task,
Expand Down Expand Up @@ -386,4 +391,4 @@ def _init_tabular_model(m):
best_model.datamodule = datamodule
return results, best_model
else:
return results
return results, None
37 changes: 35 additions & 2 deletions src/pytorch_tabular/tabular_model_tuner.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ class TabularModelTuner:
"""

ALLOWABLE_STRATEGIES = ["grid_search", "random_search"]
OUTPUT = namedtuple("OUTPUT", ["trials_df", "best_params", "best_score"])
OUTPUT = namedtuple("OUTPUT", ["trials_df", "best_params", "best_score", "best_model"])

def __init__(
self,
Expand Down Expand Up @@ -88,6 +88,8 @@ def __init__(
if trainer_config.fast_dev_run:
warnings.warn("fast_dev_run is turned on. Tuning results won't be accurate.")
if trainer_config.progress_bar != "none":
# If config and tuner have progress bar enabled, it will result in a bug within the library (rich.progress)
trainer_config.progress_bar = "none"
warnings.warn("Turning off progress bar. Set progress_bar='none' in TrainerConfig to disable this warning.")
trainer_config.trainer_kwargs.update({"enable_model_summary": False})
self.data_config = data_config
Expand Down Expand Up @@ -153,6 +155,7 @@ def tune(
cv: Optional[Union[int, Iterable, BaseCrossValidator]] = None,
cv_agg_func: Optional[Callable] = np.mean,
cv_kwargs: Optional[Dict] = {},
return_best_model: bool = True,
verbose: bool = False,
progress_bar: bool = True,
random_state: Optional[int] = 42,
Expand Down Expand Up @@ -200,6 +203,8 @@ def tune(
cv_kwargs (Optional[Dict], optional): Additional keyword arguments to be passed to the cross validation
method. Defaults to {}.

return_best_model (bool, optional): If True, will return the best model. Defaults to True.

verbose (bool, optional): Whether to print the results of each trial. Defaults to False.

progress_bar (bool, optional): Whether to show a progress bar. Defaults to True.
Expand All @@ -215,6 +220,7 @@ def tune(
trials_df (DataFrame): A dataframe with the results of each trial
best_params (Dict): The best parameters found
best_score (float): The best score found
best_model (TabularModel or None): If return_best_model is True, return best_model otherwise return None
"""
assert strategy in self.ALLOWABLE_STRATEGIES, f"tuner must be one of {self.ALLOWABLE_STRATEGIES}"
assert mode in ["max", "min"], "mode must be one of ['max', 'min']"
Expand Down Expand Up @@ -270,6 +276,8 @@ def tune(
metric_str = metric.__name__
del temp_tabular_model
trials = []
best_model = None
best_score = 0.0
for i, params in enumerate(iterator):
# Copying the configs as a base
# Make sure all default parameters that you want to be set for all
Expand Down Expand Up @@ -334,6 +342,22 @@ def tune(
else:
result = tabular_model_t.evaluate(validation, verbose=False)
params.update({k.replace("test_", ""): v for k, v in result[0].items()})

if return_best_model:
tabular_model_t.datamodule = None
if best_model is None:
best_model = deepcopy(tabular_model_t)
best_score = params[metric_str]
else:
if mode == "min":
if params[metric_str] < best_score:
best_model = deepcopy(tabular_model_t)
best_score = params[metric_str]
elif mode == "max":
if params[metric_str] > best_score:
best_model = deepcopy(tabular_model_t)
best_score = params[metric_str]

params.update({"trial_id": i})
trials.append(params)
if verbose:
Expand All @@ -349,4 +373,13 @@ def tune(
best_params = trials_df.iloc[best_idx].to_dict()
best_score = best_params.pop(metric_str)
trials_df.insert(0, "trial_id", trials)
return self.OUTPUT(trials_df, best_params, best_score)

if verbose:
logger.info("Model Tuner Finished")
logger.info(f"Best Score ({metric_str}): {best_score}")

if return_best_model and best_model is not None:
best_model.datamodule = datamodule
return self.OUTPUT(trials_df, best_params, best_score, best_model)
else:
return self.OUTPUT(trials_df, best_params, best_score, None)