In [1]:
USE_COLAB = True
SHOW_MODEL_SUMMARY = False
BATCH_SIZE = 32
EPOCHS = 500
LEARNING_RATE = 1e-4
AUGMENTATION = False
BALANCE_SET = "train" # "train", "val", "test", "train and val", "train and test", "val and test"
MIXUP = True
ALPHA_MIXUP = 0.2
MIXUP_AUGMENT_FACTOR = 2.0
CONV_LAYERS = 3
DENSE_LAYERS = 3
NODES_PER_LAYER = 512
DROPOUT_RATE = 0.5
PATIENCE = 40
L2_REGULARIZATION = 1e-3
USE_BASE_MODEL = True
MODEL = 'ConvNeXtSmall'  # 'VGG19', 'ResNet50', 'ResNet50V2', 'ResNet101', 'ResNet101V2', 'ResNet152',
# 'ResNet152V2', 'Xception', 'InceptionV3', 'InceptionResNetV2', 'MobileNet', 'MobileNetV2',
# 'DenseNet121', 'DenseNet169', 'DenseNet201', 'NASNetMobile', 'NASNetLarge',
# 'EfficientNetB0', 'EfficientNetB1', 'EfficientNetB2', 'EfficientNetB3', 'EfficientNetB4',
# 'EfficientNetB5', 'EfficientNetB6', 'EfficientNetB7', 'EfficientNetV2B0', 'EfficientNetV2B1',
# 'EfficientNetV2B2', 'EfficientNetV2B3', 'EfficientNetV2S', 'EfficientNetV2M', 'EfficientNetV2L',
# 'ConvNeXtTiny', 'ConvNeXtSmall', 'ConvNeXtBase', 'ConvNeXtLarge', 'ConvNeXtXLarge'
USE_BATCH_NORMALIZATION = False
USE_CLASS_WEIGHTS = False
BALANCE_TRAINING_CLASSES = False # Deprecated
USE_PREPROCESSING = False
BACKGROUND_THRESHOLD = 0.5 # if the background class has a probability higher than this threshold, the image is considered as background (set 1 if you want to disable this feature)
SEED = 72121

In [None]:
if USE_COLAB:
    from google.colab import drive

    drive.mount('/gdrive')
    %cd /gdrive/My Drive/ANN/CHAL1

Mounted at /gdrive
/gdrive/My Drive/ANN/058_CNS_UNFREEZED


In [3]:
from libraries import *
from preprocess import balance_classes, one_hot_encode_labels, clean_dataset
from data_partitioning import split_and_balance_distribution, print_class_distribution, apply_mixup
from custom_layer import PreprocessLayer, ConditionalAugmentation
from utils import get_base_model, analyze_mixup_distribution

In [4]:
# Load data
data = np.load('data/training_set.npz', allow_pickle=True)

# Divide data
labels = data['labels']
images = data['images']

if (0):
  # Print unique labels and their counts
  unique_labels = np.unique(labels)
  print("Unique labels in the dataset:", unique_labels)
  for label in unique_labels:
      count = np.sum(labels == label)
      print(f"Class {label}: {count} images")

  # Create a dictionary to store the first index for each class
  first_indices = {label: np.where(labels == label)[0][0] for label in unique_labels}

  # Plot the first image per class
  num_classes = len(unique_labels)
  num_cols = 4  # You can adjust this for layout
  num_rows = (num_classes + num_cols - 1) // num_cols

  plt.figure(figsize=(15, 4 * num_rows))
  for i, label in enumerate(unique_labels):
      idx = first_indices[label]
      plt.subplot(num_rows, num_cols, i + 1)
      plt.imshow(np.clip(images[idx] / 255.0, 0, 1))
      plt.title(f"Class {labels[idx]}")
      plt.axis('off')

  plt.tight_layout()
  plt.show()



In [5]:
# show that the background and the subject have different brightness
sample_size = 10
sample_images = images[:sample_size]  # Adjust as needed

def analyze_brightness(image, threshold=BACKGROUND_THRESHOLD):
    avg_intensity = np.mean(image)
    darker_pixels = image[image < threshold]
    lighter_pixels = image[image >= threshold]

    avg_darker = np.mean(darker_pixels) if darker_pixels.size > 0 else 0.0
    avg_lighter = np.mean(lighter_pixels) if lighter_pixels.size > 0 else 0.0

    return avg_intensity, avg_darker, avg_lighter

