In [None]:
USE_COLAB = False
SHOW_MODEL_SUMMARY = False
BATCH_SIZE = 32
EPOCHS = 10
LEARNING_RATE = 1e-4
AUGMENTATION = False
BALANCE_SET = "TEST" # "train", "val", "test", "train and val", "train and test", "val and test" ( use test to not balance)
VAL_SPLIT = 0.25
TEST_SPLIT = 0.000001
USE_TEST = False
USE_TEST_BIG = True
MIXUP = True
ALPHA_MIXUP = 0.2
MIXUP_AUGMENT_FACTOR = 2.0
CONV_LAYERS = 3
DENSE_LAYERS = 3
NODES_PER_LAYER = 512
DROPOUT_RATE = 0.5
PATIENCE = 2
L2_REGULARIZATION = 1e-3
USE_BASE_MODEL = False
MODEL = 'ConvNeXtSmall'  # 'VGG19', 'ResNet50', 'ResNet50V2', 'ResNet101', 'ResNet101V2', 'ResNet152',
# 'ResNet152V2', 'Xception', 'InceptionV3', 'InceptionResNetV2', 'MobileNet', 'MobileNetV2',
# 'DenseNet121', 'DenseNet169', 'DenseNet201', 'NASNetMobile', 'NASNetLarge',
# 'EfficientNetB0', 'EfficientNetB1', 'EfficientNetB2', 'EfficientNetB3', 'EfficientNetB4',
# 'EfficientNetB5', 'EfficientNetB6', 'EfficientNetB7', 'EfficientNetV2B0', 'EfficientNetV2B1',
# 'EfficientNetV2B2', 'EfficientNetV2B3', 'EfficientNetV2S', 'EfficientNetV2M', 'EfficientNetV2L',
# 'ConvNeXtTiny', 'ConvNeXtSmall', 'ConvNeXtBase', 'ConvNeXtLarge', 'ConvNeXtXLarge'
USE_BATCH_NORMALIZATION = False
USE_CLASS_WEIGHTS = False
BALANCE_TRAINING_CLASSES = False # Deprecated
USE_PREPROCESSING = False
BACKGROUND_THRESHOLD = 0.5 # if the background class has a probability higher than this threshold, the image is considered as background (set 1 if you want to disable this feature)
SEED = 72121

In [None]:
if USE_COLAB:
    from google.colab import drive

    drive.mount('/gdrive')
    %cd /gdrive/My Drive/ANN/CHAL1

In [None]:
from libraries import *

In [None]:
# Load data
data = np.load('data/training_set.npz', allow_pickle=True)

# Divide data
labels = data['labels']
images = data['images']

In [None]:
images, labels = clean_dataset(images, labels)

In [None]:
X_train, X_val, X_test, y_train, y_val, y_test = split_and_balance_distribution(
    images, labels, val_size=VAL_SPLIT, test_size=TEST_SPLIT, seed=SEED, balance_sets=BALANCE_SET, TEST=USE_TEST
)


In [None]:
if USE_TEST:
    # One-hot encode labels
    y_train, y_val, y_test = one_hot_encode_labels(y_train, y_val, y_test)
else:
    y_train, y_val = one_hot_encode_labels(y_train, y_val, None)

In [None]:
print("X_train shape:", X_train.shape)
X_train, y_train = apply_mixup(X_train, y_train, alpha=ALPHA_MIXUP, factor=MIXUP_AUGMENT_FACTOR)
print("X_train shape after mixup:", X_train.shape)

In [None]:
# Preprocess function suited for ConvNeXt models
X_train = preprocess_input(X_train)
X_test = preprocess_input(X_test)

input_shape = X_train[0].shape
output_shape = y_train[0].shape[0]

print(f"Input shape: {input_shape}")
print(f"Output shape: {output_shape}")


