From 0d7620c3bc50b01054804d8c631b7eed45a7925b Mon Sep 17 00:00:00 2001 From: Francois Chollet Date: Fri, 20 Dec 2019 11:41:17 -0800 Subject: [PATCH] Add EfficientNetB0 to B7 to Keras Applications. Also add swish activation to keras.activations (used by EfficientNets). PiperOrigin-RevId: 286614744 Change-Id: Ieba8b1f47735bdbb31c4efc84f45f763b9daa9a4 --- tensorflow/python/keras/activations.py | 13 + tensorflow/python/keras/applications/BUILD | 1 + .../keras/applications/applications_test.py | 19 + .../python/keras/applications/efficientnet.py | 654 ++++++++++++++++++ tensorflow/python/keras/layers/__init__.py | 1 + .../python/keras/layers/serialization.py | 4 +- .../v1/tensorflow.keras.activations.pbtxt | 4 + .../v2/tensorflow.keras.activations.pbtxt | 4 + 8 files changed, 699 insertions(+), 1 deletion(-) create mode 100644 tensorflow/python/keras/applications/efficientnet.py diff --git a/tensorflow/python/keras/activations.py b/tensorflow/python/keras/activations.py index f26c5a117c2eb6..16f60a7dd11289 100644 --- a/tensorflow/python/keras/activations.py +++ b/tensorflow/python/keras/activations.py @@ -182,6 +182,19 @@ def softsign(x): return nn.softsign(x) +@keras_export('keras.activations.swish') +def swish(x): + """Swish activation function. + + Arguments: + x: Input tensor. + + Returns: + The swish activation applied to `x`. + """ + return nn.swish(x) + + @keras_export('keras.activations.relu') def relu(x, alpha=0., max_value=None, threshold=0): """Applies the rectified linear unit activation function. diff --git a/tensorflow/python/keras/applications/BUILD b/tensorflow/python/keras/applications/BUILD index f5faae02a7e527..17998dff2204fb 100644 --- a/tensorflow/python/keras/applications/BUILD +++ b/tensorflow/python/keras/applications/BUILD @@ -15,6 +15,7 @@ py_library( srcs = [ "__init__.py", "densenet.py", + "efficientnet.py", "imagenet_utils.py", "inception_resnet_v2.py", "inception_v3.py", diff --git a/tensorflow/python/keras/applications/applications_test.py b/tensorflow/python/keras/applications/applications_test.py index b790eb83f95c2b..198bebd904cd58 100644 --- a/tensorflow/python/keras/applications/applications_test.py +++ b/tensorflow/python/keras/applications/applications_test.py @@ -22,6 +22,7 @@ from tensorflow.python.keras import backend from tensorflow.python.keras.applications import densenet +from tensorflow.python.keras.applications import efficientnet from tensorflow.python.keras.applications import inception_resnet_v2 from tensorflow.python.keras.applications import inception_v3 from tensorflow.python.keras.applications import mobilenet @@ -52,6 +53,14 @@ (densenet.DenseNet121, 1024), (densenet.DenseNet169, 1664), (densenet.DenseNet201, 1920), + (efficientnet.EfficientNetB0, 1280), + (efficientnet.EfficientNetB1, 1280), + (efficientnet.EfficientNetB2, 1408), + (efficientnet.EfficientNetB3, 1536), + (efficientnet.EfficientNetB4, 1792), + (efficientnet.EfficientNetB5, 2048), + (efficientnet.EfficientNetB6, 2304), + (efficientnet.EfficientNetB7, 2560), ] NASNET_LIST = [ @@ -72,6 +81,16 @@ def assertShapeEqual(self, shape1, shape2): if v1 != v2: raise AssertionError('Shapes differ: %s vs %s' % (shape1, shape2)) + @parameterized.parameters(*MODEL_LIST) + def test_application_base(self, app, _): + # Can be instantiated with default arguments + model = app(weights=None) + # Can be serialized and deserialized + config = model.get_config() + reconstructed_model = model.__class__.from_config(config) + self.assertEqual(len(model.weights), len(reconstructed_model.weights)) + backend.clear_session() + @parameterized.parameters(*MODEL_LIST) def test_application_notop(self, app, last_dim): if 'NASNet' in app.__name__: diff --git a/tensorflow/python/keras/applications/efficientnet.py b/tensorflow/python/keras/applications/efficientnet.py new file mode 100644 index 00000000000000..f3d0f1e5b0eb8f --- /dev/null +++ b/tensorflow/python/keras/applications/efficientnet.py @@ -0,0 +1,654 @@ +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +# pylint: disable=invalid-name +"""EfficientNet models for Keras. + +Reference paper: + - [EfficientNet: Rethinking Model Scaling for Convolutional Neural Networks] + (https://arxiv.org/abs/1905.11946) (ICML 2019) +""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import copy +import math +import os + +from tensorflow.python.keras import backend +from tensorflow.python.keras import layers +from tensorflow.python.keras.applications import imagenet_utils +from tensorflow.python.keras.engine import training +from tensorflow.python.keras.utils import data_utils +from tensorflow.python.keras.utils import layer_utils +from tensorflow.python.util.tf_export import keras_export + + +BASE_WEIGHTS_PATH = 'https://storage.googleapis.com/keras-applications/' + +WEIGHTS_HASHES = { + 'b0': ('902e53a9f72be733fc0bcb005b3ebbac', + '50bc09e76180e00e4465e1a485ddc09d'), + 'b1': ('1d254153d4ab51201f1646940f018540', + '74c4e6b3e1f6a1eea24c589628592432'), + 'b2': ('b15cce36ff4dcbd00b6dd88e7857a6ad', + '111f8e2ac8aa800a7a99e3239f7bfb39'), + 'b3': ('ffd1fdc53d0ce67064dc6a9c7960ede0', + 'af6d107764bb5b1abb91932881670226'), + 'b4': ('18c95ad55216b8f92d7e70b3a046e2fc', + 'ebc24e6d6c33eaebbd558eafbeedf1ba'), + 'b5': ('ace28f2a6363774853a83a0b21b9421a', + '38879255a25d3c92d5e44e04ae6cec6f'), + 'b6': ('165f6e37dce68623721b423839de8be5', + '9ecce42647a20130c1f39a5d4cb75743'), + 'b7': ('8c03f828fec3ef71311cd463b6759d99', + 'cbcfe4450ddf6f3ad90b1b398090fe4a'), +} + +DEFAULT_BLOCKS_ARGS = [{ + 'kernel_size': 3, + 'repeats': 1, + 'filters_in': 32, + 'filters_out': 16, + 'expand_ratio': 1, + 'id_skip': True, + 'strides': 1, + 'se_ratio': 0.25 +}, { + 'kernel_size': 3, + 'repeats': 2, + 'filters_in': 16, + 'filters_out': 24, + 'expand_ratio': 6, + 'id_skip': True, + 'strides': 2, + 'se_ratio': 0.25 +}, { + 'kernel_size': 5, + 'repeats': 2, + 'filters_in': 24, + 'filters_out': 40, + 'expand_ratio': 6, + 'id_skip': True, + 'strides': 2, + 'se_ratio': 0.25 +}, { + 'kernel_size': 3, + 'repeats': 3, + 'filters_in': 40, + 'filters_out': 80, + 'expand_ratio': 6, + 'id_skip': True, + 'strides': 2, + 'se_ratio': 0.25 +}, { + 'kernel_size': 5, + 'repeats': 3, + 'filters_in': 80, + 'filters_out': 112, + 'expand_ratio': 6, + 'id_skip': True, + 'strides': 1, + 'se_ratio': 0.25 +}, { + 'kernel_size': 5, + 'repeats': 4, + 'filters_in': 112, + 'filters_out': 192, + 'expand_ratio': 6, + 'id_skip': True, + 'strides': 2, + 'se_ratio': 0.25 +}, { + 'kernel_size': 3, + 'repeats': 1, + 'filters_in': 192, + 'filters_out': 320, + 'expand_ratio': 6, + 'id_skip': True, + 'strides': 1, + 'se_ratio': 0.25 +}] + +CONV_KERNEL_INITIALIZER = { + 'class_name': 'VarianceScaling', + 'config': { + 'scale': 2.0, + 'mode': 'fan_out', + 'distribution': 'truncated_normal' + } +} + +DENSE_KERNEL_INITIALIZER = { + 'class_name': 'VarianceScaling', + 'config': { + 'scale': 1. / 3., + 'mode': 'fan_out', + 'distribution': 'uniform' + } +} + + +def EfficientNet(width_coefficient, + depth_coefficient, + default_size, + dropout_rate=0.2, + drop_connect_rate=0.2, + depth_divisor=8, + activation='swish', + blocks_args='default', + model_name='efficientnet', + include_top=True, + weights='imagenet', + input_tensor=None, + input_shape=None, + pooling=None, + classes=1000): + """Instantiates the EfficientNet architecture using given scaling coefficients. + + Optionally loads weights pre-trained on ImageNet. + Note that the data format convention used by the model is + the one specified in your Keras config at `~/.keras/keras.json`. + + Arguments: + width_coefficient: float, scaling coefficient for network width. + depth_coefficient: float, scaling coefficient for network depth. + default_size: integer, default input image size. + dropout_rate: float, dropout rate before final classifier layer. + drop_connect_rate: float, dropout rate at skip connections. + depth_divisor: integer, a unit of network width. + activation: activation function. + blocks_args: list of dicts, parameters to construct block modules. + model_name: string, model name. + include_top: whether to include the fully-connected + layer at the top of the network. + weights: one of `None` (random initialization), + 'imagenet' (pre-training on ImageNet), + or the path to the weights file to be loaded. + input_tensor: optional Keras tensor + (i.e. output of `layers.Input()`) + to use as image input for the model. + input_shape: optional shape tuple, only to be specified + if `include_top` is False. + It should have exactly 3 inputs channels. + pooling: optional pooling mode for feature extraction + when `include_top` is `False`. + - `None` means that the output of the model will be + the 4D tensor output of the + last convolutional layer. + - `avg` means that global average pooling + will be applied to the output of the + last convolutional layer, and thus + the output of the model will be a 2D tensor. + - `max` means that global max pooling will + be applied. + classes: optional number of classes to classify images + into, only to be specified if `include_top` is True, and + if no `weights` argument is specified. + + Returns: + A Keras model instance. + + Raises: + ValueError: in case of invalid argument for `weights`, + or invalid input shape. + """ + if blocks_args == 'default': + blocks_args = DEFAULT_BLOCKS_ARGS + + if not (weights in {'imagenet', None} or os.path.exists(weights)): + raise ValueError('The `weights` argument should be either ' + '`None` (random initialization), `imagenet` ' + '(pre-training on ImageNet), ' + 'or the path to the weights file to be loaded.') + + if weights == 'imagenet' and include_top and classes != 1000: + raise ValueError('If using `weights` as `"imagenet"` with `include_top`' + ' as true, `classes` should be 1000') + + # Determine proper input shape + input_shape = imagenet_utils.obtain_input_shape( + input_shape, + default_size=default_size, + min_size=32, + data_format=backend.image_data_format(), + require_flatten=include_top, + weights=weights) + + if input_tensor is None: + img_input = layers.Input(shape=input_shape) + else: + if not backend.is_keras_tensor(input_tensor): + img_input = layers.Input(tensor=input_tensor, shape=input_shape) + else: + img_input = input_tensor + + bn_axis = 3 if backend.image_data_format() == 'channels_last' else 1 + + def round_filters(filters, divisor=depth_divisor): + """Round number of filters based on depth multiplier.""" + filters *= width_coefficient + new_filters = max(divisor, int(filters + divisor / 2) // divisor * divisor) + # Make sure that round down does not go down by more than 10%. + if new_filters < 0.9 * filters: + new_filters += divisor + return int(new_filters) + + def round_repeats(repeats): + """Round number of repeats based on depth multiplier.""" + return int(math.ceil(depth_coefficient * repeats)) + + # Build stem + x = img_input + x = layers.Rescaling(1. / 255.)(x) + x = layers.Normalization(axis=bn_axis)(x) + + x = layers.ZeroPadding2D( + padding=imagenet_utils.correct_pad(x, 3), + name='stem_conv_pad')(x) + x = layers.Conv2D( + round_filters(32), + 3, + strides=2, + padding='valid', + use_bias=False, + kernel_initializer=CONV_KERNEL_INITIALIZER, + name='stem_conv')(x) + x = layers.BatchNormalization(axis=bn_axis, name='stem_bn')(x) + x = layers.Activation(activation, name='stem_activation')(x) + + # Build blocks + blocks_args = copy.deepcopy(blocks_args) + + b = 0 + blocks = float(sum(args['repeats'] for args in blocks_args)) + for (i, args) in enumerate(blocks_args): + assert args['repeats'] > 0 + # Update block input and output filters based on depth multiplier. + args['filters_in'] = round_filters(args['filters_in']) + args['filters_out'] = round_filters(args['filters_out']) + + for j in range(round_repeats(args.pop('repeats'))): + # The first block needs to take care of stride and filter size increase. + if j > 0: + args['strides'] = 1 + args['filters_in'] = args['filters_out'] + x = block( + x, + activation, + drop_connect_rate * b / blocks, + name='block{}{}_'.format(i + 1, chr(j + 97)), + **args) + b += 1 + + # Build top + x = layers.Conv2D( + round_filters(1280), + 1, + padding='same', + use_bias=False, + kernel_initializer=CONV_KERNEL_INITIALIZER, + name='top_conv')(x) + x = layers.BatchNormalization(axis=bn_axis, name='top_bn')(x) + x = layers.Activation(activation, name='top_activation')(x) + if include_top: + x = layers.GlobalAveragePooling2D(name='avg_pool')(x) + if dropout_rate > 0: + x = layers.Dropout(dropout_rate, name='top_dropout')(x) + x = layers.Dense( + classes, + activation='softmax', + kernel_initializer=DENSE_KERNEL_INITIALIZER, + name='probs')(x) + else: + if pooling == 'avg': + x = layers.GlobalAveragePooling2D(name='avg_pool')(x) + elif pooling == 'max': + x = layers.GlobalMaxPooling2D(name='max_pool')(x) + + # Ensure that the model takes into account + # any potential predecessors of `input_tensor`. + if input_tensor is not None: + inputs = layer_utils.get_source_inputs(input_tensor) + else: + inputs = img_input + + # Create model. + model = training.Model(inputs, x, name=model_name) + + # Load weights. + if weights == 'imagenet': + if include_top: + file_suffix = '.h5' + file_hash = WEIGHTS_HASHES[model_name[-2:]][0] + else: + file_suffix = '_notop.h5' + file_hash = WEIGHTS_HASHES[model_name[-2:]][1] + file_name = model_name + file_suffix + weights_path = data_utils.get_file( + file_name, + BASE_WEIGHTS_PATH + file_name, + cache_subdir='models', + file_hash=file_hash) + model.load_weights(weights_path) + elif weights is not None: + model.load_weights(weights) + return model + + +def block(inputs, + activation='swish', + drop_rate=0., + name='', + filters_in=32, + filters_out=16, + kernel_size=3, + strides=1, + expand_ratio=1, + se_ratio=0., + id_skip=True): + """An inverted residual block. + + Arguments: + inputs: input tensor. + activation: activation function. + drop_rate: float between 0 and 1, fraction of the input units to drop. + name: string, block label. + filters_in: integer, the number of input filters. + filters_out: integer, the number of output filters. + kernel_size: integer, the dimension of the convolution window. + strides: integer, the stride of the convolution. + expand_ratio: integer, scaling coefficient for the input filters. + se_ratio: float between 0 and 1, fraction to squeeze the input filters. + id_skip: boolean. + + Returns: + output tensor for the block. + """ + bn_axis = 3 if backend.image_data_format() == 'channels_last' else 1 + + # Expansion phase + filters = filters_in * expand_ratio + if expand_ratio != 1: + x = layers.Conv2D( + filters, + 1, + padding='same', + use_bias=False, + kernel_initializer=CONV_KERNEL_INITIALIZER, + name=name + 'expand_conv')( + inputs) + x = layers.BatchNormalization(axis=bn_axis, name=name + 'expand_bn')(x) + x = layers.Activation(activation, name=name + 'expand_activation')(x) + else: + x = inputs + + # Depthwise Convolution + if strides == 2: + x = layers.ZeroPadding2D( + padding=imagenet_utils.correct_pad(x, kernel_size), + name=name + 'dwconv_pad')(x) + conv_pad = 'valid' + else: + conv_pad = 'same' + x = layers.DepthwiseConv2D( + kernel_size, + strides=strides, + padding=conv_pad, + use_bias=False, + depthwise_initializer=CONV_KERNEL_INITIALIZER, + name=name + 'dwconv')(x) + x = layers.BatchNormalization(axis=bn_axis, name=name + 'bn')(x) + x = layers.Activation(activation, name=name + 'activation')(x) + + # Squeeze and Excitation phase + if 0 < se_ratio <= 1: + filters_se = max(1, int(filters_in * se_ratio)) + se = layers.GlobalAveragePooling2D(name=name + 'se_squeeze')(x) + se = layers.Reshape((1, 1, filters), name=name + 'se_reshape')(se) + se = layers.Conv2D( + filters_se, + 1, + padding='same', + activation=activation, + kernel_initializer=CONV_KERNEL_INITIALIZER, + name=name + 'se_reduce')( + se) + se = layers.Conv2D( + filters, + 1, + padding='same', + activation='sigmoid', + kernel_initializer=CONV_KERNEL_INITIALIZER, + name=name + 'se_expand')(se) + x = layers.multiply([x, se], name=name + 'se_excite') + + # Output phase + x = layers.Conv2D( + filters_out, + 1, + padding='same', + use_bias=False, + kernel_initializer=CONV_KERNEL_INITIALIZER, + name=name + 'project_conv')(x) + x = layers.BatchNormalization(axis=bn_axis, name=name + 'project_bn')(x) + if id_skip and strides == 1 and filters_in == filters_out: + if drop_rate > 0: + x = layers.Dropout( + drop_rate, noise_shape=(None, 1, 1, 1), name=name + 'drop')(x) + x = layers.add([x, inputs], name=name + 'add') + return x + + +@keras_export('keras.applications.efficientnet.EfficientNetB0', + 'keras.applications.EfficientNetB0') +def EfficientNetB0(include_top=True, + weights='imagenet', + input_tensor=None, + input_shape=None, + pooling=None, + classes=1000, + **kwargs): + return EfficientNet( + 1.0, + 1.0, + 224, + 0.2, + model_name='efficientnetb0', + include_top=include_top, + weights=weights, + input_tensor=input_tensor, + input_shape=input_shape, + pooling=pooling, + classes=classes, + **kwargs) + + +@keras_export('keras.applications.efficientnet.EfficientNetB1', + 'keras.applications.EfficientNetB1') +def EfficientNetB1(include_top=True, + weights='imagenet', + input_tensor=None, + input_shape=None, + pooling=None, + classes=1000, + **kwargs): + return EfficientNet( + 1.0, + 1.1, + 240, + 0.2, + model_name='efficientnetb1', + include_top=include_top, + weights=weights, + input_tensor=input_tensor, + input_shape=input_shape, + pooling=pooling, + classes=classes, + **kwargs) + + +@keras_export('keras.applications.efficientnet.EfficientNetB2', + 'keras.applications.EfficientNetB2') +def EfficientNetB2(include_top=True, + weights='imagenet', + input_tensor=None, + input_shape=None, + pooling=None, + classes=1000, + **kwargs): + return EfficientNet( + 1.1, + 1.2, + 260, + 0.3, + model_name='efficientnetb2', + include_top=include_top, + weights=weights, + input_tensor=input_tensor, + input_shape=input_shape, + pooling=pooling, + classes=classes, + **kwargs) + + +@keras_export('keras.applications.efficientnet.EfficientNetB3', + 'keras.applications.EfficientNetB3') +def EfficientNetB3(include_top=True, + weights='imagenet', + input_tensor=None, + input_shape=None, + pooling=None, + classes=1000, + **kwargs): + return EfficientNet( + 1.2, + 1.4, + 300, + 0.3, + model_name='efficientnetb3', + include_top=include_top, + weights=weights, + input_tensor=input_tensor, + input_shape=input_shape, + pooling=pooling, + classes=classes, + **kwargs) + + +@keras_export('keras.applications.efficientnet.EfficientNetB4', + 'keras.applications.EfficientNetB4') +def EfficientNetB4(include_top=True, + weights='imagenet', + input_tensor=None, + input_shape=None, + pooling=None, + classes=1000, + **kwargs): + return EfficientNet( + 1.4, + 1.8, + 380, + 0.4, + model_name='efficientnetb4', + include_top=include_top, + weights=weights, + input_tensor=input_tensor, + input_shape=input_shape, + pooling=pooling, + classes=classes, + **kwargs) + + +@keras_export('keras.applications.efficientnet.EfficientNetB5', + 'keras.applications.EfficientNetB5') +def EfficientNetB5(include_top=True, + weights='imagenet', + input_tensor=None, + input_shape=None, + pooling=None, + classes=1000, + **kwargs): + return EfficientNet( + 1.6, + 2.2, + 456, + 0.4, + model_name='efficientnetb5', + include_top=include_top, + weights=weights, + input_tensor=input_tensor, + input_shape=input_shape, + pooling=pooling, + classes=classes, + **kwargs) + + +@keras_export('keras.applications.efficientnet.EfficientNetB6', + 'keras.applications.EfficientNetB6') +def EfficientNetB6(include_top=True, + weights='imagenet', + input_tensor=None, + input_shape=None, + pooling=None, + classes=1000, + **kwargs): + return EfficientNet( + 1.8, + 2.6, + 528, + 0.5, + model_name='efficientnetb6', + include_top=include_top, + weights=weights, + input_tensor=input_tensor, + input_shape=input_shape, + pooling=pooling, + classes=classes, + **kwargs) + + +@keras_export('keras.applications.efficientnet.EfficientNetB7', + 'keras.applications.EfficientNetB7') +def EfficientNetB7(include_top=True, + weights='imagenet', + input_tensor=None, + input_shape=None, + pooling=None, + classes=1000, + **kwargs): + return EfficientNet( + 2.0, + 3.1, + 600, + 0.5, + model_name='efficientnetb7', + include_top=include_top, + weights=weights, + input_tensor=input_tensor, + input_shape=input_shape, + pooling=pooling, + classes=classes, + **kwargs) + + +@keras_export('keras.applications.efficientnet.preprocess_input') +def preprocess_input(x, data_format=None): # pylint: disable=unused-argument + return x + + +@keras_export('keras.applications.efficientnet.decode_predictions') +def decode_predictions(preds, top=5): + return imagenet_utils.decode_predictions(preds, top=top) diff --git a/tensorflow/python/keras/layers/__init__.py b/tensorflow/python/keras/layers/__init__.py index 07cb1bdf1b3164..3f648b46bff3ee 100644 --- a/tensorflow/python/keras/layers/__init__.py +++ b/tensorflow/python/keras/layers/__init__.py @@ -44,6 +44,7 @@ from tensorflow.python.keras.layers.preprocessing.text_vectorization_v1 import TextVectorization from tensorflow.python.keras.layers.preprocessing.text_vectorization import TextVectorization as TextVectorizationV2 TextVectorizationV1 = TextVectorization +from tensorflow.python.keras.layers.preprocessing.image_preprocessing import Rescaling # Advanced activations. from tensorflow.python.keras.layers.advanced_activations import LeakyReLU diff --git a/tensorflow/python/keras/layers/serialization.py b/tensorflow/python/keras/layers/serialization.py index a7c43c60350b54..afefcc3f040d07 100644 --- a/tensorflow/python/keras/layers/serialization.py +++ b/tensorflow/python/keras/layers/serialization.py @@ -40,6 +40,7 @@ from tensorflow.python.keras.layers.normalization import * from tensorflow.python.keras.layers.pooling import * from tensorflow.python.keras.layers.preprocessing.image_preprocessing import * +from tensorflow.python.keras.layers.preprocessing.normalization_v1 import * from tensorflow.python.keras.layers.recurrent import * from tensorflow.python.keras.layers.rnn_cell_wrapper_v2 import * from tensorflow.python.keras.layers.wrappers import * @@ -49,7 +50,8 @@ if tf2.enabled(): from tensorflow.python.keras.layers.normalization_v2 import * # pylint: disable=g-import-not-at-top - from tensorflow.python.keras.layers.recurrent_v2 import * # pylint: disable=g-import-not-at-top + from tensorflow.python.keras.layers.recurrent_v2 import * # pylint: disable=g-import-not-at-top + from tensorflow.python.keras.layers.preprocessing.normalization import * # pylint: disable=g-import-not-at-top # This deserialization table is added for backward compatibility, as in TF 1.13, # BatchNormalizationV1 and BatchNormalizationV2 are used as class name for v1 diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.activations.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.activations.pbtxt index eb315e356dabc5..ee3d1f3d4a289a 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.activations.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.activations.pbtxt @@ -52,6 +52,10 @@ tf_module { name: "softsign" argspec: "args=[\'x\'], varargs=None, keywords=None, defaults=None" } + member_method { + name: "swish" + argspec: "args=[\'x\'], varargs=None, keywords=None, defaults=None" + } member_method { name: "tanh" argspec: "args=[\'x\'], varargs=None, keywords=None, defaults=None" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.activations.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.activations.pbtxt index eb315e356dabc5..ee3d1f3d4a289a 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.activations.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.activations.pbtxt @@ -52,6 +52,10 @@ tf_module { name: "softsign" argspec: "args=[\'x\'], varargs=None, keywords=None, defaults=None" } + member_method { + name: "swish" + argspec: "args=[\'x\'], varargs=None, keywords=None, defaults=None" + } member_method { name: "tanh" argspec: "args=[\'x\'], varargs=None, keywords=None, defaults=None"