In [1]:
"""
This example demonstrates the workflow to download a publicly available TF 
model, strip part of it for inference, and convert it to CoreML using the 
tfcoreml converter. 

Stripping part of the TF model may be useful when:
(1) the TF model contains input data pre-processing mechanisms that are 
suitable for training / unsupported by CoreML
(2) the TF model has ops only used in training time

We use an inception v3 model provided by Google, which can be downloaded 
at this URL:

https://storage.googleapis.com/download.tensorflow.org/models/inception_dec_2015.zip
"""

import urllib, os, sys, zipfile
from os.path import dirname
import numpy as np
import tensorflow as tf
from tensorflow.core.framework import graph_pb2

In [2]:
# Download the model and class label package
def download_file_and_unzip(url, dir_path='.'):
    """Download the frozen TensorFlow model and unzip it.
    url - The URL address of the frozen file
    dir_path - local directory
    """
    if not os.path.exists(dir_path):
        os.makedirs(dir_path)
    k = url.rfind('/')
    fname = url[k+1:]
    fpath = os.path.join(dir_path, fname)

    if not os.path.exists(fpath):
        urllib.urlretrieve(url, fpath)
    zip_ref = zipfile.ZipFile(fpath, 'r')
    zip_ref.extractall(dir_path)
    zip_ref.close()

inception_v3_url = 'https://storage.googleapis.com/download.tensorflow.org/models/inception_dec_2015.zip'
download_file_and_unzip(inception_v3_url)

In [3]:
# Load the TF graph definition
tf_model_path = './tensorflow_inception_graph.pb'
with open(tf_model_path, 'rb') as f:
    serialized = f.read()
tf.reset_default_graph()
original_gdef = tf.GraphDef()
original_gdef.ParseFromString(serialized)

# For demonstration purpose we show the first 15 ops the TF model
with tf.Graph().as_default() as g:
    tf.import_graph_def(original_gdef, name='')
    ops = g.get_operations()
    for i in xrange(15):
        print('op id {} : op name: {}, op type: "{}"'.format(str(i),ops[i].name, ops[i].type));

# This Inception model uses DecodeJpeg op to read from JPEG images
# encoded as string Tensors. You can visualize it with TensorBoard,
# but we're omitting it here. For deployment we need to remove the
# JPEG decoder and related ops, and replace them with a placeholder
# where we can feed image data in. 

op id 0 : op name: DecodeJpeg/contents, op type: "Const"
op id 1 : op name: DecodeJpeg, op type: "DecodeJpeg"
op id 2 : op name: Cast, op type: "Cast"
op id 3 : op name: ExpandDims/dim, op type: "Const"
op id 4 : op name: ExpandDims, op type: "ExpandDims"
op id 5 : op name: ResizeBilinear/size, op type: "Const"
op id 6 : op name: ResizeBilinear, op type: "ResizeBilinear"
op id 7 : op name: Sub/y, op type: "Const"
op id 8 : op name: Sub, op type: "Sub"
op id 9 : op name: Mul/y, op type: "Const"
op id 10 : op name: Mul, op type: "Mul"
op id 11 : op name: conv/conv2d_params, op type: "Const"
op id 12 : op name: conv/Conv2D, op type: "Conv2D"
op id 13 : op name: conv/batchnorm/beta, op type: "Const"
op id 14 : op name: conv/batchnorm/gamma, op type: "Const"


In [4]:
# Strip the JPEG decoder and preprocessing part of TF model
# In this model, the actual op that feeds pre-processed image into 
# the network is 'Mul'. The op that generates probabilities per
# class is 'softmax/logits'
# To figure out what are inputs/outputs for your own model
# You can use su
# For your own model you'll need to figure this out yourself

from tensorflow.python.tools import strip_unused_lib
from tensorflow.python.framework import dtypes
from tensorflow.python.platform import gfile
input_node_names = ['Mul']
output_node_names = ['softmax/logits']
gdef = strip_unused_lib.strip_unused(
        input_graph_def = original_gdef,
        input_node_names = input_node_names,
        output_node_names = output_node_names,
        placeholder_type_enum = dtypes.float32.as_datatype_enum)
