## 1. Kerasモデル情報取得
- **PIPライブラリ追加:** pip3 install tensorflow

In [1]:
import tensorflow as tf
keras_model = tf.keras.models.load_model('./outputs/keras_simple.h5')

print("INPUT: ", keras_model.inputs)
print("OUTPUT: ", keras_model.outputs)

INPUT:  [<tf.Tensor 'input_1:0' shape=(None, 150, 150, 3) dtype=float32>]
OUTPUT:  [<tf.Tensor 'dense_1/Identity:0' shape=(None, 1) dtype=float32>]


## 2. onnxモデル作成
- **PIPライブラリ追加:** pip3 install onnx onnxmltools tf2onnx keras2onnx

In [5]:
# 環境変数設定
import os
os.environ['TF_KERAS'] = '1'

In [6]:
#ONNXに変換
import onnx
import keras2onnx

onnx_model_name = 'onnx_simple.onnx'
onnx_model = keras2onnx.convert_keras(keras_model, keras_model.name) 
onnx.save_model(onnx_model, onnx_model_name)

## 3. onnxmltoolsでonnxモデル作成

In [7]:
import onnxmltools

onnx_model_name = 'onnx_simple2.onnx'
onnx_model = onnxmltools.convert_keras(keras_model) #ONNXに変換
onnxmltools.utils.save_model(onnx_model, onnx_model_name)

## 4. onnxモデル情報取得
- **PIPライブラリ追加:** pip3 install onnx onnxruntime numpy

In [8]:
import numpy
import onnx
import onnxruntime as rt

def load_model():
    global sess
    global input_name
    global label_name
    
    sess = rt.InferenceSession("onnx_simple.onnx")
    input_name = sess.get_inputs()[0].name
    label_name = sess.get_outputs()[0].name

if __name__ == "__main__":
    load_model()
    print("INPUT: ", sess.get_inputs()[0])
    print("OUTPUT: ", sess.get_outputs()[0])

INPUT:  NodeArg(name='input_1', type='tensor(float)', shape=['N', 150, 150, 3])
OUTPUT:  NodeArg(name='dense_1', type='tensor(float)', shape=['N', 1])


## 5. ONNXモデル検証
- **PIPライブラリ追加:** pip3 install pillow

In [9]:
import os
from PIL import Image

def load_image(filename):
    img = Image.open(filename) # load image
    img = img.resize((150,150)) # resize image to 150x150
    img = numpy.asarray(img) # convert image to array use numpy
    img = img.reshape(1, 150, 150, 3) # reshape
    img = img.astype('float32') # astype
    return img
    
def run():
    try:
        for filename in os.listdir('./dog_cat_images/'):
            img = load_image('./dog_cat_images/' + filename)
            result = sess.run(
                [label_name], {input_name: img})[0]
            if result[0] == 0:
                os.rename('./dog_cat_images/' + filename,
                          './dog_cat_images/cat_' + filename)
            else:
                os.rename('./dog_cat_images/' + filename,
                          './dog_cat_images/dog_' + filename)
        print("FINISH")
    except Exception as e:
        print("EXCEPTION", e)

if __name__ == "__main__":
    load_model()
    run()

FINISH
