Skip to content

Commit

Permalink
check hasattr on the type, not the instance.
Browse files Browse the repository at this point in the history
hasattr on the instance triggers __getattr__ which carries very undesirable
effects, such as running Ops on a donated buffer.

Long term, we may want to audit all uses of hasattr on TensorFlow instances
that overrides __getattr__ in nontrival (e.g. running tf Ops) ways. They will
almost always cause trouble here and there because TensorFlow is quite far
from being able guarantee if an Op returns or consumes is actually valid in all cases. Things will improve give it time, but if we can avoid such strong assumptions the system tend to get more robust.

PiperOrigin-RevId: 578261984
  • Loading branch information
rainwoodman authored and tensorflower-gardener committed Oct 31, 2023
1 parent 4891584 commit e44f8a0
Showing 1 changed file with 7 additions and 7 deletions.
14 changes: 7 additions & 7 deletions tensorflow/python/checkpoint/async_checkpoint_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,10 +263,10 @@ def _ensure_initialized(self):
# custom __getattr__ code, see b/152031870 for context.
for t in all_trackables:
# Special case 1: TPU Embedding, populate object_map here
# Special case 1: Handle TPU Embedding by addnig a dummy instance to the
# object map. Also add TPUEmbedding to separate list for special handling
# with values copy.
if hasattr(t, _TPU_EMBEDDING_ATTR):
# Special case 1: Handle TPU Embedding by addnig a dummy instance to the
# object map. Also add TPUEmbedding to separate list for special handling
# with values copy.
if hasattr(type(t), _TPU_EMBEDDING_ATTR):
self._handle_tpu_embedding(t)
# Special case 2: handle slot variables. The object_map is populated later
# when the variable values are being copied to host CPU for the first
Expand Down Expand Up @@ -414,9 +414,9 @@ def _handle_tpu_embedding(self, tpu_embedding):
Raises:
AttributeError: if the input trackable is not TPUEmbedding type.
"""
if not hasattr(
tpu_embedding, _TPU_EMBEDDING_ATTR
) or not callable(tpu_embedding._create_copy_for_async_checkpoint): # pylint: disable=protected-access
if not hasattr(type(tpu_embedding), _TPU_EMBEDDING_ATTR) or not callable(
tpu_embedding._create_copy_for_async_checkpoint # pylint: disable=protected-access
):
raise AttributeError(
"Expecting TPUEmbedding type; got %s" % type(tpu_embedding)
)
Expand Down

0 comments on commit e44f8a0

Please sign in to comment.