In [None]:
def build_model(
    input_shape=input_shape,
    output_shape=output_shape,
    learning_rate=LEARNING_RATE,
    augmentation=None,
    seed=SEED,
    conv_layers=CONV_LAYERS,
    dense_layers=DENSE_LAYERS,
    dropout_rate=DROPOUT_RATE,
    l2_regularization=L2_REGULARIZATION,
    use_base_model=USE_BASE_MODEL,
    background_threshold=BACKGROUND_THRESHOLD,
    use_batch_normalization=USE_BATCH_NORMALIZATION,
    nodes_per_layer=NODES_PER_LAYER,
    use_preprocessing=USE_PREPROCESSING
    ):
    tf.random.set_seed(seed)

    relu_initialiser = tfk.initializers.HeNormal(seed=seed)
    output_initialiser = tfk.initializers.GlorotNormal(seed=seed)
    regularizer = tfk.regularizers.l2(l2_regularization)

    # Define the input layer with original input shape
    input_layer = tfk.Input(shape=input_shape, name='input_layer')

    # Preprocess the input image
    if use_preprocessing:
        x = PreprocessLayer(threshold=background_threshold)(input_layer)

    else:
        x = input_layer

    if use_base_model:
        # Load the VGG16 model with a custom input shape (96x96x3)
        base_model = get_base_model(MODEL, input_shape=input_shape)

        # Apply augmentation if specified
        x = augmentation(x) if augmentation else x

        x = base_model(x)
        x = tfkl.GlobalAveragePooling2D(name='avg_pool')(x)

    else:
        # Apply augmentation if specified
        x = augmentation(x) if augmentation else x

        # Add Conv layers
        x = tfkl.Conv2D(filters=16, kernel_size=3, activation='relu',
                       padding='same', name='first_conv')(x)
        x = tfkl.MaxPooling2D((2, 2), name='first_maxpool')(x)

        for i in range(conv_layers - 1):
            num_filters = 32 * (2 ** i)
            x = tfkl.Conv2D(
                filters=num_filters,
                kernel_size=3,
                activation='relu',
                padding='same',
                name=f'conv_{num_filters}')(x)

            if i < conv_layers - 2:  # Apply MaxPooling except for last conv layer
                x = tfkl.MaxPooling2D((2, 2), name=f'maxpool_{num_filters}')(x)

        # Apply GlobalAveragePooling2D after all conv layers
        x = tfkl.GlobalAveragePooling2D(name='global_avg_pool')(x)

    # Add Dense layers
    for i in range(dense_layers):
        x = tfkl.Dense(int(nodes_per_layer/(2**i)),
                      activation='relu',
                      name=f'dense_{i+1}',
                      kernel_initializer=relu_initialiser)(x)

        if use_batch_normalization:
            x = tfkl.BatchNormalization()(x)

        if dropout_rate > 0:
            x = tfkl.Dropout(dropout_rate, name=f'dropout_{i+1}')(x)

    output_layer = tfkl.Dense(output_shape,
                             activation='softmax',
                             name='output_layer',
                             kernel_initializer=output_initialiser,
                             kernel_regularizer=regularizer
                             if l2_regularization > 0 else None)(x)

    # Create model
    model = tfk.Model(input_layer, output_layer)


    # Compile the model
    model.compile(optimizer=tfk.optimizers.Adam(learning_rate=learning_rate),
                 loss='categorical_crossentropy',
                 metrics=['accuracy'])


    return model

In [None]:
if AUGMENTATION:
    augmentation_layers = tfk.Sequential([
        #tfkl.RandomFlip('horizontal'),
        #tfkl.RandomFlip('vertical'),
        #tfkl.RandomRotation(0.3),
        #tfkl.RandomTranslation(0.4, 0.4, fill_mode='nearest'),
        tfkl.RandomCrop(64, 64),
        tfkl.RandomZoom(0.3, fill_mode='nearest'),
        tfkl.Resizing(96, 96)
    ], name='augmentation')

    augmentation = ConditionalAugmentation(augmentation_layers)

In [None]:
model = build_model(
    augmentation=augmentation if AUGMENTATION else None
)

if SHOW_MODEL_SUMMARY:
    model.summary()

early_stopping = tfk.callbacks.EarlyStopping(
    monitor='val_accuracy',
    patience=PATIENCE,
    restore_best_weights=True,
    mode='auto'
)

