Skip to content

Commit

Permalink
use TowerHandle.get_tensor to access variables (fix #1409)
Browse files Browse the repository at this point in the history
  • Loading branch information
ppwwyyxx committed Mar 14, 2020
1 parent 6f0ba59 commit d89b6f0
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 3 deletions.
1 change: 1 addition & 0 deletions examples/DynamicFilterNetwork/steering-filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from tensorpack import *
from tensorpack.dataflow import dataset
from tensorpack.utils import logger
from tensorpack.utils.gpu import change_gpu
from tensorpack.utils.argtools import shape2d, shape4d
from tensorpack.utils.viz import *

Expand Down
7 changes: 4 additions & 3 deletions tensorpack/tfutils/tower.py
Original file line number Diff line number Diff line change
Expand Up @@ -385,9 +385,9 @@ def ns_name(self):

def get_tensor(self, name):
"""
Get a tensor in this tower. The name can be:
Get a tensor in this tower. The name argument can be:
1. The name of the tensor without any tower prefix.
1. The name of a tensor/variable without any tower prefix.
2. A name in the input signature, if it is used when building the tower.
Expand All @@ -405,7 +405,6 @@ def get_tensor(self, name):
except KeyError:
if name in self._extra_tensor_names:
return self._extra_tensor_names[name]
raise
else:
if name in self._extra_tensor_names:
mapped_tensor = self._extra_tensor_names[name]
Expand All @@ -415,6 +414,8 @@ def get_tensor(self, name):
" Assuming it is the input '{}'.".format(mapped_tensor.name))
return mapped_tensor
return ret
# should also allow variables in get_tensor
return self.get_variable(name)

def get_tensors(self, names):
"""
Expand Down

0 comments on commit d89b6f0

Please sign in to comment.