In [4]:
import tensorflow as tf
import numpy as np
from tensorflow import keras
from keras import layers
from PIL import Image

In [5]:
(train_data, train_labels), (test_data, test_labels) = keras.datasets.cifar10.load_data()

Downloading data from https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz


In [6]:
train_data = train_data.astype('float32') / 255.0
test_data = test_data.astype('float32') / 255.0

train_labels = keras.utils.to_categorical(train_labels, 10)
test_labels = keras.utils.to_categorical(test_labels, 10)

In [9]:
model = keras.Sequential([
    keras.Input(shape=(32, 32, 3), ),
    layers.Conv2D(32, kernel_size=(3, 3), activation='relu'),
    layers.MaxPooling2D(pool_size=(2, 2)),
    layers.Conv2D(64, kernel_size=(3, 3), activation='relu'),
    layers.MaxPooling2D(pool_size=(2, 2)),
    layers.Flatten(),
    layers.Dropout(0.5),
    layers.Dense(10, activation='softmax')
])

In [10]:
model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])

In [12]:
model.fit(train_data, train_labels, batch_size=64, epochs=10, validation_data = (test_data, test_labels))

model.save('cifar10_model.h5')

Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10


In [13]:
model = keras.models.load_model('cifar10_model.h5')

In [18]:
classes = ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']

In [19]:
img = Image.open('./dog.jpg')
img = img.resize((32,32))
img_array = np.array(img)
img_array = img_array.astype('float32') / 255.0
img_array = np.expand_dims(img_array, axis = 0)

In [20]:
pred = model.predict(img_array)
class_label = np.argmax(pred)
print("Prediction:", classes[class_label])

Prediction: dog


In [21]:
img = Image.open('./cat.jpg')
img = img.resize((32,32))
img_array = np.array(img)
img_array = img_array.astype('float32') / 255.0
img_array = np.expand_dims(img_array, axis = 0)

In [22]:
pred = model.predict(img_array)
class_label = np.argmax(pred)
print("Prediction:", classes[class_label])

Prediction: cat
