Skip to content

Commit

Permalink
Fix keras dtype importing and unpin for CI (#6857)
Browse files Browse the repository at this point in the history
Keras' output format was slightly changed in
keras-team/keras#19711; for non-input layers
dtypes will now be exported as a config map instead of just a string.
This fixes test breakages when using ToT keras.

Alternative to #6855
  • Loading branch information
mloc committed May 21, 2024
1 parent 5f8b019 commit cbeecb7
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 7 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ jobs:
- name: 'Install TensorFlow'
run: |
python -m pip install -U pip
pip install "${TENSORFLOW_VERSION}" keras-nightly==3.3.3.dev2024051503
pip install "${TENSORFLOW_VERSION}"
if: matrix.tf_version_id != 'notf'
- name: 'Install Python dependencies'
run: |
Expand Down
18 changes: 12 additions & 6 deletions tensorboard/plugins/graph/keras_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,13 +258,19 @@ def keras_model_to_graph_def(keras_layer):
node_def.attr["keras_class"].s = keras_cls_name

dtype_or_policy = layer_config.get("dtype")
# Skip dtype processing if this is a dict, since it's presumably a instance of
# tf/keras/mixed_precision/Policy rather than a single dtype.
dtype = None
# If this is a dict, try and extract the dtype string from
# `config.name`; keras will export like this for non-input layers. If
# we can't find `config.name`, we skip it as it's presumably a instance
# of tf/keras/mixed_precision/Policy rather than a single dtype.
# TODO(#5548): parse the policy dict and populate the dtype attr with the variable dtype.
if dtype_or_policy is not None and not isinstance(
dtype_or_policy, dict
):
tf_dtype = dtypes.as_dtype(layer_config.get("dtype"))
if isinstance(dtype_or_policy, dict):
if "config" in dtype_or_policy:
dtype = dtype_or_policy.get("config").get("name")
elif dtype_or_policy is not None:
dtype = dtype_or_policy
if dtype is not None:
tf_dtype = dtypes.as_dtype(dtype)
node_def.attr["dtype"].type = tf_dtype.as_datatype_enum
if layer.get("inbound_nodes") is not None:
for name, size, index in _get_inbound_nodes(layer):
Expand Down

0 comments on commit cbeecb7

Please sign in to comment.