# Run analysis on the sample images
for i, image in enumerate(sample_images):
    avg_intensity, dark_avg, light_avg = analyze_brightness(image / 255.0)  # Normalize if needed
    print(f"Image {i + 1} - Avg Intensity: {avg_intensity:.4f}, Darker Avg: {dark_avg:.4f}, Lighter Avg: {light_avg:.4f}")

Image 1 - Avg Intensity: 0.8041, Darker Avg: 0.3032, Lighter Avg: 0.8206
Image 2 - Avg Intensity: 0.6285, Darker Avg: 0.3020, Lighter Avg: 0.7325
Image 3 - Avg Intensity: 0.7496, Darker Avg: 0.2688, Lighter Avg: 0.8029
Image 4 - Avg Intensity: 0.6996, Darker Avg: 0.2885, Lighter Avg: 0.7695
Image 5 - Avg Intensity: 0.7811, Darker Avg: 0.3534, Lighter Avg: 0.8025
Image 6 - Avg Intensity: 0.7382, Darker Avg: 0.2656, Lighter Avg: 0.7991
Image 7 - Avg Intensity: 0.7948, Darker Avg: 0.2735, Lighter Avg: 0.8306
Image 8 - Avg Intensity: 0.7702, Darker Avg: 0.3202, Lighter Avg: 0.8219
Image 9 - Avg Intensity: 0.7945, Darker Avg: 0.3568, Lighter Avg: 0.8018
Image 10 - Avg Intensity: 0.7157, Darker Avg: 0.2702, Lighter Avg: 0.7531


In [6]:
images, labels = clean_dataset(images, labels)

Removed 1808 duplicate and unwanted images.


In [7]:
X_train, X_val, X_test, y_train, y_val, y_test = split_and_balance_distribution(
    images, labels, val_size=0.15, test_size=0.01, seed=SEED, balance_sets=BALANCE_SET
)


Data Set Sizes:
--------------------
Train:        5704
Validation:   1793
Test:          120

Train Set Distribution:
Class          Count     Percentage
-----------------------------------
0                713         12.50%
1                713         12.50%
2                713         12.50%
3                713         12.50%
4                713         12.50%
5                713         12.50%
6                713         12.50%
7                713         12.50%

Validation Set Distribution:
Class          Count     Percentage
-----------------------------------
0                127          7.08%
1                327         18.24%
2                163          9.09%
3                303         16.90%
4                127          7.08%
5                149          8.31%
6                350         19.52%
7                247         13.78%

Test Set Distribution:
Class          Count     Percentage
-----------------------------------
0                  9          7.50%

In [8]:
if BALANCE_TRAINING_CLASSES:
    X_train, y_train = balance_classes(X_train, y_train, target_class_size=np.mean(np.bincount(labels)))
    print_class_distribution(y_train, "Train")
    print_class_distribution(y_val, "Validation")
    print_class_distribution(y_test, "Test")

In [9]:
# One-hot encode labels
y_train, y_val, y_test = one_hot_encode_labels(y_train, y_val, y_test)

In [10]:
print("X_train shape:", X_train.shape)
X_train, y_train = apply_mixup(X_train, y_train, alpha=ALPHA_MIXUP, factor=MIXUP_AUGMENT_FACTOR)
print("X_train shape after mixup:", X_train.shape)

X_train shape: (5704, 96, 96, 3)
Generating 5704 additional samples using Mixup
Generating 101 samples per class pair
X_train shape after mixup: (11360, 96, 96, 3)


In [11]:
X_train = preprocess_input(X_train)
X_test = preprocess_input(X_test)

input_shape = X_train[0].shape

output_shape = y_train[0].shape[0]

print(f"Input shape: {input_shape}")
print(f"Output shape: {output_shape}")


Input shape: (96, 96, 3)
Output shape: 8


