In [1]:
import tensorflow as tf 
import os 
import numpy as np
import glob 

In [2]:
strategy = tf.distribute.MirroredStrategy()
print('Number of devices: {}'.format(strategy.num_replicas_in_sync))

INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0', '/job:localhost/replica:0/task:0/device:GPU:1', '/job:localhost/replica:0/task:0/device:GPU:2', '/job:localhost/replica:0/task:0/device:GPU:3')
Number of devices: 4


In [3]:
datadir = "/data/fourview"
print(datadir + '/*/train*')

/data/fourview/*/train*


In [5]:
IMAGE_SIZE = 224
BATCH_SIZE = 64

datagen = tf.keras.preprocessing.image.ImageDataGenerator(
    rescale=1./255, 
    validation_split=0.2)

train_generator = datagen.flow_from_directory(
    datadir,
    target_size=(IMAGE_SIZE, IMAGE_SIZE),
    batch_size=BATCH_SIZE, 
    subset='training')

val_generator = datagen.flow_from_directory(
    datadir,
    target_size=(IMAGE_SIZE, IMAGE_SIZE),
    batch_size=BATCH_SIZE, 
    subset='validation')

Found 18624 images belonging to 49 classes.
Found 4656 images belonging to 49 classes.


In [6]:
for imgs, labels in train_generator:
    print(imgs.shape)
    print(labels.shape)
    print(type(labels))
    print(type(imgs))

(64, 224, 224, 3)
(64, 49)
<class 'numpy.ndarray'>
<class 'numpy.ndarray'>
(64, 224, 224, 3)
(64, 49)
<class 'numpy.ndarray'>
<class 'numpy.ndarray'>
(64, 224, 224, 3)
(64, 49)
<class 'numpy.ndarray'>
<class 'numpy.ndarray'>
(64, 224, 224, 3)
(64, 49)
<class 'numpy.ndarray'>
<class 'numpy.ndarray'>
(64, 224, 224, 3)
(64, 49)
<class 'numpy.ndarray'>
<class 'numpy.ndarray'>
(64, 224, 224, 3)
(64, 49)
<class 'numpy.ndarray'>
<class 'numpy.ndarray'>
(64, 224, 224, 3)
(64, 49)
<class 'numpy.ndarray'>
<class 'numpy.ndarray'>
(64, 224, 224, 3)
(64, 49)
<class 'numpy.ndarray'>
<class 'numpy.ndarray'>
(64, 224, 224, 3)
(64, 49)
<class 'numpy.ndarray'>
<class 'numpy.ndarray'>
(64, 224, 224, 3)
(64, 49)
<class 'numpy.ndarray'>
<class 'numpy.ndarray'>
(64, 224, 224, 3)
(64, 49)
<class 'numpy.ndarray'>
<class 'numpy.ndarray'>
(64, 224, 224, 3)
(64, 49)
<class 'numpy.ndarray'>
<class 'numpy.ndarray'>
(64, 224, 224, 3)
(64, 49)
<class 'numpy.ndarray'>
<class 'numpy.ndarray'>
(64, 224, 224, 3)
(64, 49

KeyboardInterrupt: 

In [7]:
print (train_generator.class_indices)

labels = '\n'.join(sorted(train_generator.class_indices.keys()))

with open('labels.txt', 'w') as f:
    f.write(labels)

{'029470000105': 0, '089686180657': 1, '5011157996592': 2, '80042532': 3, '8886467103506': 4, '8901042955988': 5, '9300617074410': 6, '9300619513085': 7, '9300631448754': 8, '9300632012268': 9, '9300633005184': 10, '9300633292355': 11, '9300633320775': 12, '9300633557072': 13, '9300633719234': 14, '9300633945138': 15, '9300644706704': 16, '9300657820251': 17, '9300701880866': 18, '9300830057504': 19, '9310072001777': 20, '9310088013184': 21, '9310140283807': 22, '9310155414814': 23, '9310432001454': 24, '9310998101018': 25, '9312631168853': 26, '9339687011254': 27, '9400097041145': 28, '9400547020232': 29, '9400574005158': 30, '9400574005752': 31, '9400593002411': 32, '9400597018715': 33, '9400597027335': 34, '9403102000922': 35, '9403110063544': 36, '9403110065944': 37, '9415022032990': 38, '9415107116447': 39, '9415107225309': 40, '9415187009318': 41, '9415317003131': 42, '9415767677203': 43, '9416050524914': 44, '9416050588718': 45, '9418315127380': 46, '9420039513295': 47, '9421902

In [17]:
IMG_SHAPE = (IMAGE_SIZE, IMAGE_SIZE, 3)

with strategy.scope():
    base_model = tf.keras.applications.EfficientNetB0(input_shape=IMG_SHAPE,
                                              include_top=False, 
                                              weights='imagenet')
#     base_model.trainable = False
    model = tf.keras.Sequential([
      base_model,
      tf.keras.layers.Conv2D(filters=32, kernel_size=3, activation='relu'),
      tf.keras.layers.Dropout(0.2),
      tf.keras.layers.GlobalAveragePooling2D(),
      tf.keras.layers.Dense(units=49, activation='softmax')
    ])
    
    model.compile(optimizer='adam', 
              loss='categorical_crossentropy', 
              metrics=['accuracy'])

Downloading data from https://storage.googleapis.com/keras-applications/efficientnetb0_notop.h5


In [18]:
model.summary()

Model: "sequential_3"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
efficientnetb0 (Functional)  (None, 7, 7, 1280)        4049571   
_________________________________________________________________
conv2d_3 (Conv2D)            (None, 5, 5, 32)          368672    
_________________________________________________________________
dropout_7 (Dropout)          (None, 5, 5, 32)          0         
_________________________________________________________________
global_average_pooling2d_3 ( (None, 32)                0         
_________________________________________________________________
dense_9 (Dense)              (None, 49)                1617      
Total params: 4,419,860
Trainable params: 4,377,837
Non-trainable params: 42,023
_________________________________________________________________


In [19]:
with strategy.scope():
    history = model.fit(train_generator,
                        steps_per_epoch=len(train_generator), 
                        epochs=2,
                        validation_data=val_generator,
                        validation_steps=len(val_generator))

Epoch 1/2
INFO:tensorflow:batch_all_reduce: 215 all-reduces with algorithm = nccl, num_packs = 1
INFO:tensorflow:batch_all_reduce: 215 all-reduces with algorithm = nccl, num_packs = 1
Epoch 2/2


# Convert to uint8 tflite model

In [20]:
converter = tf.lite.TFLiteConverter.from_keras_model(model)
tflite_model = converter.convert()

with open('mobilenet_v2_1.0_224.tflite', 'wb') as f:
    f.write(tflite_model)

Instructions for updating:
This property should not be used in TensorFlow 2.0, as updates are applied automatically.
Instructions for updating:
This property should not be used in TensorFlow 2.0, as updates are applied automatically.
INFO:tensorflow:Assets written to: /tmp/tmpszqvv5mk/assets


In [21]:
# A generator that provides a representative dataset
def representative_data_gen():
    dataset_list = tf.data.Dataset.list_files(datadir + '/*/*')
    for i in range(100):
        image = next(iter(dataset_list))
        image = tf.io.read_file(image)
        image = tf.io.decode_jpeg(image, channels=3)
        image = tf.image.resize(image, [IMAGE_SIZE, IMAGE_SIZE])
        image = tf.cast(image / 255., tf.float32)
        image = tf.expand_dims(image, 0)
        yield [image]

