Skip to content

Commit

Permalink
Fixes regression for issue #5548: Avoid attempting to convert dtypes …
Browse files Browse the repository at this point in the history
…from "mixed precision" policy types. (#6859)

Following-up on PR #6857, which seems to have introduced a regression
for issue #5548.

This change just gracefully degrades the functionality to avoid crashing
on an error (as it was before the recent change in #6857), but it might
not be a proper fix for the scenario described in that issue.
  • Loading branch information
arcra committed May 21, 2024
1 parent cbeecb7 commit ae7d0b9
Show file tree
Hide file tree
Showing 3 changed files with 43 additions and 9 deletions.
1 change: 1 addition & 0 deletions tensorboard/plugins/graph/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,7 @@ py_library(
deps = [
"//tensorboard/compat/proto:protos_all_py_pb2",
"//tensorboard/compat/tensorflow_stub",
"//tensorboard/util:tb_logging",
],
)

Expand Down
37 changes: 28 additions & 9 deletions tensorboard/plugins/graph/keras_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,10 @@
"""
from tensorboard.compat.proto.graph_pb2 import GraphDef
from tensorboard.compat.tensorflow_stub import dtypes
from tensorboard.util import tb_logging


logger = tb_logging.get_logger()


def _walk_layers(keras_layer):
Expand Down Expand Up @@ -259,19 +263,34 @@ def keras_model_to_graph_def(keras_layer):

dtype_or_policy = layer_config.get("dtype")
dtype = None
has_unsupported_value = False
# If this is a dict, try and extract the dtype string from
# `config.name`; keras will export like this for non-input layers. If
# we can't find `config.name`, we skip it as it's presumably a instance
# of tf/keras/mixed_precision/Policy rather than a single dtype.
# TODO(#5548): parse the policy dict and populate the dtype attr with the variable dtype.
if isinstance(dtype_or_policy, dict):
if "config" in dtype_or_policy:
dtype = dtype_or_policy.get("config").get("name")
# `config.name`. Keras will export like this for non-input layers and
# some other cases (e.g. tf/keras/mixed_precision/Policy, as described
# in issue #5548).
if isinstance(dtype_or_policy, dict) and "config" in dtype_or_policy:
dtype = dtype_or_policy.get("config").get("name")
elif dtype_or_policy is not None:
dtype = dtype_or_policy

if dtype is not None:
tf_dtype = dtypes.as_dtype(dtype)
node_def.attr["dtype"].type = tf_dtype.as_datatype_enum
try:
tf_dtype = dtypes.as_dtype(dtype)
node_def.attr["dtype"].type = tf_dtype.as_datatype_enum
except TypeError:
has_unsupported_value = True
elif dtype_or_policy is not None:
has_unsupported_value = True

if has_unsupported_value:
# There's at least one known case when this happens, which is when
# mixed precision dtype policies are used, as described in issue
# #5548. (See https://keras.io/api/mixed_precision/).
# There might be a better way to handle this, but here we are.
logger.warning(
"Unsupported dtype value in graph model config (json):\n%s",
dtype_or_policy,
)
if layer.get("inbound_nodes") is not None:
for name, size, index in _get_inbound_nodes(layer):
inbound_name = _scoped_name(name_scope, name)
Expand Down
14 changes: 14 additions & 0 deletions tensorboard/plugins/graph/keras_util_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1043,6 +1043,20 @@ def test_keras_model_to_graph_def_functional_multiple_inbound_nodes_from_same_no

self.assertGraphDefToModel(expected_proto, model)

def test__keras_model_to_graph_def__does_not_crash_with_mixed_precision_dtype_policy(
self,
):
# See https://keras.io/api/mixed_precision/ for more info.
# Test to avoid regression on issue #5548
first_layer = tf.keras.layers.Dense(
1, input_shape=(1,), dtype="mixed_float16"
)
model = tf.keras.Sequential([first_layer])

model_config = json.loads(model.to_json())
# This line should not raise errors:
keras_util.keras_model_to_graph_def(model_config)


class _DoublingLayer(tf.keras.layers.Layer):
def call(self, inputs):
Expand Down

0 comments on commit ae7d0b9

Please sign in to comment.