# Save it to an output file
frozen_model_file = './inception_v3.pb'
with gfile.GFile(frozen_model_file, "wb") as f:
    f.write(gdef.SerializeToString())

In [5]:
# Now we have a TF model ready to be converted to CoreML
import tfcoreml
# Supply a dictionary of input tensors' name and shape (with 
# batch axis)
input_tensor_shapes = {"Mul:0":[1,299,299,3]} # batch size is 1
# Output CoreML model path
coreml_model_file = './inception_v3.mlmodel'
# The TF model's ouput tensor name
output_tensor_names = ['softmax/logits:0']

# Call the converter
coreml_model = tfcoreml.convert(
        tf_model_path=frozen_model_file,
        mlmodel_path=coreml_model_file,
        input_name_shape_dict=input_tensor_shapes,
        output_feature_names=output_tensor_names,
        image_input_names = ['Mul:0'],
        red_bias = -1,
        green_bias = -1,
        blue_bias = -1,
        image_scale = 2.0/255.0)

# MLModel saved at location: ./inception_v3.mlmodel



Shapes not found for 506 tensors. Executing graph to determine shapes. 
Automatic shape interpretation succeeded for input blob Mul:0
485/993: Converting op name: Mul ( type:  Placeholder )
Skipping name of placeholder
486/993: Converting op name: conv/Conv2D ( type:  Conv2D )
487/993: Converting op name: conv/batchnorm ( type:  BatchNormWithGlobalNormalization )
488/993: Converting op name: conv/control_dependency ( type:  Identity )
489/993: Converting op name: conv ( type:  Relu )
490/993: Converting op name: conv_1/Conv2D ( type:  Conv2D )
491/993: Converting op name: conv_1/batchnorm ( type:  BatchNormWithGlobalNormalization )
492/993: Converting op name: conv_1/control_dependency ( type:  Identity )
493/993: Converting op name: conv_1 ( type:  Relu )
494/993: Converting op name: conv_2/Conv2D ( type:  Conv2D )
495/993: Converting op name: conv_2/batchnorm ( type:  BatchNormWithGlobalNormalization )
496/993: Converting op name: conv_2/control_dependency ( type:  Identity )
497/993

603/993: Converting op name: mixed_2/tower/conv/Conv2D ( type:  Conv2D )
604/993: Converting op name: mixed_2/tower/conv/batchnorm ( type:  BatchNormWithGlobalNormalization )
605/993: Converting op name: mixed_2/tower/conv/control_dependency ( type:  Identity )
606/993: Converting op name: mixed_2/tower/conv ( type:  Relu )
607/993: Converting op name: mixed_2/tower/conv_1/Conv2D ( type:  Conv2D )
608/993: Converting op name: mixed_2/tower/conv_1/batchnorm ( type:  BatchNormWithGlobalNormalization )
609/993: Converting op name: mixed_2/tower/conv_1/control_dependency ( type:  Identity )
610/993: Converting op name: mixed_2/tower/conv_1 ( type:  Relu )
611/993: Converting op name: mixed_2/tower/conv_1/CheckNumerics ( type:  CheckNumerics )
612/993: Converting op name: mixed_2/tower/conv/CheckNumerics ( type:  CheckNumerics )
613/993: Converting op name: mixed_2/conv/Conv2D ( type:  Conv2D )
614/993: Converting op name: mixed_2/conv/batchnorm ( type:  BatchNormWithGlobalNormalization )
6

703/993: Converting op name: mixed_5/tower_1/conv_1/batchnorm ( type:  BatchNormWithGlobalNormalization )
704/993: Converting op name: mixed_5/tower_1/conv_1/control_dependency ( type:  Identity )
705/993: Converting op name: mixed_5/tower_1/conv_1 ( type:  Relu )
706/993: Converting op name: mixed_5/tower_1/conv_2/Conv2D ( type:  Conv2D )
707/993: Converting op name: mixed_5/tower_1/conv_2/batchnorm ( type:  BatchNormWithGlobalNormalization )
708/993: Converting op name: mixed_5/tower_1/conv_2/control_dependency ( type:  Identity )
709/993: Converting op name: mixed_5/tower_1/conv_2 ( type:  Relu )
710/993: Converting op name: mixed_5/tower_1/conv_3/Conv2D ( type:  Conv2D )
711/993: Converting op name: mixed_5/tower_1/conv_3/batchnorm ( type:  BatchNormWithGlobalNormalization )
712/993: Converting op name: mixed_5/tower_1/conv_3/control_dependency ( type:  Identity )
713/993: Converting op name: mixed_5/tower_1/conv_3 ( type:  Relu )
714/993: Converting op name: mixed_5/tower_1/conv_4

