Skip to content

Commit

Permalink
CheckNumerics callbacks
Browse files Browse the repository at this point in the history
  • Loading branch information
ppwwyyxx committed Jul 16, 2020
1 parent 07e28ee commit dbc0b36
Showing 1 changed file with 17 additions and 8 deletions.
25 changes: 17 additions & 8 deletions tensorpack/callbacks/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,18 +219,27 @@ def _dump_image(self, im, idx=None):
cv2.imwrite(fname, res.astype('uint8'))


class CheckNumerics(Callback):
class CheckNumerics(RunOp):
"""
When triggered, check variables in the graph for NaN and Inf.
Raise exceptions if such an error is found.
Check variables in the graph for NaN and Inf.
Raise an exception if such an error is found.
"""
def _setup_graph(self):
_chief_only = True

def __init__(self, run_as_trigger=True, run_step=False):
"""
Args: same as in :class:`RunOp`.
"""
super().__init__(
self._get_op,
run_as_trigger=run_as_trigger,
run_step=run_step)

def _get_op(self):
vars = tf.trainable_variables()
ops = [tf.check_numerics(v, "CheckNumerics['{}']".format(v.op.name)).op for v in vars]
self._check_op = tf.group(*ops)

def _trigger(self):
self._check_op.run()
check_op = tf.group(*ops, name="CheckAllNumerics")
return check_op


try:
Expand Down

0 comments on commit dbc0b36

Please sign in to comment.