Skip to content

Commit

Permalink
Improve compatibility to tensorflow 2.3 (#1487)
Browse files Browse the repository at this point in the history
* fix gfile not found error occurring with tensorflow 2.3

* fix as_list() not found error occurring with tensorflow 2.3

Co-authored-by: Philipp Werner <pw_post@gmx.de>
  • Loading branch information
philippwerner and philippwerner committed Sep 30, 2020
1 parent a12872d commit 1b98fe5
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 8 deletions.
20 changes: 14 additions & 6 deletions tensorpack/input_source/input_source_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,13 +33,21 @@ def build_or_reuse_placeholder(tensor_spec):
assert "Placeholder" in tensor.op.type, "Tensor {} exists but is not a placeholder!".format(name)
assert tensor_spec.is_compatible_with(tensor), \
"Tensor {} exists but is not compatible with the signature!".format(tensor)
if tensor.shape.as_list() == tensor_spec.shape.as_list():
# It might be desirable to use a placeholder of a different shape in some tower
# (e.g., a less specific shape)

# Comparing `tensor.shape` directly doesn't work, because
# tensorflow thinks `tf.Dimension(None)` and `tf.Dimension(None)` are not equal.
return tensor
# It might be desirable to use a placeholder of a different shape in some tower
# (e.g., a less specific shape)
try:
if tensor.shape.as_list() == tensor_spec.shape.as_list():
# Comparing `tensor.shape` directly doesn't work in older versions of tensorflow,
# because tensorflow thinks `tf.Dimension(None)` and `tf.Dimension(None)` are not
# equal. Newer versions of tensorflow, e.g. 2.3, do not support as_list() for
# `tf.Dimension(None)` and raise a `ValueError`
return tensor
except ValueError:
if tensor.shape == tensor_spec.shape:
# With the newer version of tensorflow, comparing `tensor.shape` directly seems
# to work fine.
return tensor
except KeyError:
pass
with tfv1.name_scope(None): # clear any name scope it might get called in
Expand Down
4 changes: 2 additions & 2 deletions tensorpack/train/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@
# File: config.py

import os
import tensorflow as tf

from ..compat import tfv1
from ..callbacks import (
JSONWriter, MergeAllSummaries, MovingAverageSummary, ProgressBar, RunUpdateOps, ScalarPrinter, TFEventWriter)
from ..dataflow.base import DataFlow
Expand Down Expand Up @@ -237,6 +237,6 @@ def get_sessinit_resume(dir=None):
if not dir:
return None
path = os.path.join(dir, 'checkpoint')
if not tf.gfile.Exists(path):
if not tfv1.gfile.Exists(path):
return None
return SaverRestore(path)

0 comments on commit 1b98fe5

Please sign in to comment.