From b225ef33cd14b953b82dc2f2bf764c3fc8555af0 Mon Sep 17 00:00:00 2001 From: mirand863 Date: Mon, 28 Nov 2022 11:12:38 +0100 Subject: [PATCH] Make hyperparameters flexible (#65) --- benchmarks/consumer_complaints/README.md | 2 +- .../consumer_complaints/scripts/tune.py | 124 +++++++----------- .../consumer_complaints/scripts/tune_table.py | 33 +---- .../consumer_complaints/tests/test_tune.py | 91 ++++++------- .../tests/test_tune_table.py | 35 ----- 5 files changed, 91 insertions(+), 194 deletions(-) diff --git a/benchmarks/consumer_complaints/README.md b/benchmarks/consumer_complaints/README.md index aa04e164..f3dc8c2c 100644 --- a/benchmarks/consumer_complaints/README.md +++ b/benchmarks/consumer_complaints/README.md @@ -49,7 +49,7 @@ n_estimators: 1 criterion: 1 ``` -The intervals for testing can be defined with the functions `range` or `choice`, as described on [Hydra's documentation](https://hydra.cc/docs/plugins/optuna_sweeper/). If you wish to add more parameters for testing, you can simply add the parameter name inside the `params` field and at the end of the file set it to 1 in order to enable its usage in Hydra. Additionally, you would need to modify one of the functions `configure_lightgbm`, `configure_logistic_regression` or `configure_random_forest` (whichever is appropriate) inside the script [tune.py](scripts/tune.py) to enable the new hyperparameter. +The intervals for testing can be defined with the functions `range` or `choice`, as described on [Hydra's documentation](https://hydra.cc/docs/plugins/optuna_sweeper/). If you wish to add more parameters for testing, you can simply add the parameter name inside the `params` field and at the end of the file set it to 1 in order to enable its usage in Hydra. ## Running locally diff --git a/benchmarks/consumer_complaints/scripts/tune.py b/benchmarks/consumer_complaints/scripts/tune.py index 91c6c520..86e9e2f0 100644 --- a/benchmarks/consumer_complaints/scripts/tune.py +++ b/benchmarks/consumer_complaints/scripts/tune.py @@ -15,7 +15,6 @@ from lightgbm import LGBMClassifier from numpy.core._exceptions import _ArrayMemoryError from omegaconf import DictConfig, OmegaConf -from sklearn.base import BaseEstimator from sklearn.ensemble import RandomForestClassifier from sklearn.feature_extraction.text import CountVectorizer, TfidfTransformer from sklearn.linear_model import LogisticRegression @@ -30,89 +29,55 @@ ) from hiclass.metrics import f1 - log = logging.getLogger("TUNE") -def configure_lightgbm(cfg: DictConfig) -> BaseEstimator: - """ - Configure LightGBM with parameters passed as argument. +configure_flat = { + "lightgbm": LGBMClassifier(), + "logistic_regression": LogisticRegression(), + "random_forest": RandomForestClassifier(), +} - Parameters - ---------- - cfg : DictConfig - Dictionary containing all configuration information. - Returns - ------- - classifier : BaseEstimator - Estimator with hyper-parameters configured. - """ - classifier = LGBMClassifier( - n_jobs=cfg.n_jobs, - num_leaves=cfg.num_leaves, - n_estimators=cfg.n_estimators, - min_child_samples=cfg.min_child_samples, - ) - return classifier +configure_hierarchical = { + "local_classifier_per_node": LocalClassifierPerNode(), + "local_classifier_per_parent_node": LocalClassifierPerParentNode(), + "local_classifier_per_level": LocalClassifierPerLevel(), +} -def configure_logistic_regression(cfg: DictConfig) -> BaseEstimator: - """ - Configure LogisticRegression with parameters passed as argument. +non_hyperparameters = [ + "model", + "classifier", + "n_jobs", + "x_train", + "y_train", + "output_dir", + "mem_gb", + "n_splits", +] - Parameters - ---------- - cfg : DictConfig - Dictionary containing all configuration information. - Returns - ------- - classifier : BaseEstimator - Estimator with hyper-parameters configured. - """ - classifier = LogisticRegression( - n_jobs=cfg.n_jobs, - solver=cfg.solver, - max_iter=cfg.max_iter, - ) - return classifier - - -def configure_random_forest(cfg: DictConfig) -> BaseEstimator: +def delete_non_hyperparameters(cfg: OmegaConf) -> dict: """ - Configure RandomForest with parameters passed as argument. + Delete non-hyperparameters from the dictionary. Parameters ---------- - cfg : DictConfig - Dictionary containing all configuration information. + cfg : OmegaConf + Dictionary to delete non-hyperparameters from. Returns ------- - classifier : BaseEstimator - Estimator with hyper-parameters configured. - """ - classifier = RandomForestClassifier( - n_jobs=cfg.n_jobs, - n_estimators=cfg.n_estimators, - criterion=cfg.criterion, - ) - return classifier - - -configure_flat = { - "lightgbm": configure_lightgbm, - "logistic_regression": configure_logistic_regression, - "random_forest": configure_random_forest, -} - + hyperparameters : dict + Dictionary containing only hyperparameters. -configure_hierarchical = { - "local_classifier_per_node": LocalClassifierPerNode(), - "local_classifier_per_parent_node": LocalClassifierPerParentNode(), - "local_classifier_per_level": LocalClassifierPerLevel(), -} + """ + hyperparameters = OmegaConf.to_container(cfg) + for key in non_hyperparameters: + if key in hyperparameters: + del hyperparameters[key] + return hyperparameters def configure_pipeline(cfg: DictConfig) -> Pipeline: @@ -130,9 +95,11 @@ def configure_pipeline(cfg: DictConfig) -> Pipeline: Pipeline with hyper-parameters configured. """ if cfg.model == "flat": - classifier = configure_flat[cfg.classifier](cfg) + classifier = configure_flat[cfg.classifier] + classifier.set_params(**delete_non_hyperparameters(cfg)) else: - local_classifier = configure_flat[cfg.classifier](cfg) + local_classifier = configure_flat[cfg.classifier] + local_classifier.set_params(**delete_non_hyperparameters(cfg)) local_classifier.set_params(n_jobs=1) classifier = configure_hierarchical[cfg.model] classifier.set_params( @@ -148,14 +115,14 @@ def configure_pipeline(cfg: DictConfig) -> Pipeline: return pipeline -def compute_md5(cfg: DictConfig) -> str: +def compute_md5(cfg: dict) -> str: """ Compute MD5 hash of configuration. Parameters ---------- - cfg : DictConfig - Dictionary containing all configuration information. + cfg : dict + Dictionary containing hyperparameters. Returns ------- @@ -163,10 +130,7 @@ def compute_md5(cfg: DictConfig) -> str: MD5 hash of configuration. """ - dictionary = OmegaConf.to_object(cfg) - md5 = hashlib.md5( - json.dumps(dictionary, sort_keys=True).encode("utf-8") - ).hexdigest() + md5 = hashlib.md5(json.dumps(cfg, sort_keys=True).encode("utf-8")).hexdigest() return md5 @@ -181,10 +145,11 @@ def save_trial(cfg: DictConfig, scores: List[float]) -> None: scores : List[float] List of scores for each fold. """ - md5 = compute_md5(cfg) + hyperparameters = delete_non_hyperparameters(cfg) + md5 = compute_md5(hyperparameters) filename = f"{cfg.output_dir}/{md5}.sav" with open(filename, "wb") as file: - pickle.dump((cfg, scores), file) + pickle.dump((hyperparameters, scores), file) def load_trial(cfg: DictConfig) -> List[float]: @@ -201,7 +166,8 @@ def load_trial(cfg: DictConfig) -> List[float]: scores : List[float] The cross-validation scores or empty list if file does not exist. """ - md5 = compute_md5(cfg) + hyperparameters = delete_non_hyperparameters(cfg) + md5 = compute_md5(hyperparameters) filename = f"{cfg.output_dir}/{md5}.sav" if os.path.exists(filename): (_, scores) = pickle.load(open(filename, "rb")) diff --git a/benchmarks/consumer_complaints/scripts/tune_table.py b/benchmarks/consumer_complaints/scripts/tune_table.py index b21b6c4f..1da92490 100644 --- a/benchmarks/consumer_complaints/scripts/tune_table.py +++ b/benchmarks/consumer_complaints/scripts/tune_table.py @@ -8,7 +8,6 @@ from typing import Tuple, List import numpy as np -from omegaconf import OmegaConf def parse_args(args: list) -> Namespace: @@ -55,33 +54,6 @@ def parse_args(args: list) -> Namespace: return parser.parse_args(args) -def delete_non_hyperparameters(hyperparameters: OmegaConf) -> dict: - """ - Delete non-hyperparameters from the dictionary. - - Parameters - ---------- - hyperparameters : OmegaConf - Hyperparameters to delete non-hyperparameters from. - - Returns - ------- - hyperparameters : dict - Hyperparameters without non-hyperparameters. - - """ - hyperparameters = OmegaConf.to_container(hyperparameters) - del hyperparameters["model"] - del hyperparameters["classifier"] - del hyperparameters["n_jobs"] - del hyperparameters["x_train"] - del hyperparameters["y_train"] - del hyperparameters["output_dir"] - del hyperparameters["mem_gb"] - del hyperparameters["n_splits"] - return hyperparameters - - def compute( folder: str, ) -> Tuple[List[dict], List[list], List[np.ndarray], List[np.ndarray]]: @@ -104,14 +76,15 @@ def compute( std : List[np.ndarray] Standard deviations of k-fold cross-validation. """ - results = glob.glob(f"{folder}/[!trained_model]*.sav") + results = glob.glob(f"{folder}/*.sav") + if "{}/trained_model.sav".format(folder) in results: + results.remove(f"{folder}/trained_model.sav") hyperparameters = [] scores = [] avg = [] std = [] for result in results: parameters, s = pickle.load(open(result, "rb")) - parameters = delete_non_hyperparameters(parameters) hyperparameters.append(parameters) scores.append([round(i, 3) for i in s]) avg.append(np.mean(s)) diff --git a/benchmarks/consumer_complaints/tests/test_tune.py b/benchmarks/consumer_complaints/tests/test_tune.py index 92689a2d..82523d35 100644 --- a/benchmarks/consumer_complaints/tests/test_tune.py +++ b/benchmarks/consumer_complaints/tests/test_tune.py @@ -2,51 +2,21 @@ import pandas as pd import pytest -from lightgbm import LGBMClassifier -from omegaconf import DictConfig +from omegaconf import DictConfig, OmegaConf from pyfakefs.fake_filesystem_unittest import Patcher from scripts.data import flatten_labels from scripts.tune import ( - configure_lightgbm, - configure_logistic_regression, - configure_random_forest, configure_pipeline, compute_md5, save_trial, load_trial, limit_memory, cross_validate, + delete_non_hyperparameters, ) -from sklearn.ensemble import RandomForestClassifier -from sklearn.linear_model import LogisticRegression from sklearn.pipeline import Pipeline -def test_configure_lightgbm(): - cfg = DictConfig( - { - "n_jobs": 1, - "boosting_type": "gbdt", - "num_leaves": 31, - "learning_rate": 0.1, - "n_estimators": 100, - "subsample_for_bin": 200000, - "class_weight": None, - "min_split_gain": 0.0, - "min_child_weight": 0.001, - "min_child_samples": 20, - "subsample": 1.0, - "subsample_freq": 0, - "colsample_bytree": 1.0, - "reg_alpha": 0.0, - "reg_lambda": 0.0, - } - ) - classifier = configure_lightgbm(cfg) - assert classifier is not None - assert isinstance(classifier, LGBMClassifier) - - @pytest.fixture def random_forest_config(): cfg = DictConfig( @@ -71,12 +41,6 @@ def random_forest_config(): return cfg -def test_configure_random_forest(random_forest_config): - classifier = configure_random_forest(random_forest_config) - assert classifier is not None - assert isinstance(classifier, RandomForestClassifier) - - def test_configure_pipeline_1(random_forest_config): classifier = configure_pipeline(random_forest_config) assert classifier is not None @@ -84,13 +48,13 @@ def test_configure_pipeline_1(random_forest_config): def test_compute_md5_1(random_forest_config): - md5 = compute_md5(random_forest_config) - expected = "407b164147fd7e78c41acc215d9c28fe" + hyperparameters = delete_non_hyperparameters(random_forest_config) + md5 = compute_md5(hyperparameters) + expected = "14a11881777bc9589583efba8af9b752" assert expected == md5 def test_save_and_load_trial_1(random_forest_config): - random_forest_config.output_dir = "." with Patcher(): save_trial(random_forest_config, [1, 2, 3]) scores = load_trial(random_forest_config) @@ -121,12 +85,6 @@ def logistic_regression_config(): return cfg -def test_configure_logistic_regression(logistic_regression_config): - classifier = configure_logistic_regression(logistic_regression_config) - assert classifier is not None - assert isinstance(classifier, LogisticRegression) - - def test_configure_pipeline_2(logistic_regression_config): classifier = configure_pipeline(logistic_regression_config) assert classifier is not None @@ -134,8 +92,9 @@ def test_configure_pipeline_2(logistic_regression_config): def test_compute_md5_2(logistic_regression_config): - md5 = compute_md5(logistic_regression_config) - expected = "739b6346f16b738eb36967e6d82ade41" + hyperparameters = delete_non_hyperparameters(logistic_regression_config) + md5 = compute_md5(hyperparameters) + expected = "f393e84d6faa5dc4c62aacbf70ef5865" assert expected == md5 @@ -213,3 +172,37 @@ def test_cross_validate_4(logistic_regression_config, X, y): assert [0.5, 1.0] == scores scores = load_trial(logistic_regression_config) assert [0.5, 1.0] == scores + + +@pytest.fixture +def hyperparameters(): + hp = OmegaConf.create( + { + "model": "flat", + "classifier": "lightgbm", + "n_jobs": 1, + "x_train": "x_train.csv", + "y_train": "y_train.csv", + "output_dir": ".", + "mem_gb": 1, + "n_splits": 2, + "reg_alpha": 0.0, + "reg_lambda": 0.0, + "num_leaves": 31, + "learning_rate": 0.1, + "n_estimators": 100, + } + ) + return hp + + +def test_delete_non_hyperparameters(hyperparameters): + hyperparameters = delete_non_hyperparameters(hyperparameters) + expected = { + "reg_alpha": 0.0, + "reg_lambda": 0.0, + "num_leaves": 31, + "learning_rate": 0.1, + "n_estimators": 100, + } + assert expected == hyperparameters diff --git a/benchmarks/consumer_complaints/tests/test_tune_table.py b/benchmarks/consumer_complaints/tests/test_tune_table.py index 802788d0..3bc7759c 100644 --- a/benchmarks/consumer_complaints/tests/test_tune_table.py +++ b/benchmarks/consumer_complaints/tests/test_tune_table.py @@ -6,7 +6,6 @@ from pyfakefs.fake_filesystem_unittest import Patcher from scripts.tune_table import ( parse_args, - delete_non_hyperparameters, compute, create_table, ) @@ -37,40 +36,6 @@ def test_parser(): assert "output.md" == parser.output -@pytest.fixture -def hyperparameters(): - hp = OmegaConf.create( - { - "model": "flat", - "classifier": "lightgbm", - "n_jobs": 1, - "x_train": "x_train.csv", - "y_train": "y_train.csv", - "output_dir": ".", - "mem_gb": 1, - "n_splits": 2, - "reg_alpha": 0.0, - "reg_lambda": 0.0, - "num_leaves": 31, - "learning_rate": 0.1, - "n_estimators": 100, - } - ) - return hp - - -def test_delete_non_hyperparameters(hyperparameters): - hyperparameters = delete_non_hyperparameters(hyperparameters) - expected = { - "reg_alpha": 0.0, - "reg_lambda": 0.0, - "num_leaves": 31, - "learning_rate": 0.1, - "n_estimators": 100, - } - assert expected == hyperparameters - - @pytest.fixture def lightgbm_config(): cfg = OmegaConf.create(