Skip to content

Commit

Permalink
Add EfficientNetB0 to B7 to Keras Applications.
Browse files Browse the repository at this point in the history
Also add swish activation to keras.activations (used by EfficientNets).

PiperOrigin-RevId: 286614744
Change-Id: Ieba8b1f47735bdbb31c4efc84f45f763b9daa9a4
  • Loading branch information
fchollet authored and tensorflower-gardener committed Dec 20, 2019
1 parent 3dbe4fb commit 0d7620c
Show file tree
Hide file tree
Showing 8 changed files with 699 additions and 1 deletion.
13 changes: 13 additions & 0 deletions tensorflow/python/keras/activations.py
Expand Up @@ -182,6 +182,19 @@ def softsign(x):
return nn.softsign(x)


@keras_export('keras.activations.swish')
def swish(x):
"""Swish activation function.
Arguments:
x: Input tensor.
Returns:
The swish activation applied to `x`.
"""
return nn.swish(x)


@keras_export('keras.activations.relu')
def relu(x, alpha=0., max_value=None, threshold=0):
"""Applies the rectified linear unit activation function.
Expand Down
1 change: 1 addition & 0 deletions tensorflow/python/keras/applications/BUILD
Expand Up @@ -15,6 +15,7 @@ py_library(
srcs = [
"__init__.py",
"densenet.py",
"efficientnet.py",
"imagenet_utils.py",
"inception_resnet_v2.py",
"inception_v3.py",
Expand Down
19 changes: 19 additions & 0 deletions tensorflow/python/keras/applications/applications_test.py
Expand Up @@ -22,6 +22,7 @@

from tensorflow.python.keras import backend
from tensorflow.python.keras.applications import densenet
from tensorflow.python.keras.applications import efficientnet
from tensorflow.python.keras.applications import inception_resnet_v2
from tensorflow.python.keras.applications import inception_v3
from tensorflow.python.keras.applications import mobilenet
Expand Down Expand Up @@ -52,6 +53,14 @@
(densenet.DenseNet121, 1024),
(densenet.DenseNet169, 1664),
(densenet.DenseNet201, 1920),
(efficientnet.EfficientNetB0, 1280),
(efficientnet.EfficientNetB1, 1280),
(efficientnet.EfficientNetB2, 1408),
(efficientnet.EfficientNetB3, 1536),
(efficientnet.EfficientNetB4, 1792),
(efficientnet.EfficientNetB5, 2048),
(efficientnet.EfficientNetB6, 2304),
(efficientnet.EfficientNetB7, 2560),
]

NASNET_LIST = [
Expand All @@ -72,6 +81,16 @@ def assertShapeEqual(self, shape1, shape2):
if v1 != v2:
raise AssertionError('Shapes differ: %s vs %s' % (shape1, shape2))

@parameterized.parameters(*MODEL_LIST)
def test_application_base(self, app, _):
# Can be instantiated with default arguments
model = app(weights=None)
# Can be serialized and deserialized
config = model.get_config()
reconstructed_model = model.__class__.from_config(config)
self.assertEqual(len(model.weights), len(reconstructed_model.weights))
backend.clear_session()

@parameterized.parameters(*MODEL_LIST)
def test_application_notop(self, app, last_dim):
if 'NASNet' in app.__name__:
Expand Down

0 comments on commit 0d7620c

Please sign in to comment.