801/993: Converting op name: mixed_7/tower_1/conv/batchnorm ( type:  BatchNormWithGlobalNormalization )
802/993: Converting op name: mixed_7/tower_1/conv/control_dependency ( type:  Identity )
803/993: Converting op name: mixed_7/tower_1/conv ( type:  Relu )
804/993: Converting op name: mixed_7/tower_1/conv_1/Conv2D ( type:  Conv2D )
805/993: Converting op name: mixed_7/tower_1/conv_1/batchnorm ( type:  BatchNormWithGlobalNormalization )
806/993: Converting op name: mixed_7/tower_1/conv_1/control_dependency ( type:  Identity )
807/993: Converting op name: mixed_7/tower_1/conv_1 ( type:  Relu )
808/993: Converting op name: mixed_7/tower_1/conv_2/Conv2D ( type:  Conv2D )
809/993: Converting op name: mixed_7/tower_1/conv_2/batchnorm ( type:  BatchNormWithGlobalNormalization )
810/993: Converting op name: mixed_7/tower_1/conv_2/control_dependency ( type:  Identity )
811/993: Converting op name: mixed_7/tower_1/conv_2 ( type:  Relu )
812/993: Converting op name: mixed_7/tower_1/conv_3/Conv2

904/993: Converting op name: mixed_9/tower/conv/batchnorm ( type:  BatchNormWithGlobalNormalization )
905/993: Converting op name: mixed_9/tower/conv/control_dependency ( type:  Identity )
906/993: Converting op name: mixed_9/tower/conv ( type:  Relu )
907/993: Converting op name: mixed_9/tower/mixed/conv_1/Conv2D ( type:  Conv2D )
908/993: Converting op name: mixed_9/tower/mixed/conv_1/batchnorm ( type:  BatchNormWithGlobalNormalization )
909/993: Converting op name: mixed_9/tower/mixed/conv_1/control_dependency ( type:  Identity )
910/993: Converting op name: mixed_9/tower/mixed/conv_1 ( type:  Relu )
911/993: Converting op name: mixed_9/tower/mixed/conv_1/CheckNumerics ( type:  CheckNumerics )
912/993: Converting op name: mixed_9/tower/mixed/conv/Conv2D ( type:  Conv2D )
913/993: Converting op name: mixed_9/tower/mixed/conv/batchnorm ( type:  BatchNormWithGlobalNormalization )
914/993: Converting op name: mixed_9/tower/mixed/conv/control_dependency ( type:  Identity )
915/993: Conve

In [None]:
# Now we're ready to test out the CoreML model with a real image!
# Load an image
import PIL
import requests
from io import BytesIO
from matplotlib.pyplot import imshow
# This is an image of a golden retriever from Wikipedia
img_url = 'https://upload.wikimedia.org/wikipedia/commons/9/93/Golden_Retriever_Carlos_%2810581910556%29.jpg'
response = requests.get(img_url)
%matplotlib inline
img = PIL.Image.open(BytesIO(response.content))
imshow(np.asarray(img))


In [13]:
# Run CoreML prediction
# Pay attention to '__0'. We change ':0' to '__0' to make sure 
# MLModel's generated Swift/Obj-C code is semantically correct
img = img.resize([299,299], PIL.Image.ANTIALIAS)
coreml_inputs = {'Mul__0': img}
coreml_output = coreml_model.predict(coreml_inputs, useCPUOnly=False)
probs = coreml_output['softmax__logits__0'].flatten()
label_idx = np.argmax(probs)

# This label file comes with the model
label_file = 'imagenet_comp_graph_label_strings.txt' 
with open(label_file) as f:
    labels = f.readlines()
print('Label = {}'.format(labels[label_idx]))

Label = golden retriever



In [None]:
# And that's the end