Skip to content

Commit

Permalink
Fix TensorFlow checkpoint and trackable imports.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 452684705
  • Loading branch information
k-w-w authored and tensorflower-gardener committed Jun 3, 2022
1 parent 12e0f6b commit 0a10dd4
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 1 deletion.
Expand Up @@ -134,6 +134,11 @@ def gen_module(module_name):
'from tensorflow.python.ops import variables',
'from tensorflow_probability.python.internal.backend.numpy '
'import variables')
code = code.replace(
'from tensorflow.python.trackable '
'import data_structures',
'from tensorflow_probability.python.internal.backend.numpy '
'import data_structures')
code = code.replace(
'from tensorflow.python.training.tracking '
'import data_structures',
Expand Down
Expand Up @@ -24,7 +24,7 @@
from tensorflow_probability.python.internal import parameter_properties
from tensorflow_probability.python.util.deferred_tensor import TensorMetaClass
from tensorflow.python.framework import composite_tensor # pylint: disable=g-direct-tensorflow-import
from tensorflow.python.training.tracking import data_structures # pylint: disable=g-direct-tensorflow-import
from tensorflow.python.trackable import data_structures # pylint: disable=g-direct-tensorflow-import


__all__ = [] # We intend nothing public.
Expand Down

0 comments on commit 0a10dd4

Please sign in to comment.