diff --git a/tensorflow/python/keras/utils/vis_utils.py b/tensorflow/python/keras/utils/vis_utils.py index 3720708543f3ac..acfb589f51b5ea 100644 --- a/tensorflow/python/keras/utils/vis_utils.py +++ b/tensorflow/python/keras/utils/vis_utils.py @@ -69,6 +69,7 @@ def add_edge(dot, src, dst): @keras_export('keras.utils.model_to_dot') def model_to_dot(model, show_shapes=False, + show_dtype=False, show_layer_names=True, rankdir='TB', expand_nested=False, @@ -79,6 +80,7 @@ def model_to_dot(model, Arguments: model: A Keras model instance. show_shapes: whether to display shape information. + show_dtype: whether to display layer dtypes. show_layer_names: whether to display layer names. rankdir: `rankdir` argument passed to PyDot, a string specifying the format of the plot: @@ -150,8 +152,11 @@ def model_to_dot(model, if isinstance(layer, wrappers.Wrapper): if expand_nested and isinstance(layer.layer, functional.Functional): - submodel_wrapper = model_to_dot(layer.layer, show_shapes, - show_layer_names, rankdir, + submodel_wrapper = model_to_dot(layer.layer, + show_shapes, + show_dtype, + show_layer_names, + rankdir, expand_nested, subgraph=True) # sub_w : submodel_wrapper @@ -165,8 +170,11 @@ def model_to_dot(model, class_name = '{}({})'.format(class_name, child_class_name) if expand_nested and isinstance(layer, functional.Functional): - submodel_not_wrapper = model_to_dot(layer, show_shapes, - show_layer_names, rankdir, + submodel_not_wrapper = model_to_dot(layer, + show_shapes, + show_dtype, + show_layer_names, + rankdir, expand_nested, subgraph=True) # sub_n : submodel_not_wrapper @@ -180,6 +188,16 @@ def model_to_dot(model, label = '{}: {}'.format(layer_name, class_name) else: label = class_name + + # Rebuild the label as a table including the layer's dtype. + if show_dtype: + def format_dtype(dtype): + if dtype is None: + return '?' + else: + return str(dtype) + + label = '%s|%s' % (label, format_dtype(layer.dtype)) # Rebuild the label as a table including input/output shapes. if show_shapes: @@ -260,6 +278,7 @@ def format_shape(shape): def plot_model(model, to_file='model.png', show_shapes=False, + show_dtype=False, show_layer_names=True, rankdir='TB', expand_nested=False, @@ -286,6 +305,7 @@ def plot_model(model, model: A Keras model instance to_file: File name of the plot image. show_shapes: whether to display shape information. + show_dtype: whether to display layer dtypes. show_layer_names: whether to display layer names. rankdir: `rankdir` argument passed to PyDot, a string specifying the format of the plot: @@ -300,6 +320,7 @@ def plot_model(model, """ dot = model_to_dot(model, show_shapes=show_shapes, + show_dtype=show_dtype, show_layer_names=show_layer_names, rankdir=rankdir, expand_nested=expand_nested, diff --git a/tensorflow/python/keras/utils/vis_utils_test.py b/tensorflow/python/keras/utils/vis_utils_test.py index 984014216beb89..c5e3d18e7340b1 100644 --- a/tensorflow/python/keras/utils/vis_utils_test.py +++ b/tensorflow/python/keras/utils/vis_utils_test.py @@ -36,7 +36,8 @@ def test_plot_model_cnn(self): model.add(keras.layers.Dense(5, name='dense')) dot_img_file = 'model_1.png' try: - vis_utils.plot_model(model, to_file=dot_img_file, show_shapes=True) + vis_utils.plot_model(model, to_file=dot_img_file, + show_shapes=True, show_dtype=True) self.assertTrue(file_io.file_exists(dot_img_file)) file_io.delete_file(dot_img_file) except ImportError: @@ -62,7 +63,8 @@ def test_plot_model_with_wrapped_layers_and_models(self): dot_img_file = 'model_2.png' try: vis_utils.plot_model( - model, to_file=dot_img_file, show_shapes=True, expand_nested=True) + model, to_file=dot_img_file, show_shapes=True, + show_dtype=True, expand_nested=True) self.assertTrue(file_io.file_exists(dot_img_file)) file_io.delete_file(dot_img_file) except ImportError: @@ -76,7 +78,8 @@ def test_plot_model_with_add_loss(self): dot_img_file = 'model_3.png' try: vis_utils.plot_model( - model, to_file=dot_img_file, show_shapes=True, expand_nested=True) + model, to_file=dot_img_file, show_shapes=True, + show_dtype=True, expand_nested=True) self.assertTrue(file_io.file_exists(dot_img_file)) file_io.delete_file(dot_img_file) except ImportError: @@ -88,7 +91,8 @@ def test_plot_model_with_add_loss(self): dot_img_file = 'model_4.png' try: vis_utils.plot_model( - model, to_file=dot_img_file, show_shapes=True, expand_nested=True) + model, to_file=dot_img_file, show_shapes=True, + show_dtype=True, expand_nested=True) self.assertTrue(file_io.file_exists(dot_img_file)) file_io.delete_file(dot_img_file) except ImportError: