Skip to content

Commit

Permalink
Merge pull request #3 from tlranda/master
Browse files Browse the repository at this point in the history
Optimizer fixes for ConfigurationSpace interactions
  • Loading branch information
pbalapra committed Feb 24, 2023
2 parents f9805df + 0738a2e commit f321c3d
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 6 deletions.
2 changes: 2 additions & 0 deletions README.md
Expand Up @@ -190,6 +190,8 @@ The core ``ytopt`` team is at Argonne National Laboratory:

Modules, patches (code, documentation, etc.) contributed by:

* Thomas Randall <tlranda@clemson.edu>

# How can I participate?

Questions, comments, feature requests, bug reports, etc. can be directed to:
Expand Down
16 changes: 10 additions & 6 deletions ytopt/search/optimizer/optimizer.py
Expand Up @@ -11,21 +11,22 @@
logger = util.conf_logger('ytopt.search.hps.optimizer.optimizer')

class Optimizer:
SEED = 12345
KAPPA = 1.96

def __init__(self, num_workers: int, space, learner, acq_func, liar_strategy, **kwargs):
def __init__(self, num_workers: int, space, learner, acq_func, liar_strategy, KAPPA=1.96, SEED=12345, **kwargs):
assert learner in ["RF", "ET", "GBRT", "GP", "DUMMY"], f"Unknown scikit-optimize base_estimator: {learner}"
assert liar_strategy in "cl_min cl_mean cl_max".split()

self.space = space
self.learner = learner
self.acq_func = acq_func
self.liar_strategy = liar_strategy
self.KAPPA = KAPPA
self.SEED = SEED

n_init = inf if learner=='DUMMY' else num_workers

if isinstance(self.space, CS.ConfigurationSpace):
# Pass on seed for replicable RNG
self.space.seed(self.SEED)
self._optimizer = SkOptimizer(
dimensions=self.space,
base_estimator=self.learner,
Expand Down Expand Up @@ -110,12 +111,15 @@ def ask_initial(self, n_points):
self.evals[key] = y
return [self.to_dict(x) for x in XX]

def tell(self, xy_data):
def tell(self, xy_data, require_requested=True):
assert isinstance(xy_data, list), f"where type(xy_data)=={type(xy_data)}"
maxval = max(self._optimizer.yi) if self._optimizer.yi else 0.0
for x,y in xy_data:
key = tuple(x.values()) # * tuple(x[k] for k in self.space)
assert key in self.evals, f"where key=={key} and self.evals=={self.evals}"
if require_requested:
assert key in self.evals, f"where key=={key} and self.evals=={self.evals}"
else:
self.counter += 1
logger.debug(f'tell: {x} --> {key}: evaluated objective: {y}')
self.evals[key] = (y if y < float_info.max else maxval)

Expand Down

0 comments on commit f321c3d

Please sign in to comment.