Skip to content

Commit

Permalink
freeze_variable: don't add to collection if not originally trainable
Browse files Browse the repository at this point in the history
  • Loading branch information
ppwwyyxx committed May 14, 2020
1 parent 9f4154e commit 610ffe3
Showing 1 changed file with 3 additions and 1 deletion.
4 changes: 3 additions & 1 deletion tensorpack/tfutils/varreplace.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,9 @@ def custom_getter(getter, *args, **kwargs):
if skip_collection:
kwargs['trainable'] = False
v = getter(*args, **kwargs)
if skip_collection:
# do not perform unnecessary changes if it's not originally trainable
# otherwise the variable may get added to MODEL_VARIABLES twice
if trainable and skip_collection:
tf.add_to_collection(tf.GraphKeys.MODEL_VARIABLES, v)
if trainable and stop_gradient:
v = tf.stop_gradient(v, name='freezed_' + name)
Expand Down

0 comments on commit 610ffe3

Please sign in to comment.