In [12]:
def build_model(
    input_shape=input_shape,
    output_shape=output_shape,
    learning_rate=LEARNING_RATE,
    augmentation=None,
    seed=SEED,
    conv_layers=CONV_LAYERS,
    dense_layers=DENSE_LAYERS,
    dropout_rate=DROPOUT_RATE,
    l2_regularization=L2_REGULARIZATION,
    use_base_model=USE_BASE_MODEL,
    background_threshold=BACKGROUND_THRESHOLD,
    use_batch_normalization=USE_BATCH_NORMALIZATION,
    nodes_per_layer=NODES_PER_LAYER,
    use_preprocessing=USE_PREPROCESSING
    ):
    tf.random.set_seed(seed)

    relu_initialiser = tfk.initializers.HeNormal(seed=seed)
    output_initialiser = tfk.initializers.GlorotNormal(seed=seed)
    regularizer = tfk.regularizers.l2(l2_regularization)

    # Define the input layer with original input shape
    input_layer = tfk.Input(shape=input_shape, name='input_layer')

    # Preprocess the input image
    if use_preprocessing:
        x = PreprocessLayer(threshold=background_threshold)(input_layer)

    else:
        x = input_layer

    if use_base_model:
        # Load the VGG16 model with a custom input shape (96x96x3)
        base_model = get_base_model(MODEL, input_shape=input_shape)

        # Apply augmentation if specified
        x = augmentation(x) if augmentation else x

        x = base_model(x)
        x = tfkl.GlobalAveragePooling2D(name='avg_pool')(x)

    else:
        # Apply augmentation if specified
        x = augmentation(x) if augmentation else x

        # Add Conv layers
        x = tfkl.Conv2D(filters=16, kernel_size=3, activation='relu',
                       padding='same', name='first_conv')(x)
        x = tfkl.MaxPooling2D((2, 2), name='first_maxpool')(x)

        for i in range(conv_layers - 1):
            num_filters = 32 * (2 ** i)
            x = tfkl.Conv2D(
                filters=num_filters,
                kernel_size=3,
                activation='relu',
                padding='same',
                name=f'conv_{num_filters}')(x)

            if i < conv_layers - 2:  # Apply MaxPooling except for last conv layer
                x = tfkl.MaxPooling2D((2, 2), name=f'maxpool_{num_filters}')(x)

        # Apply GlobalAveragePooling2D after all conv layers
        x = tfkl.GlobalAveragePooling2D(name='global_avg_pool')(x)

    # Add Dense layers
    for i in range(dense_layers):
        x = tfkl.Dense(int(nodes_per_layer/(2**i)),
                      activation='relu',
                      name=f'dense_{i+1}',
                      kernel_initializer=relu_initialiser)(x)

        if use_batch_normalization:
            x = tfkl.BatchNormalization()(x)

        if dropout_rate > 0:
            x = tfkl.Dropout(dropout_rate, name=f'dropout_{i+1}')(x)

    output_layer = tfkl.Dense(output_shape,
                             activation='softmax',
                             name='output_layer',
                             kernel_initializer=output_initialiser,
                             kernel_regularizer=regularizer
                             if l2_regularization > 0 else None)(x)

    # Create model
    model = tfk.Model(input_layer, output_layer)


    # Compile the model
    model.compile(optimizer=tfk.optimizers.Adam(learning_rate=learning_rate),
                 loss='categorical_crossentropy',
                 metrics=['accuracy'])


    return model

In [13]:
if AUGMENTATION:
    augmentation_layers = tfk.Sequential([
        #tfkl.RandomFlip('horizontal'),
        #tfkl.RandomFlip('vertical'),
        #tfkl.RandomRotation(0.3),
        #tfkl.RandomTranslation(0.4, 0.4, fill_mode='nearest'),
        tfkl.RandomCrop(64, 64),
        tfkl.RandomZoom(0.3, fill_mode='nearest'),
        tfkl.Resizing(96, 96)
    ], name='augmentation')

    augmentation = ConditionalAugmentation(augmentation_layers)

In [14]:
model = build_model(
    augmentation=augmentation if AUGMENTATION else None
)

if SHOW_MODEL_SUMMARY:
    model.summary()

early_stopping = tfk.callbacks.EarlyStopping(
    monitor='val_accuracy',
    patience=PATIENCE,
    restore_best_weights=True,
    mode='auto'
)

