In [1]:
import tensorflow as tf
from tensorflow import keras
import os

base_dir = './dataset/train'

In [2]:
IMAGE_SIZE = 224
BATCH_SIZE = 64

datagen = tf.keras.preprocessing.image.ImageDataGenerator(
    rescale = 1./255,
    validation_split=0.2
)

train_generator = datagen.flow_from_directory(
    base_dir,
    target_size=(IMAGE_SIZE, IMAGE_SIZE),
    batch_size = BATCH_SIZE,
    subset='training'
)
val_generator = datagen.flow_from_directory(  #validation generator
    base_dir, 
    target_size=(IMAGE_SIZE, IMAGE_SIZE),
    batch_size=BATCH_SIZE,
    subset='validation'
)

Found 829 images belonging to 5 classes.
Found 205 images belonging to 5 classes.


In [3]:
print(train_generator.class_indices)
labels = '\n'.join(sorted(train_generator.class_indices.keys()))
with open('labels.txt', 'w') as f:
    f.write(labels)

{'Elephant': 0, 'Kangaroo': 1, 'Panda': 2, 'Penguin': 3, 'Tiger': 4}


In [4]:
IMG_SHAPE = (IMAGE_SIZE, IMAGE_SIZE, 3) 
base_model = tf.keras.applications.MobileNetV2(
    input_shape=IMG_SHAPE,
    include_top=False,
    weights='imagenet'
)

Instructions for updating:
Call initializer instance with the dtype argument instead of passing it to the constructor


In [5]:
base_model.trainable=False
model = tf.keras.Sequential([
  base_model,
  tf.keras.layers.Conv2D(32,3, activation = 'relu'),
  tf.keras.layers.Dropout(0.2),
  tf.keras.layers.GlobalAveragePooling2D(),
  tf.keras.layers.Dense(5, #no.of classes
                        activation='softmax')
])

In [6]:
model.compile(
    optimizer=tf.keras.optimizers.Adam(),
    loss='categorical_crossentropy', 
    metrics=['accuracy']
)

In [7]:
epochs = 10
history = model.fit(
    train_generator, 
    epochs = epochs, 
    validation_data=val_generator
)

Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10


In [8]:
keras_file = 'predict.h5'
keras.models.save_model(model, keras_file)

In [9]:
converter = tf.lite.TFLiteConverter.from_keras_model_file(keras_file)



In [10]:
tflite_model = converter.convert()
open('model.tflite', 'wb').write(tflite_model)

10328044