Skip to content

Commit

Permalink
Add EfficientNet-lite models and set EfficientNet-Lite0 as the defaul…
Browse files Browse the repository at this point in the history
…t model in model maker.

PiperOrigin-RevId: 299329492
  • Loading branch information
ziyeqinghan authored and copybara-github committed Mar 6, 2020
1 parent b8de6f0 commit a83cb97
Show file tree
Hide file tree
Showing 8 changed files with 70 additions and 27 deletions.
4 changes: 2 additions & 2 deletions tensorflow_examples/lite/model_maker/cli/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,15 +54,15 @@ def image_classification(self,
data_dir,
tflite_filename,
label_filename,
spec='efficientnet_b0',
spec='efficientnet_lite0',
**kwargs):
"""Run Image classification.
Args:
data_dir: str, input directory of training data. (required)
tflite_filename: str, output path to export tflite file. (required)
label_filename: str, output path to export label file. (required)
spec: str, model_name. Valid: {MODELS}, default: efficientnet_b0.
spec: str, model_name. Valid: {MODELS}, default: efficientnet_lite0.
**kwargs: --epochs: int, epoch num to run. More: see `create` function.
"""
# Convert types
Expand Down
4 changes: 2 additions & 2 deletions tensorflow_examples/lite/model_maker/cli/cli_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,10 +58,10 @@ def test_init(self, tf_opt, expected_tf):
with patch_image() as run, patch_setup() as setup:
cli.main()
setup.assert_called_once_with(expected_tf)
run.assert_called_once_with('data', 'lite', 'label', 'efficientnet_b0')
run.assert_called_once_with('data', 'lite', 'label', 'efficientnet_lite0')

