# Load MobilenetV1 TF1 Checkpoint to TF2 Keras.

In [14]:
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# =============================================================================

In [11]:
import os
import sys
from typing import Text, List, Dict, Tuple, Callable

import tensorflow as tf

In [12]:
# Modify the PYTHONPATH
root = os.path.abspath('../../../')
sys.path.append(root)

In [17]:
from research.mobilenet.configs import archs
from research.mobilenet.mobilenet_v1_model import mobilenet_v1
from research.mobilenet.tf1_loader import mobilenet_v1_loader
from research.mobilenet.mobilenet_trainer import _get_dataset_config, _get_metrics, get_dataset

## Download Checkpiont

In [15]:
## Download kfctl v0.7.0
! curl -LO http://download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_1.0_224.tgz

  % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
                                 Dload  Upload   Total   Spent    Left  Speed
100 89.9M  100 89.9M    0     0  8528k      0  0:00:10  0:00:10 --:--:-- 8749k:02  0:00:14 5443k
x ./
x ./mobilenet_v1_1.0_224.tflite
x ./mobilenet_v1_1.0_224.ckpt.meta
x ./mobilenet_v1_1.0_224.ckpt.index
x ./mobilenet_v1_1.0_224.ckpt.data-00000-of-00001
x ./mobilenet_v1_1.0_224_info.txt
x ./mobilenet_v1_1.0_224_frozen.pb
x ./mobilenet_v1_1.0_224_eval.pbtxt


In [22]:
## Unpack the tar ball
! mkdir ./checkpoints
! tar -xvf mobilenet_v1_1.0_224.tgz -C ./checkpoints

x ./
x ./mobilenet_v1_1.0_224.tflite
x ./mobilenet_v1_1.0_224.ckpt.meta
x ./mobilenet_v1_1.0_224.ckpt.index
x ./mobilenet_v1_1.0_224.ckpt.data-00000-of-00001
x ./mobilenet_v1_1.0_224_info.txt
x ./mobilenet_v1_1.0_224_frozen.pb
x ./mobilenet_v1_1.0_224_eval.pbtxt


In [23]:
source_checkpoint = './checkpoints/mobilenet_v1_1.0_224.ckpt'

## Restore TF2 Keras Model from TF1 Checkpoint

In [24]:
m_config = archs.MobileNetV1Config()
d_config = _get_dataset_config().get("imagenette")()

# create a TF2 Keras model, and load weights from TF1 checkpoint
keras_model = mobilenet_v1_loader.load_mobilenet_v1(
    checkpoint_path=source_checkpoint,
    config=m_config)

# compile Keras model
if d_config.one_hot:
    loss_obj = tf.keras.losses.CategoricalCrossentropy()
else:
    loss_obj = tf.keras.losses.SparseCategoricalCrossentropy()

keras_model.compile(
    optimizer='rmsprop',
    loss=loss_obj,
    metrics=[_get_metrics(one_hot=d_config.one_hot)['acc']])

In [25]:
keras_model.summary()

Model: "MobileNetV1"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
Input (InputLayer)           [(None, 224, 224, 3)]     0         
_________________________________________________________________
Conv2d_0 (Conv2D)            (None, 112, 112, 32)      864       
_________________________________________________________________
Conv2d_0/batch_norm (BatchNo (None, 112, 112, 32)      128       
_________________________________________________________________
Conv2d_0/relu6 (Activation)  (None, 112, 112, 32)      0         
_________________________________________________________________
Conv2d_1/depthwise (Depthwis (None, 112, 112, 32)      288       
_________________________________________________________________
Conv2d_1/depthwise/batch_nor (None, 112, 112, 32)      128       
_________________________________________________________________
Conv2d_1/depthwise/relu6 (Ac (None, 112, 112, 32)      

## Save TF2 Compatible Checkpoint

In [26]:
save_path = './mobilenet_v1_ck'

checkpoint = tf.train.Checkpoint(model=keras_model)
manager = tf.train.CheckpointManager(checkpoint,
                                     directory=save_path,
                                     max_to_keep=1)
manager.save()

'./mobilenet_v1_ck/ckpt-1'

## Run Evaluation

In [13]:
# build evaluation dataset
d_config.split = 'validation'
d_config.batch_size = 128
d_config.one_hot = False
d_config.data_dir = '[data_dir]' # make sure you have downloaded the imagenet data in TFRecords format

# the checkpoint is trained using slim
eval_dataset = get_dataset(d_config, slim_preprocess=True)

In [None]:
# run evaluation
eval_result = keras_model.evaluate(eval_dataset)

## Test Prediction on imagenette

In [27]:
d_config = _get_dataset_config().get("imagenette")()
d_config.split = 'validation'
eval_dataset = get_dataset(d_config)
for batch in eval_dataset.take(1):
    data, label = batch[0], batch[1]

In [28]:
keras_model.predict(data)

array([[1.9183963e-09, 4.1394728e-06, 3.9897436e-08, ..., 7.5235895e-10,
        3.7643774e-06, 4.6054534e-09],
       [5.5527597e-08, 1.3980379e-07, 1.3633577e-06, ..., 5.2619344e-09,
        4.1898606e-06, 2.4916830e-05],
       [2.8108006e-13, 1.0193365e-11, 1.5710981e-11, ..., 4.1460992e-13,
        4.6441356e-10, 1.6905852e-10],
       ...,
       [3.1768695e-11, 6.0316262e-12, 8.5546084e-11, ..., 3.0574855e-11,
        1.1098102e-09, 2.0510100e-09],
       [4.1738550e-11, 5.0156116e-12, 1.9647657e-10, ..., 1.1356447e-11,
        3.9844417e-09, 3.6098619e-08],
       [2.0708221e-07, 9.9771933e-07, 6.3004092e-07, ..., 1.4773001e-07,
        1.2141989e-05, 2.9931611e-05]], dtype=float32)