# VGG19 Transfer Learning on FMS Images for Carbonate Rock Classification 

In [None]:
import os
import numpy as np
from keras.models import Model
from keras.optimizers import Adam
from tensorflow.keras import optimizers
from keras.applications.vgg19 import VGG19, preprocess_input
from keras.preprocessing.image import ImageDataGenerator
from keras.callbacks import ModelCheckpoint, EarlyStopping
from keras.layers import Dense, Dropout, Flatten
from pathlib import Path
from sklearn.metrics import accuracy_score, confusion_matrix, classification_report
import seaborn as sns
from matplotlib import pyplot as plt

In [None]:
def create_image_generators(train_dir, val_dir, test_dir, image_size, batch_size):
    """ Create image data generators for training, validation, and testing. """
    datagen = ImageDataGenerator(rescale=1./255, preprocessing_function=preprocess_input)
    train_gen = datagen.flow_from_directory(
        train_dir, target_size=image_size, class_mode='categorical', batch_size=batch_size, shuffle=True, seed=42)
    val_gen = datagen.flow_from_directory(
        val_dir, target_size=image_size, class_mode='categorical', batch_size=batch_size, shuffle=False, seed=42)
    test_gen = datagen.flow_from_directory(
        test_dir, target_size=image_size, class_mode=None, batch_size=1, shuffle=False, seed=42)
    return train_gen, val_gen, test_gen

def create_model(input_shape, n_classes, optimizer='rmsprop', fine_tune=0):
    """ Build and compile a VGG19 model for image classification. """
    conv_base = VGG19(include_top=False, weights='imagenet', input_shape=input_shape)
    if fine_tune > 0:
        for layer in conv_base.layers[:-fine_tune]:
            layer.trainable = False
    else:
        for layer in conv_base.layers:
            layer.trainable = False

    top_model = conv_base.output
    top_model = Flatten(name="flatten")(top_model)
    top_model = Dense(4096, activation='relu')(top_model)
    top_model = Dense(1072, activation='relu')(top_model)
    top_model = Dropout(0.2)(top_model)
    output_layer = Dense(n_classes, activation='softmax')(top_model)
    model = Model(inputs=conv_base.input, outputs=output_layer)
    model.compile(optimizer=optimizer, loss='categorical_crossentropy', metrics=['accuracy'])
    return model

def train_model(model, train_gen, val_gen, n_steps, n_val_steps, n_epochs, weights_path, plot_losses=True):
    """ Train and evaluate the model, save best weights, and plot training progress. """
    callbacks = [
        ModelCheckpoint(filepath=weights_path, save_best_only=True, verbose=1),
        EarlyStopping(monitor='val_loss', patience=5, restore_best_weights=True, mode='min')
    ]
    if plot_losses:
        from livelossplot.inputs.keras import PlotLossesCallback
        callbacks.append(PlotLossesCallback())

    history = model.fit(
        train_gen, epochs=n_epochs, validation_data=val_gen,
        steps_per_epoch=n_steps, validation_steps=n_val_steps,
        callbacks=callbacks, verbose=1
    )
    return history

def evaluate_model(model, test_gen, true_classes):
    """ Evaluate the model's accuracy and display confusion matrix. """
    preds = model.predict(test_gen)
    pred_classes = np.argmax(preds, axis=1)
    accuracy = accuracy_score(true_classes, pred_classes)
    print(f"Model Accuracy: {accuracy * 100:.2f}%")
    return accuracy

def plot_confusion_matrix(true_classes, pred_classes, class_names):
    """ Plot confusion matrix for model predictions. """
    cm = confusion_matrix(true_classes, pred_classes)
    cmn = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
    sns.heatmap(cmn, annot=True, fmt='.2f', cmap='Blues')
    plt.xlabel('Predicted')
    plt.ylabel('True')
    plt.show()

In [None]:
def main(data_dir, n_classes, batch_size, image_size=(224, 224), n_epochs=200, learning_rate=0.001, fine_tune=0):
    """ Main function to set up and run the training process. """
    train_dir = data_dir/'train'
    val_dir = data_dir/'val'
    test_dir = data_dir/'test'
    
    train_gen, val_gen, test_gen = create_image_generators(train_dir, val_dir, test_dir, image_size, batch_size)
    model = create_model(image_size + (3,), n_classes, Adam(learning_rate=learning_rate), fine_tune)
    
    weights_path = f"{data_dir}/model_best.weights.hdf5"
    train_model(model, train_gen, val_gen, len(train_gen), len(val_gen), n_epochs, weights_path)
    
    true_classes = test_gen.classes
    evaluate_model(model, test_gen, true_classes)
    preds = model.predict(test_gen)
    pred_classes = np.argmax(preds, axis=1)
    plot_confusion_matrix(true_classes, pred_classes, list(test_gen.class_indices.keys()))
    
if __name__ == "__main__":
    data_dir = Path("/path/to/your/data")
    main(data_dir, n_classes=5, batch_size=128)