# Image classification model training

Let's build a deep learning image classification model using Keras / Tensorflow, and convert it to ONNX format.

In [None]:
import json
from pathlib import Path

import keras
import numpy as np
import onnx
import onnxruntime as ort
import requests
import tf2onnx

Load a ResNet50 image classification model (based on the ImageNet specifications):

In [None]:
model = keras.applications.resnet50.ResNet50()

Load an example image tensor of shape (1, 224, 224, 3):

In [None]:
img_path = "cat.jpg"
img = keras.utils.load_img(img_path, target_size=(224, 224))
x = keras.utils.img_to_array(img)
x = np.expand_dims(x, axis=0)
img

Try the model on the sample image and decode the top 3 predicted classes:

In [None]:
preds = model.predict(x)
keras.applications.resnet50.decode_predictions(preds, top=3)

Extract the ImageNet class index JSON file:

In [None]:
rv = requests.get(
    "https://storage.googleapis.com/download.tensorflow.org/data/imagenet_class_index.json"
)
imagenet_class_index = rv.json()
(Path() / "imagenet_class_index.json").write_text(json.dumps(imagenet_class_index))

Export the model to ONNX format using `tf2onnx`:

In [None]:
onnx_model, _ = tf2onnx.convert.from_keras(model)
onnx_model_path = Path() / "model.onnx"
onnx.save(onnx_model, "model.onnx")

Load the ONNX model, run an inference on the sample image and compute the predicted class name:

In [None]:
session = ort.InferenceSession(onnx_model_path, providers=ort.get_available_providers())
output = session.run(None, {"input_1": x})[0]
imagenet_class_index[str(np.argmax(output[0]))][1]