checkpoint_callback = tfk.callbacks.ModelCheckpoint(
    'models/best_model_restored.keras',  # Path where the model will be saved
    monitor='val_accuracy',  # Metric to monitor
    save_best_only=True,  # Save only the best model
    verbose=1,  # Print messages when saving the model
    save_weights_only=False,  # Save the entire model (including architecture)
    mode='max'  # 'max' to save the model with the highest validation accuracy
)

callbacks = [early_stopping, checkpoint_callback]


Downloading data from https://storage.googleapis.com/tensorflow/keras-applications/convnext/convnext_small_notop.h5
[1m198551472/198551472[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 0us/step


In [15]:
# Calculate the total number of samples
num_samples = len(labels)

# Convert labels to numpy array if not already
labels = np.array(labels)

# Get unique classes
unique_classes = np.unique(labels)
num_classes = len(unique_classes)

# Initialize a dictionary to count occurrences of each class
class_counts = {cls: 0 for cls in unique_classes}

# Count occurrences of each class
for cls in unique_classes:
    class_counts[cls] = np.sum(labels == cls)
    print(f"Class {cls}: {class_counts[cls]} samples")

# Calculate class weights
class_weight = {i: 1.0 for i in unique_classes}  # Default weights set to 1 for each class

if USE_CLASS_WEIGHTS:
    # Calculate balanced weights
    class_weight = {
        i: num_samples / (num_classes * class_counts[i])
        for i in unique_classes
    }

print(f"Class weights: {class_weight}")

Class 0: 850 samples
Class 1: 2179 samples
Class 2: 1085 samples
Class 3: 2023 samples
Class 4: 849 samples
Class 5: 992 samples
Class 6: 2330 samples
Class 7: 1643 samples
Class weights: {0: 1.0, 1: 1.0, 2: 1.0, 3: 1.0, 4: 1.0, 5: 1.0, 6: 1.0, 7: 1.0}


In [16]:
# Train the model with early stopping callback
history = model.fit(
    x=X_train,
    y=y_train,
    batch_size=BATCH_SIZE,
    epochs=EPOCHS,
    validation_data=(X_val, y_val),
    callbacks=callbacks,
    class_weight=class_weight
).history

final_val_acc = history['val_accuracy'][-(PATIENCE+1)] * 100
print(f'Final validation accuracy: {final_val_acc:.0f}')

Epoch 1/500
[1m355/355[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 144ms/step - accuracy: 0.2732 - loss: 2.6013
Epoch 1: val_accuracy improved from -inf to 0.89069, saving model to models/best_model_restored.keras
[1m355/355[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m179s[0m 279ms/step - accuracy: 0.2737 - loss: 2.5990 - val_accuracy: 0.8907 - val_loss: 0.3570
Epoch 2/500
[1m355/355[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 150ms/step - accuracy: 0.7908 - loss: 0.9143
Epoch 2: val_accuracy improved from 0.89069 to 0.96877, saving model to models/best_model_restored.keras
[1m355/355[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m94s[0m 242ms/step - accuracy: 0.7909 - loss: 0.9141 - val_accuracy: 0.9688 - val_loss: 0.1566
Epoch 3/500
[1m355/355[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 154ms/step - accuracy: 0.8992 - loss: 0.6680
Epoch 3: val_accuracy improved from 0.96877 to 0.98048, saving model to models/best_model_restored.keras
[1m355/35

KeyboardInterrupt: 

In [None]:
# Create a timestamp for the filename
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")

# save model using val acc
model.save(f'models/model_{final_val_acc:.0f}_{timestamp}.keras')

del model

In [None]:
# Function to log model parameters to a text file
def log_model_parameters(final_val_acc, timestamp):
    # Create the log filename with date and time
    log_filename = f'models/model_{final_val_acc:.0f}_params_{timestamp}.txt'

    # Write the parameters to the log file
    with open(log_filename, 'w') as log_file:
        log_file.write("Model Training Parameters:\n\n")
        log_file.write(f"BATCH_SIZE: {BATCH_SIZE}\n")
        log_file.write(f"EPOCHS: {EPOCHS}\n")
        log_file.write(f"LEARNING_RATE: {LEARNING_RATE}\n")
        log_file.write(f"AUGMENTATION: {AUGMENTATION}\n")
        log_file.write(f"MIXUP: {MIXUP}\n")
        log_file.write(f"CONV_LAYERS: {CONV_LAYERS}\n")
        log_file.write(f"DENSE_LAYERS: {DENSE_LAYERS}\n")
        log_file.write(f"NODES_PER_LAYER: {NODES_PER_LAYER}\n")
        log_file.write(f"DROPOUT_RATE: {DROPOUT_RATE}\n")
        log_file.write(f"PATIENCE: {PATIENCE}\n")
        log_file.write(f"L2_REGULARIZATION: {L2_REGULARIZATION}\n")
        log_file.write(f"USE_BASE_MODEL: {USE_BASE_MODEL}\n")
        log_file.write(f"USE_BATCH_NORMALIZATION: {USE_BATCH_NORMALIZATION}\n")
        log_file.write(f"USE_CLASS_WEIGHTS: {USE_CLASS_WEIGHTS}\n")
        log_file.write(f"BALANCE_TRAINING_CLASSES: {BALANCE_TRAINING_CLASSES}\n")
        log_file.write(f"SEED: {SEED}\n")


# Log the model parameters
log_model_parameters(final_val_acc, timestamp)


In [None]:
# plot training loss and accuracy
def plot_training(history):
    fig, axs = plt.subplots(1, 2, figsize=(15, 5))

    axs[0].plot(history['loss'], label='train')
    axs[0].plot(history['val_loss'], label='validation')
    axs[0].set_title('Loss')
    axs[0].legend()

    axs[1].plot(history['accuracy'], label='train')
    axs[1].plot(history['val_accuracy'], label='validation')
    axs[1].set_title('Accuracy')
    axs[1].legend()

    plt.show()

plot_training(history)



# Make inference

In [None]:
# Load the saved model
model = tfk.models.load_model(f'models/model_{final_val_acc:.0f}_{timestamp}.keras', custom_objects={'PreprocessLayer': PreprocessLayer, 'ConditionalAugmentation': ConditionalAugmentation})



In [None]:
# Predict class probabilities and get predicted classes
test_predictions = model.predict(X_test, verbose=0)
test_predictions_classes = np.argmax(test_predictions, axis=-1)

# Extract ground truth classes
test_gt = np.argmax(y_test, axis=-1)

# Calculate and display test set accuracy
test_accuracy = accuracy_score(test_gt, test_predictions_classes)
print(f'Accuracy score over the test set: {round(test_accuracy, 4)}')

# Calculate and display test set precision
test_precision = precision_score(test_gt, test_predictions_classes, average='weighted')
print(f'Precision score over the test set: {round(test_precision, 4)}')

# Calculate and display test set recall
test_recall = recall_score(test_gt, test_predictions_classes, average='weighted')
print(f'Recall score over the test set: {round(test_recall, 4)}')

# Calculate and display test set F1 score
test_f1 = f1_score(test_gt, test_predictions_classes, average='weighted')
print(f'F1 score over the test set: {round(test_f1, 4)}')

# Compute the confusion matrix
cm = confusion_matrix(test_gt, test_predictions_classes)

# Plot the confusion matrix with class labels
plt.figure(figsize=(10, 8))
sns.heatmap(cm, annot=True, fmt='d',
            xticklabels=[f'Class {i}' for i in range(8)],
            yticklabels=[f'Class {i}' for i in range(8)], cmap='Blues')
plt.xlabel('Predicted labels')
plt.ylabel('True labels')
plt.show()

# Classification report for detailed metrics per class
print("\nClassification Report:\n")
print(classification_report(test_gt, test_predictions_classes))

# ROC-AUC score for each class (only if this is multilabel or multiclass with probability predictions)
# Binarize the output classes for AUC calculation
y_test_binarized = label_binarize(test_gt, classes=range(8))
roc_auc_scores = []
for i in range(8):
    try:
        roc_auc = roc_auc_score(y_test_binarized[:, i], test_predictions[:, i])
        roc_auc_scores.append(roc_auc)
        print(f"Class {i} ROC-AUC Score: {round(roc_auc, 4)}")
    except ValueError:
        print(f"Class {i} ROC-AUC Score: Unable to calculate (not enough samples).")

# Optional: Display mean ROC-AUC score across classes
if roc_auc_scores:
    print(f"\nMean ROC-AUC Score: {round(np.mean(roc_auc_scores), 4)}")