checkpoint_callback = tfk.callbacks.ModelCheckpoint(
    'models/best_model_restored.keras',  # Path where the model will be saved
    monitor='val_accuracy',  # Metric to monitor
    save_best_only=True,  # Save only the best model
    verbose=1,  # Print messages when saving the model
    save_weights_only=False,  # Save the entire model (including architecture)
    mode='max'  # 'max' to save the model with the highest validation accuracy
)

callbacks = [early_stopping, checkpoint_callback]


In [None]:
class_weights = compute_class_weights(y_train) if USE_CLASS_WEIGHTS else None

In [None]:
# Train the model with early stopping callback
history = model.fit(
    x=X_train,
    y=y_train,
    batch_size=BATCH_SIZE,
    epochs=EPOCHS,
    validation_data=(X_val, y_val),
    callbacks=callbacks,
    class_weight=class_weights
).history

final_val_acc = history['val_accuracy'][-(PATIENCE+1)] * 100
print(f'Final validation accuracy: {final_val_acc:.0f}')

In [None]:
# Create a timestamp for the filename
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")

# save model using val acc
model.save(f'models/model_{final_val_acc:.0f}_{timestamp}.keras')

del model

In [None]:
# Function to log model parameters to a text file
def log_model_parameters(final_val_acc, timestamp):
    # Create the log filename with date and time
    log_filename = f'models/model_{final_val_acc:.0f}_params_{timestamp}.txt'

    # Write the parameters to the log file
    with open(log_filename, 'w') as log_file:
        log_file.write("Model Training Parameters:\n\n")
        log_file.write(f"BATCH_SIZE: {BATCH_SIZE}\n")
        log_file.write(f"EPOCHS: {EPOCHS}\n")
        log_file.write(f"LEARNING_RATE: {LEARNING_RATE}\n")
        log_file.write(f"AUGMENTATION: {AUGMENTATION}\n")
        log_file.write(f"MIXUP: {MIXUP}\n")
        log_file.write(f"CONV_LAYERS: {CONV_LAYERS}\n")
        log_file.write(f"DENSE_LAYERS: {DENSE_LAYERS}\n")
        log_file.write(f"NODES_PER_LAYER: {NODES_PER_LAYER}\n")
        log_file.write(f"DROPOUT_RATE: {DROPOUT_RATE}\n")
        log_file.write(f"PATIENCE: {PATIENCE}\n")
        log_file.write(f"L2_REGULARIZATION: {L2_REGULARIZATION}\n")
        log_file.write(f"USE_BASE_MODEL: {USE_BASE_MODEL}\n")
        log_file.write(f"USE_BATCH_NORMALIZATION: {USE_BATCH_NORMALIZATION}\n")
        log_file.write(f"USE_CLASS_WEIGHTS: {USE_CLASS_WEIGHTS}\n")
        log_file.write(f"BALANCE_TRAINING_CLASSES: {BALANCE_TRAINING_CLASSES}\n")
        log_file.write(f"SEED: {SEED}\n")


# Log the model parameters
log_model_parameters(final_val_acc, timestamp)


In [None]:
# plot training loss and accuracy
def plot_training(history):
    fig, axs = plt.subplots(1, 2, figsize=(15, 5))

    axs[0].plot(history['loss'], label='train')
    axs[0].plot(history['val_loss'], label='validation')
    axs[0].set_title('Loss')
    axs[0].legend()

    axs[1].plot(history['accuracy'], label='train')
    axs[1].plot(history['val_accuracy'], label='validation')
    axs[1].set_title('Accuracy')
    axs[1].legend()

    plt.show()

plot_training(history)



# Make inference

In [None]:
# Load the saved model
#model = tfk.models.load_model(f'models/model_{final_val_acc:.0f}_{timestamp}.keras', custom_objects={'PreprocessLayer': PreprocessLayer, 'ConditionalAugmentation': ConditionalAugmentation})
model = tfk.models.load_model(f'models/best_model_restored.keras', custom_objects={'PreprocessLayer': PreprocessLayer, 'ConditionalAugmentation': ConditionalAugmentation})

