In [6]:
import os
import numpy as np
from PIL import Image
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms

import tensorflow as tf
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Conv2D, MaxPooling2D, Flatten, Dense, Dropout
from tensorflow.keras.callbacks import EarlyStopping
from tensorflow.keras.optimizers import Adam
from sklearn.metrics import classification_report, confusion_matrix
from tensorflow.keras.models import load_model
from tensorflow.keras.preprocessing.image import ImageDataGenerator

Read YOLO data & process

In [2]:
# DATA PATH
path = r'D:\NOTEBOOK\FYP\final_dataset_withnon'
train_img_path = os.path.join(path, 'images', 'train')
train_lbl_path = os.path.join(path, 'labels', 'train')

valid_img_path = os.path.join(path, 'images', 'val')
valid_lbl_path = os.path.join(path, 'labels', 'val')

test_img_path = os.path.join(path, 'images', 'test')
test_lbl_path = os.path.join(path, 'labels', 'test')

IMG_SIZE = 128
NUM_CLASSES = 4

# path to save processed images
processed_path = os.path.join(path, 'processed')
os.makedirs(processed_path, exist_ok=True)

In [3]:
# process a dataset by extracting craters from images based on YOLO-format labels and saving them in class-specific directories. 
def process_dataset(img_dir, lbl_dir, output_dir):
    for img_file in os.listdir(img_dir):
        if not img_file.endswith('.jpg'):
            continue # skip non-image files
            
        # Get the corresponding label file
        base_name = os.path.splitext(img_file)[0]
        lbl_file = os.path.join(lbl_dir, f"{base_name}.txt")
        
        # Process single image
        img = Image.open(os.path.join(img_dir, img_file))
        img_w, img_h = img.size
        
        with open(lbl_file, 'r') as f:
            for idx, line in enumerate(f.readlines()):
                class_id, xc, yc, w, h = map(float, line.strip().split())
                # Ensure image bounds
                x1 = int((xc - w/2) * img_w)
                y1 = int((yc - h/2) * img_h)
                x2 = int((xc + w/2) * img_w)
                y2 = int((yc + h/2) * img_h)
                                
                x1, y1 = max(0, x1), max(0, y1)
                x2, y2 = min(img_w, x2), min(img_h, y2)
                
                # Crop the crater and resize it
                crater = img.crop((x1, y1, x2, y2))
                crater = crater.resize((IMG_SIZE, IMG_SIZE), Image.Resampling.LANCZOS)
                # save
                class_dir = os.path.join(output_dir, str(int(class_id)))
                os.makedirs(class_dir, exist_ok=True)
                crater.save(os.path.join(class_dir, f"{base_name}_{idx}.jpg"))

# process all dataset
process_dataset(train_img_path, train_lbl_path, os.path.join(processed_path, 'train'))
process_dataset(valid_img_path, valid_lbl_path, os.path.join(processed_path, 'val'))
process_dataset(test_img_path, test_lbl_path, os.path.join(processed_path, 'test'))

In [4]:
# define a PyTorch dataset for loading crater images and their corresponding labels.
class CraterDataset(Dataset):
    def __init__(self, data_dir, transform=None):
        self.data = []
        self.transform = transform
        
        for class_id in range(NUM_CLASSES):
            class_dir = os.path.join(data_dir, str(class_id))
            for img_file in os.listdir(class_dir):
                self.data.append((os.path.join(class_dir, img_file), class_id))
                
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        img_path, label = self.data[idx]
        img = Image.open(img_path).convert('RGB') 
        
        if self.transform:
            img = self.transform(img)
            
        return img, label

In [5]:
# Define data enhancement
train_datagen = ImageDataGenerator(
    rescale=1./255,
    rotation_range=15,
    horizontal_flip=True,
    width_shift_range=0.1,
    height_shift_range=0.1
)

# Loading data from category folder
train_generator = train_datagen.flow_from_directory(
    os.path.join(processed_path, 'train'),
    target_size=(IMG_SIZE, IMG_SIZE),
    batch_size=32,
    class_mode='categorical'
)

val_generator = train_datagen.flow_from_directory(
    os.path.join(processed_path, 'val'),
    target_size=(IMG_SIZE, IMG_SIZE),
    batch_size=32,
    class_mode='categorical'
)

Found 1880 images belonging to 4 classes.
Found 413 images belonging to 4 classes.


