# Cross validation

This a first demonstration of how we'll use cross validation method in other models. Performance is not yet the main concern. This aim to clearely defined what is cross validation and how to apply it in any type of model.

The Model class is defined with methods for initializing the dataset, building the model, and training the model.

The __init__ method initializes the dataset from the "data/train" directory, gets the class names and some dataset properties, and displays some information about the dataset, such as image batch shapes, class distribution, and the mean of the data. 

The dataset is normalized and the features (images) and labels (classes) are separated.

In [None]:
from sklearn.utils import class_weight
from sklearn.model_selection import KFold

from scripts.x_ray_dataset_builder import Dataset


class Model:
    def __init__(self, image_size=(512, 512)):
        train_dir = pathlib.Path("data/train")

        train_ds = Dataset(train_dir, batch_size=64, image_size=image_size)

        AUTOTUNE = tf.data.AUTOTUNE

        train_ds.build(AUTOTUNE, False)

        class_names = train_ds.get_class_names()
        print("\nClass names:")
        print(class_names)

        train_x_batch_shape = train_ds.get_x_batch_shape()
        print("\nTraining dataset's images batch shape is:")
        print(train_x_batch_shape)

        train_y_batch_shape = train_ds.get_y_batch_shape()
        print("\nTraining dataset's labels batch shape is:")
        print(train_y_batch_shape)

        train_ds.display_images_in_batch(2, "Training dataset")
        train_ds.display_batch_number("Training dataset")
        train_ds.display_distribution("Training dataset")
        train_ds.display_mean("Training dataset")

        self.class_names = class_names
        self.train_ds = train_ds.normalized_dataset
        self.x_train = train_ds.x_dataset
        self.y_train = train_ds.y_dataset

The build method defines the architecture of the model, which is a simple feedforward neural network with an input layer that flattens the images, a hidden layer with 128 neurons using ReLU activation function, and an output layer with a number of neurons equivalent to the number of classes using the softmax activation function. 

The model is compiled with the Adam optimizer, categorical cross-entropy loss function, and metrics including categorical accuracy, precision, and recall.

In [None]:
 def build(self, input_shape=(512, 512, 1)):
        model = tf.keras.Sequential(
            [
                tf.keras.layers.Flatten(input_shape=input_shape),
                tf.keras.layers.Dense(128, activation="relu"),
                tf.keras.layers.Dense(len(self.class_names), activation="softmax"),
            ]
        )

        model.compile(
            optimizer="adam",
            loss=tf.keras.losses.CategoricalCrossentropy(from_logits=False),
            metrics=[
                tf.keras.metrics.CategoricalAccuracy(), 
                tf.keras.metrics.Precision(), 
                tf.keras.metrics.Recall()
            ],
        )

        model.summary()

        return model

The train method implements k-fold cross-validation for training and evaluating the model. 

The dataset is split into 'k' folds, with each fold used as a validation set while the remaining folds are used for training. 

The script computes class weights for handling imbalanced classes. 

For each fold, the model is built and trained, and the training and validation accuracies and losses are plotted. 

After training, the model is saved in the "notebooks/2_cross_validation" directory.

In [None]:
    def train(self, epochs, k=5, input_shape=(512, 512, 1)):
        print("\nStarting training...")
        k = k
        num_epochs = epochs

        kfold = KFold(n_splits=k, shuffle=True, random_state=1)
        fold = 1

        class_weights = class_weight.compute_class_weight('balanced', classes=np.unique(self.y_train), y=np.argmax(self.y_train, axis=1))
        class_weights = dict(enumerate(class_weights))
        class_weights[0] = class_weights[0] * 4.25


        for train_index, val_index in kfold.split(self.x_train, self.y_train):       
            model = self.build(input_shape=input_shape)

            print(f"Processing fold {fold}")
            train_images, val_images = self.x_train[train_index], self.x_train[val_index]
            train_labels, val_labels = self.y_train[train_index], self.y_train[val_index]

            history = model.fit(train_images, train_labels, class_weight=class_weights, batch_size=64, epochs=num_epochs, validation_data=(val_images, val_labels))
            
            fold += 1

            categorical_accuracy = history.history["categorical_accuracy"]
            val_categorical_accuracy = history.history["val_categorical_accuracy"]

            loss = history.history["loss"]
            val_loss = history.history["val_loss"]

            epochs_range = range(epochs)

            plt.figure(figsize=(8, 8))
            plt.subplot(1, 2, 1)
            plt.plot(epochs_range, categorical_accuracy, label="Training Accuracy")
            plt.plot(epochs_range, val_categorical_accuracy, label="Validation Accuracy")
            plt.legend(loc="lower right")
            plt.title("Training and Validation Accuracy")

            plt.subplot(1, 2, 2)
            plt.plot(epochs_range, loss, label="Training Loss")
            plt.plot(epochs_range, val_loss, label="Validation Loss")
            plt.legend(loc="upper right")
            plt.title("Training and Validation Loss")
            plt.show()
        
        print("\n\033[92mTraining done !\033[0m")

        print("\nSaving...")
        model.save("notebooks/2_cross_validation/model_2.keras")
        model.save("notebooks/2_cross_validation")
        print("\n\033[92mSaving done !\033[0m")

This script is an example of how to implement k-fold cross-validation in machine learning, which helps to better estimate the model's performance by averaging the performance metrics over 'k' different train-validation splits. 

It's particularly useful when the available data is limited