# Loading a pre-trained model

This notebook was inspired by the *rock, scissors, paper* example in Coursera's excellent
[Convolutional Neural Networks in TensorFlow](https://www.coursera.org/learn/convolutional-neural-networks-tensorflow/) course.

It uses the saved model that students create in that course, and it makes additional use of Laurence Moroney's open source
[rock paper scissors dataset](http://www.laurencemoroney.com/rock-paper-scissors-dataset/)

The model has 12 layers and over 3.4 million parameters. It is just possible to train the model on a Jetson Nano, but it would take a very long time.

Instead, you will load a pre-trained model that has been stored as `rps.hp5` in the data directory.

Once the model has been loaded, you can use it to identify some of the test data images as rock, paper or scissors.

First, you'll download the test data locally and unzip it.

In [None]:
!wget --no-check-certificate \
    https://storage.googleapis.com/laurencemoroney-blog.appspot.com/rps-test-set.zip \
    -O rps-test-set.zip

In [None]:
import zipfile

local_zip = 'rps-test-set.zip'
zip_ref = zipfile.ZipFile(local_zip, 'r')
zip_ref.extractall('data/')
zip_ref.close()

Next, you will load the saved model and look at its structure.

In [None]:
import tensorflow as tf
import numpy as np
from tensorflow.keras.preprocessing import image
model = tf.keras.models.load_model('data/rps.h5')
model.summary()

Now you can load a test data image, get the model to clasify it, and display the image and its predicted class.

In [None]:
def get_image_file(i_type, number):
    return 'data/rps-test-set/%s/test%s01-%02d.png' % (i_type, i_type, number)

fn = get_image_file('paper', 7)
img = image.load_img(fn, target_size=(150, 150))
x = image.img_to_array(img)
x = np.expand_dims(x, axis=0)

images = np.vstack([x])
classes = model.predict(images, batch_size=10)
print(fn)
print(classes)

Finally, display the image which was classified above

In [None]:
%matplotlib inline

import matplotlib
import matplotlib.pyplot as plt

test_image = matplotlib.image.imread(fn)
plt.imshow(test_image)
plt.axis('Off')
plt.show()