In [None]:
from keras.applications import Xception
from keras.models import Model
from keras.layers import Dense, GlobalAveragePooling2D, BatchNormalization
from keras.preprocessing.image import ImageDataGenerator
import tensorflow as tf
from keras import layers


In [None]:
img_width=299
img_height=299
batch_size=16
num_of_classes=70
IMG_SIZE=299


training_set_directory = "/content/drive/MyDrive/Datasets/polishedGemstones/train"
validation_set_directory = "/content/drive/MyDrive/Datasets/polishedGemstones/val"
test_set_directory = "/content/drive/MyDrive/Datasets/polishedGemstones/test"


In [None]:
train_data = ImageDataGenerator(
    rotation_range=15,
    width_shift_range=0.1,
    height_shift_range=0.1,
    horizontal_flip=True,
                                   )

validation_data = ImageDataGenerator (
                                      horizontal_flip=True,
                                      vertical_flip = True,
                                    )

test_data = ImageDataGenerator ( rescale=1.0/255.0,
                                      horizontal_flip=True,
                                      vertical_flip = True,
                                    )


In [None]:
training_gen = train_data.flow_from_directory(training_set_directory,
                                                 batch_size = batch_size,
                                                 class_mode = 'categorical',
                                                 shuffle = True,
                                                 target_size = (img_height, img_width),
                                           # subset='training')
)

validation_gen = validation_data.flow_from_directory(validation_set_directory,
                                                     batch_size = batch_size,
                                                     class_mode = 'categorical',
                                                     shuffle = True,
                                                     target_size = (img_height, img_width),
                                                     #subset='validation')
)

test_gen = test_data.flow_from_directory(test_set_directory,
                                                     batch_size = batch_size,
                                                     class_mode = 'categorical',
                                                     shuffle = True,
                                                     target_size = (img_height, img_width),
                                                     #subset='validation')
)

Found 6542 images belonging to 70 classes.
Found 1857 images belonging to 70 classes.
Found 1000 images belonging to 70 classes.


In [None]:
def build_model(num_classes):
    inputs = layers.Input(shape=(IMG_SIZE, IMG_SIZE, 3))
    x = tf.keras.applications.xception.preprocess_input(inputs)
    base_model = tf.keras.applications.Xception(weights='imagenet', include_top=False, input_shape=(IMG_SIZE, IMG_SIZE, 3))


    base_model.trainable = False


    x = base_model(x, training=False)
    x = layers.GlobalAveragePooling2D(name="avg_pool")(x)
    x = layers.BatchNormalization()(x)


    x = layers.Dense(512, activation="relu")(x)
    x = layers.BatchNormalization()(x)
    x = layers.Dropout(0.5)(x)

    x = layers.Dense(256, activation="relu")(x)
    x = layers.BatchNormalization()(x)
    x = layers.Dropout(0.5)(x)


    outputs = layers.Dense(num_classes, activation="softmax", name="pred")(x)


    model = tf.keras.Model(inputs, outputs, name="Xception")
    optimizer = tf.keras.optimizers.Adam(learning_rate=1e-2)
    model.compile(
        optimizer=optimizer, loss="categorical_crossentropy", metrics=["accuracy"]
    )
    return model

In [None]:
checkpoint_filepath = '/tmp/checkpoint'
model_checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(
    filepath=checkpoint_filepath,
    save_weights_only=True,
    monitor='val_accuracy',
    mode='max',
    save_best_only=True)

In [None]:
model = build_model(70)
for layer in model.layers:
    if 'batch_normalization' not in layer.name:
        layer.trainable = True
hist = model.fit(training_gen, epochs=30, validation_data=validation_gen, verbose=1, callbacks=[model_checkpoint_callback])

In [None]:
checkpoint_filepath2 = '/tmp/checkpoint2'
model_checkpoint_callback2 = tf.keras.callbacks.ModelCheckpoint(
    filepath=checkpoint_filepath2,
    save_weights_only=True,
    monitor='val_accuracy',
    mode='max',
    save_best_only=True)

In [None]:
for layer in model.layers:
    if 'batch_normalization' not in layer.name:
        layer.trainable = True

In [None]:
optimizer = tf.keras.optimizers.Adam(learning_rate=1e-4)

model.compile(
        optimizer=optimizer, loss="categorical_crossentropy", metrics=["accuracy"]
    )

In [None]:
model.load_weights(checkpoint_filepath)

<tensorflow.python.checkpoint.checkpoint.CheckpointLoadStatus at 0x7f164005f7f0>

In [None]:
epochs = 50
hist = model.fit(training_gen, epochs=epochs, validation_data=validation_gen, verbose=1, callbacks=[model_checkpoint_callback2])