Skip to content

Commit

Permalink
correct TF version in saliency example
Browse files Browse the repository at this point in the history
  • Loading branch information
ppwwyyxx committed Jan 7, 2020
1 parent 758ae94 commit d2f9564
Showing 1 changed file with 5 additions and 3 deletions.
8 changes: 5 additions & 3 deletions examples/Saliency/saliency-maps.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

import tensorpack as tp
import tensorpack.utils.viz as viz
from tensorpack.tfutils import get_tf_version_tuple

IMAGE_SIZE = 224

Expand All @@ -29,7 +30,7 @@ def guided_relu():
@tf.RegisterGradient("GuidedReLU")
def GuidedReluGrad(op, grad):
return tf.where(0. < grad,
gen_nn_ops._relu_grad(grad, op.outputs[0]),
gen_nn_ops.relu_grad(grad, op.outputs[0]),
tf.zeros(grad.get_shape()))

g = tf.get_default_graph()
Expand Down Expand Up @@ -59,9 +60,9 @@ def inputs(self):
def build_graph(self, orig_image):
mean = tf.get_variable('resnet_v1_50/mean_rgb', shape=[3])
with guided_relu():
with slim.arg_scope(resnet_v1.resnet_arg_scope(is_training=False)):
with slim.arg_scope(resnet_v1.resnet_arg_scope()):
image = tf.expand_dims(orig_image - mean, 0)
logits, _ = resnet_v1.resnet_v1_50(image, 1000)
logits, _ = resnet_v1.resnet_v1_50(image, 1000, is_training=False)
saliency_map(logits, orig_image, name="saliency")


Expand Down Expand Up @@ -103,4 +104,5 @@ def run(model_path, image_path):
if len(sys.argv) != 2:
tp.logger.error("Usage: {} image.jpg".format(sys.argv[0]))
sys.exit(1)
assert get_tf_version_tuple() >= (1, 7), "requires TF >= 1.7"
run("resnet_v1_50.ckpt", sys.argv[1])

0 comments on commit d2f9564

Please sign in to comment.