@parameterized.parameters(
([], ['efficientnet_b0'], {}),
([], ['efficientnet_lite0'], {}),
(['--spec=mobilenet_v2'], ['mobilenet_v2'], {}),
(['--spec=mobilenet_v2', '--epochs=1'], ['mobilenet_v2'], dict(epochs=1)),
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@

def create(train_data,
model_export_format=mef.ModelExportFormat.TFLITE,
model_spec=ms.mobilenet_v2_spec,
model_spec=ms.efficientnet_lite0_spec,
shuffle=False,
validation_data=None,
batch_size=None,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -81,11 +81,23 @@ def test_mobilenetv2_model_create_v1_incompatible(self):
model_spec.mobilenet_v2_spec)

@test_util.test_in_tf_1and2
def test_efficientnetb0_model(self):
def test_efficientnetlite0_model(self):
model = image_classifier.create(
self.train_data,
mef.ModelExportFormat.TFLITE,
model_spec.efficientnet_b0_spec,
model_spec.efficientnet_lite0_spec,
epochs=2,
batch_size=4,
shuffle=True)
self._test_accuracy(model)
self._test_export_to_tflite(model)

@test_util.test_in_tf_1and2
def test_efficientnetlite4_model(self):
model = image_classifier.create(
self.train_data,
mef.ModelExportFormat.TFLITE,
model_spec.efficientnet_lite4_spec,
epochs=2,
batch_size=4,
shuffle=True)
Expand Down
47 changes: 40 additions & 7 deletions tensorflow_examples/lite/model_maker/core/task/model_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,9 @@
from official.utils.misc import distribution_utils


DEFAULT_INPUT_IMAGE_SHAPE = [224, 224]


def create_int_feature(values):
feature = tf.train.Feature(int64_list=tf.train.Int64List(value=list(values)))
return feature
Expand Down Expand Up @@ -66,18 +69,17 @@ def get_num_gpus(num_gpus):
class ImageModelSpec(object):
"""A specification of image model."""

input_image_shape = [224, 224]
mean_rgb = [0, 0, 0]
stddev_rgb = [255, 255, 255]

def __init__(self, uri, compat_tf_versions=None):
def __init__(self, uri, compat_tf_versions=None, input_image_shape=None):
self.uri = uri
self.compat_tf_versions = _get_compat_tf_versions(compat_tf_versions)

self.input_image_shape = DEFAULT_INPUT_IMAGE_SHAPE
if input_image_shape is not None:
self.input_image_shape = input_image_shape

efficientnet_b0_spec = ImageModelSpec(
uri='https://tfhub.dev/google/efficientnet/b0/feature-vector/1',
compat_tf_versions=[1, 2])

mobilenet_v2_spec = ImageModelSpec(
uri='https://tfhub.dev/google/tf2-preview/mobilenet_v2/feature_vector/4',
Expand All @@ -87,6 +89,30 @@ def __init__(self, uri, compat_tf_versions=None):
uri='https://tfhub.dev/google/imagenet/resnet_v2_50/feature_vector/4',
compat_tf_versions=2)

efficientnet_lite0_spec = ImageModelSpec(
uri='https://tfhub.dev/tensorflow/efficientnet/lite0/feature-vector/1',
compat_tf_versions=[1, 2])

efficientnet_lite1_spec = ImageModelSpec(
uri='https://tfhub.dev/tensorflow/efficientnet/lite1/feature-vector/1',
compat_tf_versions=[1, 2],
input_image_shape=[240, 240])

efficientnet_lite2_spec = ImageModelSpec(
uri='https://tfhub.dev/tensorflow/efficientnet/lite2/feature-vector/1',
compat_tf_versions=[1, 2],
input_image_shape=[260, 260])

efficientnet_lite3_spec = ImageModelSpec(
uri='https://tfhub.dev/tensorflow/efficientnet/lite3/feature-vector/1',
compat_tf_versions=[1, 2],
input_image_shape=[280, 280])

efficientnet_lite4_spec = ImageModelSpec(
uri='https://tfhub.dev/tensorflow/efficientnet/lite4/feature-vector/1',
compat_tf_versions=[1, 2],
input_image_shape=[300, 300])


class TextModelSpec(abc.ABC):
"""The abstract base class that constains the specification of text model."""
Expand Down Expand Up @@ -538,15 +564,22 @@ def set_shape(self, model):

# A dict for model specs to make it accessible by string key.
MODEL_SPECS = {
'efficientnet_b0': efficientnet_b0_spec,
'efficientnet_lite0': efficientnet_lite0_spec,
'efficientnet_lite1': efficientnet_lite1_spec,
'efficientnet_lite2': efficientnet_lite2_spec,
'efficientnet_lite3': efficientnet_lite3_spec,
'efficientnet_lite4': efficientnet_lite4_spec,
'mobilenet_v2': mobilenet_v2_spec,
'resnet_50': resnet_50_spec,
'average_word_vec': AverageWordVecModelSpec,
'bert': BertModelSpec,
}

# List constants for supported models.
IMAGE_CLASSIFICATION_MODELS = ['efficientnet_b0', 'mobilenet_v2', 'resnet_50']
IMAGE_CLASSIFICATION_MODELS = [
'efficientnet_lite0', 'efficientnet_lite1', 'efficientnet_lite2',
'efficientnet_lite3', 'efficientnet_lite4', 'mobilenet_v2', 'resnet_50'
]
TEXT_CLASSIFICATION_MODELS = ['bert', 'average_word_vec']


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@
"\n",
"from tensorflow_examples.lite.model_maker.core.data_util.image_dataloader import ImageClassifierDataLoader\n",
"from tensorflow_examples.lite.model_maker.core.task import image_classifier\n",
"from tensorflow_examples.lite.model_maker.core.task.model_spec import efficientnet_b0_spec\n",
"from tensorflow_examples.lite.model_maker.core.task.model_spec import mobilenet_v2_spec\n",
"from tensorflow_examples.lite.model_maker.core.task.model_spec import ImageModelSpec\n",
"\n",
"import matplotlib.pyplot as plt"
Expand Down Expand Up @@ -289,7 +289,7 @@
"source": [
"## Detailed Process\n",
"\n",
"Currently, we only include MobileNetV2 and EfficientNetB0 models as pre-trained models for image classification. But it is very flexible to add new pre-trained models to this library with just a few lines of code.\n",
"Currently, we support several models such as EfficientNet-Lite* models, MobileNetV2, ResNet50 as pre-trained models for image classification. But it is very flexible to add new pre-trained models to this library with just a few lines of code.\n",
"\n",
"\n",
"The following walks through this end-to-end example step by step to show more detail."
Expand Down Expand Up @@ -439,7 +439,7 @@
"source": [
"### Step 2: Customize the TensorFlow Model\n",
"\n",
"Create a custom image classifier model based on the loaded data. The default model is MobileNetV2.\n"
"Create a custom image classifier model based on the loaded data. The default model is EfficientNet-Lite0.\n"
]
},
{
Expand Down Expand Up @@ -665,7 +665,7 @@
"id": "fuHB-NFqpKTD"
},
"source": [
"Note that preprocessing for inference should be the same as training. Currently, preprocessing contains normalizing each pixel value and resizing the image according to the model's specification. For MobileNetV2, input image should be normalized to `[0, 1]` and resized to `[224, 224, 3]`."
"Note that preprocessing for inference should be the same as training. Currently, preprocessing contains normalizing each pixel value and resizing the image according to the model's specification. For EfficientNet-Lite0, input image should be normalized to `[0, 1]` and resized to `[224, 224, 3]`."
]
},
{
Expand All @@ -682,9 +682,9 @@
"The `create`function contains the following steps:\n",
"\n",
"1. Split the data into training, validation, testing data according to parameter `validation_ratio` and `test_ratio`. The default value of `validation_ratio` and `test_ratio` are `0.1` and `0.1`.\n",
"2. Download a [Image Feature Vector](https://www.tensorflow.org/hub/common_signatures/images#image_feature_vector) as the base model from TensorFlow Hub. The default pre-trained model is MobileNetV2.\n",
"2. Download a [Image Feature Vector](https://www.tensorflow.org/hub/common_signatures/images#image_feature_vector) as the base model from TensorFlow Hub. The default pre-trained model is EfficientNet-Lite0.\n",
"3. Add a classifier head with a Dropout Layer with `dropout_rate` between head layer and pre-trained model. The default `dropout_rate` is the default `dropout_rate` value from [make_image_classifier_lib](https://github.com/tensorflow/hub/blob/master/tensorflow_hub/tools/make_image_classifier/make_image_classifier_lib.py#L55) by TensorFlow Hub.\n",
"4. Preprocess the raw input data. Currently, preprocessing steps including normalizing the value of each image pixel to model input scale and resizing it to model input size. MobileNetV2 have the input scale `[0, 1]` and the input image size `[224, 224, 3]`.\n",
"4. Preprocess the raw input data. Currently, preprocessing steps including normalizing the value of each image pixel to model input scale and resizing it to model input size. EfficientNet-Lite0 have the input scale `[0, 1]` and the input image size `[224, 224, 3]`.\n",
"5. Feed the data into the classifier model. By default, the training parameters such as training epochs, batch size, learning rate, momentum are the default values from [make_image_classifier_lib](https://github.com/tensorflow/hub/blob/master/tensorflow_hub/tools/make_image_classifier/make_image_classifier_lib.py#L55) by TensorFlow Hub. Only the classifier head is trained.\n",
"\n",
"\n",
Expand All @@ -711,11 +711,9 @@
"source": [
"### Change to the model that's supported in this library.\n",
"\n",
"This library supports MobileNetV2 and EfficientNetB0 model by now. The default model is MobileNetV2.\n",
"This library supports EfficientNet-Lite models, MobileNetV2, ResNet50 by now. [EfficientNet-Lite](https://github.com/tensorflow/tpu/tree/master/models/official/efficientnet/lite) are a family of image classification models that could acheive state-of-art accuracy and suitable for Edge devices. The default model is EfficientNet-Lite0.\n",
"\n",
"[EfficientNets](https://github.com/tensorflow/tpu/tree/master/models/official/efficientnet) are a family of image classification models that could acheive state-of-art accuracy. EfficinetNetB0 is one of the EfficientNet models that's small and suitable for on-device applications. It's larger than MobileNetV2 while might achieve better performance.\n",
"\n",
"We could switch model to EfficientNetB0 by just setting parameter `model_spec` to `efficientnet_b0_spec` in `create` method."
"We could switch model to MobileNetV2 by just setting parameter `model_spec` to `mobilenet_v2_spec` in `create` method."
]
},
{
Expand All @@ -728,7 +726,7 @@
},
"outputs": [],
"source": [
"model = image_classifier.create(train_data, model_spec=efficientnet_b0_spec, validation_data=validation_data)"
"model = image_classifier.create(train_data, model_spec=mobilenet_v2_spec, validation_data=validation_data)"
]
},
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def download_demo_data(**kwargs):
def run(data_dir,
tflite_filename,
label_filename,
spec='efficientnet_b0',
spec='efficientnet_lite0',
**kwargs):
"""Runs demo."""
spec = model_spec.get(spec)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def test_image_classification_demo(self):
data_dir,
tflite_filename,
label_filename,
spec='efficientnet_b0',
spec='efficientnet_lite0',
epochs=1,
batch_size=1)
self.assertTrue(tf.io.gfile.exists(tflite_filename))
Expand Down

0 comments on commit a83cb97

Please sign in to comment.