Skip to content

Commit

Permalink
ScheduleParam: set at beginning
Browse files Browse the repository at this point in the history
  • Loading branch information
ppwwyyxx committed Nov 1, 2020
1 parent 36bdc18 commit 11ca8b2
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 7 deletions.
27 changes: 21 additions & 6 deletions tensorpack/callbacks/param.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,7 +234,8 @@ class ScheduledHyperParamSetter(HyperParamSetter):
Set hyperparameters by a predefined epoch-based schedule.
"""

def __init__(self, param, schedule, interp=None, step_based=False):
def __init__(self, param, schedule, interp=None, step_based=False,
set_at_beginning=True):
"""
Args:
param: same as in :class:`HyperParamSetter`.
Expand All @@ -250,6 +251,14 @@ def __init__(self, param, schedule, interp=None, step_based=False):
every time this callback is triggered.
step_based (bool): interpret ``schedule`` as (step, value) instead
of (epoch, value).
set_at_beginning (bool): at the start of training, the current value
may be different from the expected value according to the
schedule.
If this option is True, set the value anyway even though the current
epoch/step is not at the scheduled time.
If False, the value will only be set according to the
schedule, i.e. it will only be set if the current epoch/step
is at the scheduled time.
Example:
.. code-block:: python
Expand All @@ -263,6 +272,7 @@ def __init__(self, param, schedule, interp=None, step_based=False):
assert interp == 'linear'
self.interp = interp
self._step = step_based
self._set_at_beginning = set_at_beginning
super(ScheduledHyperParamSetter, self).__init__(param)

def _get_value_to_set(self): # override parent
Expand All @@ -277,12 +287,17 @@ def _check_value_at_beginning(self):
for p in range(0, self._current_point() + 1):
v = self._get_value_to_set_at_point(p) or v
actual_value = self.param.get_value()
current_point = "step" if self._step else "epoch" + str(self._current_point())
if v is not None and not np.isclose(v, actual_value):
logger.warn("According to scheduler {}, parameter '{}' should become {} at the current point. "
"However its current value is {}. "
"If this is the only scheduler being used, you may want to check whether your "
"initialization of the parameter is as expected".format(
self, self.param.readable_name, v, actual_value))
logger.warn("According to scheduler {}, parameter '{}' should become {:.7g} at the current point ({}). "
"However its current value is {:.7g}. ".format(
self, self.param.readable_name, v, current_point, actual_value))
if self._set_at_beginning:
logger.info("Setting '{}' to {:.7g}.".format(self.param.readable_name, v))
self.param.set_value(v)
else:
logger.warn("If there is no other scheduler being used, you may want to check whether your "
"initialization of the parameter is as expected")

def _get_value_to_set_at_point(self, point):
"""
Expand Down
9 changes: 8 additions & 1 deletion tensorpack/callbacks/param_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,10 +69,17 @@ def testSchedule(self):
def testStartAfterSchedule(self):
scheduler = ScheduledHyperParamSetter(
ObjAttrParam(self._param_obj, ParamObject.PARAM_NAME),
[(10, 0.3), (20, 0.4), (30, 0.5)])
[(10, 0.3), (20, 0.4), (30, 0.5)], set_at_beginning=False)
history = self._create_trainer_with_scheduler(scheduler, 1, 92, starting_epoch=90)
self.assertEqual(len(history), 0)

def testStartAfterSchedule_SetAtBeginning(self):
scheduler = ScheduledHyperParamSetter(
ObjAttrParam(self._param_obj, ParamObject.PARAM_NAME),
[(10, 0.3), (20, 0.4), (30, 0.5)], set_at_beginning=True)
history = self._create_trainer_with_scheduler(scheduler, 1, 92, starting_epoch=90)
self.assertEqual(history, {0: 0.5})

def testWarningStartInTheMiddle(self):
scheduler = ScheduledHyperParamSetter(
ObjAttrParam(self._param_obj, ParamObject.PARAM_NAME),
Expand Down

0 comments on commit 11ca8b2

Please sign in to comment.