Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add show_dtype support for plot_model, update related tests. #40601

Merged
merged 11 commits into from Jun 26, 2020
29 changes: 25 additions & 4 deletions tensorflow/python/keras/utils/vis_utils.py
Expand Up @@ -69,6 +69,7 @@ def add_edge(dot, src, dst):
@keras_export('keras.utils.model_to_dot')
def model_to_dot(model,
show_shapes=False,
show_dtype=False,
show_layer_names=True,
rankdir='TB',
expand_nested=False,
Expand All @@ -79,6 +80,7 @@ def model_to_dot(model,
Arguments:
model: A Keras model instance.
show_shapes: whether to display shape information.
show_dtype: whether to display layer dtypes.
show_layer_names: whether to display layer names.
rankdir: `rankdir` argument passed to PyDot,
a string specifying the format of the plot:
Expand Down Expand Up @@ -150,8 +152,11 @@ def model_to_dot(model,
if isinstance(layer, wrappers.Wrapper):
if expand_nested and isinstance(layer.layer,
functional.Functional):
submodel_wrapper = model_to_dot(layer.layer, show_shapes,
show_layer_names, rankdir,
submodel_wrapper = model_to_dot(layer.layer,
show_shapes,
show_dtype,
show_layer_names,
rankdir,
expand_nested,
subgraph=True)
# sub_w : submodel_wrapper
Expand All @@ -165,8 +170,11 @@ def model_to_dot(model,
class_name = '{}({})'.format(class_name, child_class_name)

if expand_nested and isinstance(layer, functional.Functional):
submodel_not_wrapper = model_to_dot(layer, show_shapes,
show_layer_names, rankdir,
submodel_not_wrapper = model_to_dot(layer,
show_shapes,
show_dtype,
show_layer_names,
rankdir,
expand_nested,
subgraph=True)
# sub_n : submodel_not_wrapper
Expand All @@ -180,6 +188,16 @@ def model_to_dot(model,
label = '{}: {}'.format(layer_name, class_name)
else:
label = class_name

# Rebuild the label as a table including the layer's dtype.
if show_dtype:
def format_dtype(dtype):
jonah-kohn marked this conversation as resolved.
Show resolved Hide resolved
if dtype is None:
return '?'
else:
return str(dtype)

label = '%s|%s' % (label, format_dtype(layer.dtype))

# Rebuild the label as a table including input/output shapes.
if show_shapes:
Expand Down Expand Up @@ -260,6 +278,7 @@ def format_shape(shape):
def plot_model(model,
to_file='model.png',
show_shapes=False,
show_dtype=False,
show_layer_names=True,
rankdir='TB',
expand_nested=False,
Expand All @@ -286,6 +305,7 @@ def plot_model(model,
model: A Keras model instance
to_file: File name of the plot image.
show_shapes: whether to display shape information.
show_dtype: whether to display layer dtypes.
show_layer_names: whether to display layer names.
rankdir: `rankdir` argument passed to PyDot,
a string specifying the format of the plot:
Expand All @@ -300,6 +320,7 @@ def plot_model(model,
"""
dot = model_to_dot(model,
show_shapes=show_shapes,
show_dtype=show_dtype,
show_layer_names=show_layer_names,
rankdir=rankdir,
expand_nested=expand_nested,
Expand Down
12 changes: 8 additions & 4 deletions tensorflow/python/keras/utils/vis_utils_test.py
Expand Up @@ -36,7 +36,8 @@ def test_plot_model_cnn(self):
model.add(keras.layers.Dense(5, name='dense'))
dot_img_file = 'model_1.png'
try:
vis_utils.plot_model(model, to_file=dot_img_file, show_shapes=True)
vis_utils.plot_model(model, to_file=dot_img_file,
show_shapes=True, show_dtype=True)
self.assertTrue(file_io.file_exists(dot_img_file))
file_io.delete_file(dot_img_file)
except ImportError:
Expand All @@ -62,7 +63,8 @@ def test_plot_model_with_wrapped_layers_and_models(self):
dot_img_file = 'model_2.png'
try:
vis_utils.plot_model(
model, to_file=dot_img_file, show_shapes=True, expand_nested=True)
model, to_file=dot_img_file, show_shapes=True,
show_dtype=True, expand_nested=True)
self.assertTrue(file_io.file_exists(dot_img_file))
file_io.delete_file(dot_img_file)
except ImportError:
Expand All @@ -76,7 +78,8 @@ def test_plot_model_with_add_loss(self):
dot_img_file = 'model_3.png'
try:
vis_utils.plot_model(
model, to_file=dot_img_file, show_shapes=True, expand_nested=True)
model, to_file=dot_img_file, show_shapes=True,
show_dtype=True, expand_nested=True)
self.assertTrue(file_io.file_exists(dot_img_file))
file_io.delete_file(dot_img_file)
except ImportError:
Expand All @@ -88,7 +91,8 @@ def test_plot_model_with_add_loss(self):
dot_img_file = 'model_4.png'
try:
vis_utils.plot_model(
model, to_file=dot_img_file, show_shapes=True, expand_nested=True)
model, to_file=dot_img_file, show_shapes=True,
show_dtype=True expand_nested=True)
jonah-kohn marked this conversation as resolved.
Show resolved Hide resolved
self.assertTrue(file_io.file_exists(dot_img_file))
file_io.delete_file(dot_img_file)
except ImportError:
Expand Down