CNN Model

In [7]:
model = Sequential([
    Conv2D(32, (3,3), activation='relu', input_shape=(IMG_SIZE, IMG_SIZE, 3)),
    MaxPooling2D((2,2)),

    Conv2D(64, (3,3), activation='relu'),
    MaxPooling2D((2,2)),

    Conv2D(128, (3,3), activation='relu'),
    MaxPooling2D((2,2)),

    Flatten(),
    Dense(128, activation='relu'),
    Dropout(0.5),
    Dense(NUM_CLASSES, activation='softmax')
])

model.compile(
    optimizer=Adam(learning_rate=0.001),
    loss='categorical_crossentropy',
    metrics=['accuracy']
)

# ModelCheckpoint
checkpoint = tf.keras.callbacks.ModelCheckpoint(
    'best_model.keras',
    monitor='val_accuracy',
    save_best_only=True,
    mode='max'
)

# early_stop 
early_stop = EarlyStopping(
    monitor='val_accuracy', 
    patience=5,
    restore_best_weights=True,
    verbose=1
)

# train
history = model.fit(
    train_generator,
    epochs=30,
    validation_data=val_generator,
    callbacks=[checkpoint, early_stop] 
)

  super().__init__(activity_regularizer=activity_regularizer, **kwargs)
  self._warn_if_super_not_called()


Epoch 1/30
[1m59/59[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m32s[0m 490ms/step - accuracy: 0.7910 - loss: 0.8551 - val_accuracy: 0.8959 - val_loss: 0.4531
Epoch 2/30
[1m59/59[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m28s[0m 462ms/step - accuracy: 0.8544 - loss: 0.4736 - val_accuracy: 0.9080 - val_loss: 0.2486
Epoch 3/30
[1m59/59[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m26s[0m 443ms/step - accuracy: 0.8708 - loss: 0.3867 - val_accuracy: 0.9056 - val_loss: 0.2873
Epoch 4/30
[1m59/59[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m27s[0m 454ms/step - accuracy: 0.8712 - loss: 0.3919 - val_accuracy: 0.9080 - val_loss: 0.2623
Epoch 5/30
[1m59/59[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m28s[0m 464ms/step - accuracy: 0.8761 - loss: 0.3496 - val_accuracy: 0.9225 - val_loss: 0.1998
Epoch 6/30
[1m59/59[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m27s[0m 457ms/step - accuracy: 0.8668 - loss: 0.3739 - val_accuracy: 0.9274 - val_loss: 0.2162
Epoch 7/30
[1m59/59[

Model Test & Prediction

In [10]:
# test data
test_datagen = ImageDataGenerator(rescale=1./255)

test_generator = test_datagen.flow_from_directory(
    os.path.join(processed_path, 'test'),
    target_size=(IMG_SIZE, IMG_SIZE),
    batch_size=32,
    class_mode='categorical', 
    shuffle=False
)

Found 872 images belonging to 4 classes.


In [11]:
# load best data
best_model = load_model('best_model.keras')

# Evaluating Model Performance
test_loss, test_acc = best_model.evaluate(test_generator)
print(f'\nTest accuracy: {test_acc:.4f}')
print(f'Test loss: {test_loss:.4f}')

[1m28/28[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m3s[0m 76ms/step - accuracy: 0.8915 - loss: 0.2350

Test accuracy: 0.9128
Test loss: 0.2465


In [12]:
y_pred = best_model.predict(test_generator)
y_pred_classes = np.argmax(y_pred, axis=1)
y_true = test_generator.classes

class_names = list(test_generator.class_indices.keys())
print(classification_report(y_true, y_pred_classes, target_names=class_names))

conf_mat = confusion_matrix(y_true, y_pred_classes)
print("Confusion Matrix:\n", conf_mat)

[1m28/28[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m3s[0m 86ms/step
              precision    recall  f1-score   support

           0       0.55      0.35      0.43        31
           1       0.96      0.99      0.97       746
           2       0.63      0.60      0.61        65
           3       0.48      0.37      0.42        30

    accuracy                           0.91       872
   macro avg       0.65      0.58      0.61       872
weighted avg       0.90      0.91      0.91       872

Confusion Matrix:
 [[ 11   4  10   6]
 [  0 735  11   0]
 [  7  13  39   6]
 [  2  15   2  11]]