In [None]:
# Main testing logic
if USE_TEST_BIG:
    test_data = np.load('data/blood_cells_96x96.npz', allow_pickle=True)
    test_images = test_data['data']
    test_labels = test_data['labels']
    

    # Divide in 10 groups of tot images
    N = 500
    M = 10
    images_per_class = N // 8
    
    test_accuracy = 0.0
    test_precision = 0.0
    test_recall = 0.0
    test_f1 = 0.0
    for k in range(M):
        
        print(f'Group {k+1}/{M}')
        
        # choose images_per_class images for each class
        group_labels = []
        group_indices = []
        group_images = []
    
        for i in range(8):
            indexes = np.where(test_labels == i)[0]
            np.random.shuffle(indexes)
            group_indices.extend(indexes[:images_per_class])
            
        group_labels = test_labels[group_indices]
        group_images = test_images[group_indices]
        
        # Predict class probabilities and get predicted classes for normal test set
        test_predictions = model.predict(group_images, verbose=1)
        test_predictions_classes = np.argmax(test_predictions, axis=-1)
        
        # Calculate and display metrics for the normal test set
        test_accuracy += accuracy_score(group_labels, test_predictions_classes)
        test_precision += precision_score(group_labels, test_predictions_classes, average='weighted')
        test_recall += recall_score(group_labels, test_predictions_classes, average='weighted')
        test_f1 += f1_score(group_labels, test_predictions_classes, average='weighted')

    test_accuracy /= M
    test_precision /= M
    test_recall /= M
    test_f1 /= M
    
    print(f'Accuracy score over the normal test set: {round(test_accuracy, 4)}')
    print(f'Precision score over the normal test set: {round(test_precision, 4)}')
    print(f'Recall score over the normal test set: {round(test_recall, 4)}')
    print(f'F1 score over the normal test set: {round(test_f1, 4)}')
    
    
elif USE_TEST:
    test_data = np.load('data/test_set.npz', allow_pickle=True)
    test_images = test_data['data']
    test_labels = test_data['labels']
    
    # Predict class probabilities and get predicted classes for normal test set
    test_predictions = model.predict(test_images, verbose=0)
    test_predictions_classes = np.argmax(test_predictions, axis=-1)

    # Extract ground truth classes
    test_gt = np.argmax(test_labels, axis=-1)

    # Calculate and display metrics for the normal test set
    test_accuracy = accuracy_score(test_gt, test_predictions_classes)
    test_precision = precision_score(test_gt, test_predictions_classes, average='weighted')
    test_recall = recall_score(test_gt, test_predictions_classes, average='weighted')
    test_f1 = f1_score(test_gt, test_predictions_classes, average='weighted')

    print(f'Accuracy score over the normal test set: {round(test_accuracy, 4)}')
    print(f'Precision score over the normal test set: {round(test_precision, 4)}')
    print(f'Recall score over the normal test set: {round(test_recall, 4)}')
    print(f'F1 score over the normal test set: {round(test_f1, 4)}')

    # Compute the confusion matrix
    cm = confusion_matrix(test_gt, test_predictions_classes)

    # Plot the confusion matrix with class labels
    plt.figure(figsize=(10, 8))
    sns.heatmap(cm, annot=True, fmt='d',
                xticklabels=[f'Class {i}' for i in range(8)],
                yticklabels=[f'Class {i}' for i in range(8)], cmap='Blues')
    plt.xlabel('Predicted labels')
    plt.ylabel('True labels')
    plt.show()

    # Classification report for detailed metrics per class
    print("\nClassification Report:\n")
    print(classification_report(test_gt, test_predictions_classes))

    # ROC-AUC score for each class (only if this is multilabel or multiclass with probability predictions)
    y_test_binarized = label_binarize(test_gt, classes=range(8))
    roc_auc_scores = []
    for i in range(8):
        try:
            roc_auc = roc_auc_score(y_test_binarized[:, i], test_predictions[:, i])
            roc_auc_scores.append(roc_auc)
            print(f"Class {i} ROC-AUC Score: {round(roc_auc, 4)}")
        except ValueError:
            print(f"Class {i} ROC-AUC Score: Unable to calculate (not enough samples).")

    # Optional: Display mean ROC-AUC score across classes
    if roc_auc_scores:
        print(f"\nMean ROC-AUC Score: {round(np.mean(roc_auc_scores), 4)}")