In [2]:
### beaware how you load custom model which saved in keras

import tensorflow as tf
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.optimizers import SGD
from tensorflow.keras.callbacks import ModelCheckpoint
from tensorflow.keras import layers, models
import os
import numpy as np
import albumentations as A
import random
import pickle
import matplotlib.pyplot as plt
from tensorflow.keras.layers import Dense, Lambda


# parameters
model_path = 'O:/project/epochs/mobile_adam_efined50_newmodel_freezed/model_epoch_300.keras'
train_dir = 'O:/project/dataset/train'
val_dir = 'O:/project/dataset/val'
checkpoint_dir = 'O:/project/epochs/mobile_adam_efined50_newmodel_final'
base_model_path = 'O:/project/epochs/check_mobilenetv2_adam_4_e50_fined_50/model_epoch_50.keras'
npy_path = 'O:/project/epochs/mobile_adam_efined50_newmodel_freezedignored_nodes.npy'
ignored_nodes = np.load(npy_path)


img_height, img_width = 160, 160
batch_size = 64
epochs = 50
initial_learning_rate = 1e-5

# functions

def albumentations_preprocessing(image):
    aug_list = [
        A.GaussianBlur(p=0.0),  # original image
        A.GaussianBlur(p=0.0),  # original image
        A.GaussianBlur(p=1.0),
        A.CoarseDropout(max_holes=2, max_height=16, max_width=16, p=1.0),
        A.GaussNoise(var_limit=(10.0, 50.0), p=1.0),
        A.RandomBrightnessContrast(brightness_limit=0.1, contrast_limit=0.1, p=1.0),
        A.ChannelShuffle(p=1.0)
    ]

    augmentation = random.choice(aug_list)
    transform = A.Compose([augmentation])
    augmented = transform(image=np.array(image))
    return augmented['image']

def plot_history(history):
    acc = history['accuracy']
    val_acc = history['val_accuracy']
    loss = history['loss']
    val_loss = history['val_loss']
    epochs_range = range(1, len(acc) + 1)

    plt.figure(figsize=(14, 5))

    plt.subplot(1, 2, 1)
    plt.plot(epochs_range, acc, label='Training Accuracy')
    plt.plot(epochs_range, val_acc, label='Validation Accuracy')
    plt.legend(loc='lower right')
    plt.title('Training and Validation Accuracy')

    plt.subplot(1, 2, 2)
    plt.plot(epochs_range, loss, label='Training Loss')
    plt.plot(epochs_range, val_loss, label='Validation Loss')
    plt.legend(loc='upper right')
    plt.title('Training and Validation Loss')

    plt.show()

# main code

train_datagen = ImageDataGenerator(
    rescale=1./255,
    preprocessing_function=albumentations_preprocessing
)

train_generator = train_datagen.flow_from_directory(
    train_dir,
    target_size=(img_height, img_width),
    batch_size=batch_size,
    class_mode='categorical')

val_datagen = ImageDataGenerator(rescale=1./255)



val_generator = val_datagen.flow_from_directory(
    val_dir,
    target_size=(img_height, img_width),
    batch_size=batch_size,
    class_mode='categorical')

# Define the base model


model = tf.keras.models.load_model(base_model_path)

for layer in model.layers:
    layer.trainable = True

# 3. Remove the last layers (classifier)
layer_name = model.layers[-2].name  # Get the name of the layer before the final classifier layer
model_without_classifier = models.Model(inputs=model.input, outputs=model.get_layer(layer_name).output)


def apply_mask(x, mask):
    return x * tf.cast(tf.expand_dims(tf.constant(mask), 0), x.dtype)
for layer in model_without_classifier.layers:
    layer.trainable = True
# Apply the mask using a Lambda layer
masked_output = Lambda(lambda x: apply_mask(x, ~ignored_nodes))(model_without_classifier.output)

# Define the number of classes for your new Dense layer
num_classes = train_generator.num_classes  # Make sure this is defined correctly

# Add the Dense layer to the masked output
final_output = Dense(num_classes, activation='softmax', name='classifier_dense')(masked_output)

# Create the new model
model_with_masked_output = models.Model(inputs=model_without_classifier.input, outputs=final_output)



model_with_masked_output.load_weights(model_path)



model_with_masked_output.compile(optimizer=SGD(learning_rate=initial_learning_rate, momentum=0.0),
              loss='categorical_crossentropy',
              metrics=['accuracy', 'categorical_accuracy', 'Precision', 'Recall'])

os.makedirs(checkpoint_dir, exist_ok=True)

checkpoint_callback = ModelCheckpoint(
    filepath=os.path.join(checkpoint_dir, 'model_epoch_{epoch:02d}.keras'),
    save_weights_only=False,
    save_freq='epoch')

history = model_with_masked_output.fit(
    train_generator,
    epochs=epochs,
    validation_data=val_generator,
    callbacks=[checkpoint_callback])

with open(f'{checkpoint_dir}/training_history.pkl', 'wb') as f:
    pickle.dump(history.history, f)

with open(f'{checkpoint_dir}/training_history.pkl', 'rb') as f:
    saved_history = pickle.load(f)

plot_history(saved_history)



Found 5690 images belonging to 24 classes.
Found 1208 images belonging to 24 classes.
Epoch 1/50


  self._warn_if_super_not_called()


[1m89/89[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m318s[0m 3s/step - Precision: 0.9623 - Recall: 0.6948 - accuracy: 0.7304 - categorical_accuracy: 0.7304 - loss: 0.9126 - val_Precision: 0.8435 - val_Recall: 0.8121 - val_accuracy: 0.8237 - val_categorical_accuracy: 0.8237 - val_loss: 0.6934
Epoch 2/50
[1m89/89[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m247s[0m 3s/step - Precision: 0.9618 - Recall: 0.6869 - accuracy: 0.7324 - categorical_accuracy: 0.7324 - loss: 0.9363 - val_Precision: 0.8415 - val_Recall: 0.8129 - val_accuracy: 0.8228 - val_categorical_accuracy: 0.8228 - val_loss: 0.6946
Epoch 3/50
[1m89/89[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m234s[0m 3s/step - Precision: 0.9746 - Recall: 0.7172 - accuracy: 0.7616 - categorical_accuracy: 0.7616 - loss: 0.7990 - val_Precision: 0.8412 - val_Recall: 0.8113 - val_accuracy: 0.8245 - val_categorical_accuracy: 0.8245 - val_loss: 0.6961
Epoch 4/50
[1m89/89[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m237s[0m 3s/step

[1m60/89[0m [32m━━━━━━━━━━━━━[0m[37m━━━━━━━[0m [1m1:19[0m 3s/step - Precision: 0.9725 - Recall: 0.7007 - accuracy: 0.7559 - categorical_accuracy: 0.7559 - loss: 0.8047