converter = tf.lite.TFLiteConverter.from_keras_model(model)
# This enables quantization
converter.optimizations = [tf.lite.Optimize.DEFAULT]
# This sets the representative dataset for quantization
converter.representative_dataset = representative_data_gen
# This ensures that if any ops can't be quantized, the converter throws an error
converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8]
# For full integer quantization, though supported types defaults to int8 only, we explicitly declare it for clarity.
converter.target_spec.supported_types = [tf.int8]
# These set the input and output tensors to uint8 (added in r2.3)
converter.inference_input_type = tf.uint8
converter.inference_output_type = tf.uint8
tflite_model = converter.convert()

with open('mobilenet_v2_1.0_224_quant.tflite', 'wb') as f:
    f.write(tflite_model)

INFO:tensorflow:Assets written to: /tmp/tmpeto15a_d/assets


INFO:tensorflow:Assets written to: /tmp/tmpeto15a_d/assets


# Validate the raw model and tflite quantized model

In [22]:
batch_images, batch_labels = next(val_generator)

logits = model(batch_images)
prediction = np.argmax(logits, axis=1)
truth = np.argmax(batch_labels, axis=1)

keras_accuracy = tf.keras.metrics.Accuracy()
keras_accuracy(prediction, truth)

print("Raw model accuracy: {:.3%}".format(keras_accuracy.result()))

Raw model accuracy: 3.125%


In [23]:
def set_input_tensor(interpreter, input):
    input_details = interpreter.get_input_details()[0]
    tensor_index = input_details['index']
    input_tensor = interpreter.tensor(tensor_index)()[0]
    # Inputs for the TFLite model must be uint8, so we quantize our input data.
    # NOTE: This step is necessary only because we're receiving input data from
    # ImageDataGenerator, which rescaled all image data to float [0,1]. When using
    # bitmap inputs, they're already uint8 [0,255] so this can be replaced with:
    #   input_tensor[:, :] = input
    scale, zero_point = input_details['quantization']
    input_tensor[:, :] = np.uint8(input / scale + zero_point)

def classify_image(interpreter, input):
    set_input_tensor(interpreter, input)
    interpreter.invoke()
    output_details = interpreter.get_output_details()[0]
    output = interpreter.get_tensor(output_details['index'])
    # Outputs from the TFLite model are uint8, so we dequantize the results:
    scale, zero_point = output_details['quantization']
    output = scale * (output - zero_point)
    top_1 = np.argmax(output)
    return top_1

interpreter = tf.lite.Interpreter('mobilenet_v2_1.0_224_quant.tflite')
interpreter.allocate_tensors()

# Collect all inference predictions in a list
batch_prediction = []
batch_truth = np.argmax(batch_labels, axis=1)

for i in range(len(batch_images)):
    prediction = classify_image(interpreter, batch_images[i])
    batch_prediction.append(prediction)

# Compare all predictions to the ground truth
tflite_accuracy = tf.keras.metrics.Accuracy()
tflite_accuracy(batch_prediction, batch_truth)
print("Quant TF Lite accuracy: {:.3%}".format(tflite_accuracy.result()))

Quant TF Lite accuracy: 1.562%


# Export to edgetpu model

In [24]:
! edgetpu_compiler mobilenet_v2_1.0_224_quant.tflite

Edge TPU Compiler version 16.0.384591198
Started a compilation timeout timer of 180 seconds.
ERROR: Attempting to use a delegate that only supports static-sized tensors with a graph that has dynamic-sized tensors.
Compilation failed: Model failed in Tflite interpreter. Please ensure model can be loaded/run in Tflite interpreter.
Compilation child process completed within timeout period.
Compilation failed! 


In [24]:
!python3 classify_image.py \
  --model mobilenet_v2_1.0_224_quant_edgetpu.tflite \
  --label labels.txt \
  --input /data/9415767677203/val_72.jpg

----INFERENCE TIME----
Note: The first inference on Edge TPU is slow because it includes loading the model into Edge TPU memory.
9.4ms
2.4ms
2.5ms
2.5ms
2.4ms
-------RESULTS--------
9300657820251: 0.19141
