In [1]:
import os
import cv2
import numpy as np
from collections import defaultdict
import random
import shutil
import logging
import time
import traceback
from typing import Dict, List, Tuple

In [2]:
# Configure logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

In [3]:
# Specify input and output paths
input_root = "/Users/tony/Desktop/coffeebeans_classification/original_dataset"
output_root = "/Users/tony/Desktop/coffeebeans_classification/balanced_dataset"
augmentation_strength = 1.0  # Augmentation strength, range 0-1

In [4]:
def count_images(folder: str) -> int:
    """
    Recursively count the number of images in a folder.
    
    This function is essential for determining the class imbalance in the dataset.
    It supports various image formats commonly used in computer vision tasks.
    
    Args:
        folder (str): Path to the folder containing images.
    
    Returns:
        int: Total number of images in the folder and its subfolders.
    """
    return sum(len([f for f in files if f.lower().endswith(('.png', '.jpg', '.jpeg', '.bmp', '.tif', '.tiff'))])
               for _, _, files in os.walk(folder))

In [5]:
def get_image_paths(folder: str) -> List[str]:
    """
    Recursively get all image paths in a folder.
    
    This function is crucial for data preprocessing and augmentation tasks,
    as it provides a list of all image files that need to be processed.
    
    Args:
        folder (str): Path to the folder containing images.
    
    Returns:
        List[str]: A list of full paths to all image files in the folder and its subfolders.
    """
    return [os.path.join(root, f) for root, _, files in os.walk(folder)
            for f in files if f.lower().endswith(('.png', '.jpg', '.jpeg', '.bmp', '.tif', '.tiff'))]


In [6]:
def augment_image(image: np.ndarray, method: str, strength: float = 1.0) -> np.ndarray:
    """
    Apply a specific augmentation method to an image.
    
    Data augmentation is a crucial technique in machine learning to increase the diversity of the training set,
    reduce overfitting, and improve model generalization.
    
    Args:
        image (np.ndarray): Input image as a NumPy array.
        method (str): Augmentation method to apply ('rotate', 'flip_horizontal', or 'flip_vertical').
        strength (float): Augmentation strength, currently not used but can be implemented for more fine-grained control.
    
    Returns:
        np.ndarray: Augmented image as a NumPy array.
    """
    if method == 'rotate':
        angle = random.choice([90, 180, 270])
        rows, cols = image.shape[:2]
        M = cv2.getRotationMatrix2D((cols/2, rows/2), angle, 1)
        image = cv2.warpAffine(image, M, (cols, rows))
    elif method == 'flip_horizontal':
        image = cv2.flip(image, 1)
    elif method == 'flip_vertical':
        image = cv2.flip(image, 0)
    return image

In [7]:
def augment_dataset(input_folder: str, output_folder: str, target_count: int, strength: float) -> Dict[str, List[str]]:
    """
    Augment the dataset until the target count is reached.
    
    This function implements a balanced augmentation strategy to ensure that
    each class has an equal number of samples, which is crucial for training
    unbiased machine learning models.
    
    Args:
        input_folder (str): Path to the input folder containing original images.
        output_folder (str): Path to the output folder for augmented images.
        target_count (int): The desired number of images after augmentation.
        strength (float): Augmentation strength, range 0-1.
    
    Returns:
        Dict[str, List[str]]: A dictionary mapping augmentation methods to lists of augmented image paths.
    """
    if not os.path.exists(output_folder):
        os.makedirs(output_folder)
    
    original_images = get_image_paths(input_folder)
    augmented_images = defaultdict(list)
    
    # Copy original images
    for img_path in original_images:
        rel_path = os.path.relpath(img_path, input_folder)
        output_path = os.path.join(output_folder, rel_path)
        os.makedirs(os.path.dirname(output_path), exist_ok=True)
        shutil.copy2(img_path, output_path)
        augmented_images['original'].append(output_path)
    
    # Perform data augmentation
    augmentation_methods = ['rotate', 'flip_horizontal', 'flip_vertical']
    while sum(len(images) for images in augmented_images.values()) < target_count:
        img_path = random.choice(original_images)
        image = cv2.imread(img_path)
        
        if image is None:
            logger.warning(f"Unable to read image: {img_path}")
            continue
        
        method = random.choice(augmentation_methods)
        augmented = augment_image(image, method, strength)
        rel_path = os.path.relpath(img_path, input_folder)
        output_name = f"augmented_{method}_{len(augmented_images[method])}_{os.path.basename(rel_path)}"
        output_path = os.path.join(output_folder, os.path.dirname(rel_path), output_name)
        os.makedirs(os.path.dirname(output_path), exist_ok=True)
        cv2.imwrite(output_path, augmented)
        augmented_images[method].append(output_path)
    
    return augmented_images

In [8]:
def balance_dataset(input_root: str, output_root: str, strength: float) -> Dict[str, Tuple[str, str]]:
    """
    Balance the dataset by augmenting underrepresented classes and subcategories.
    
    This function implements a hierarchical balancing strategy, ensuring that
    the dataset is balanced at the class, subclass, and sub-subclass levels.
    This is particularly important for multi-level classification tasks and
    for maintaining the hierarchical structure of the data.
    
    Args:
        input_root (str): Path to the root folder of the original dataset.
        output_root (str): Path to the root folder for the balanced dataset.
        strength (float): Augmentation strength, range 0-1.
    
    Returns:
        Dict[str, Tuple[str, str]]: A dictionary mapping class names to tuples of (input_folder, output_folder).
    """
    class_counts = defaultdict(lambda: defaultdict(lambda: defaultdict(int)))
    class_folders = {}
    
    # Count images in each class, subclass, and sub-subclass
    for class_name in os.listdir(input_root):
        if class_name.startswith('.'): continue  # Skip hidden files/folders
        class_path = os.path.join(input_root, class_name)
        if os.path.isdir(class_path):
            class_folders[class_name] = class_path
            for subclass_name in os.listdir(class_path):
                if subclass_name.startswith('.'): continue  # Skip hidden files/folders
                subclass_path = os.path.join(class_path, subclass_name)
                if os.path.isdir(subclass_path):
                    for sub_subclass_name in os.listdir(subclass_path):
                        if sub_subclass_name.startswith('.'): continue  # Skip hidden files/folders
                        sub_subclass_path = os.path.join(subclass_path, sub_subclass_name)
                        if os.path.isdir(sub_subclass_path):
                            count = count_images(sub_subclass_path)
                            class_counts[class_name][subclass_name][sub_subclass_name] = count
    
    # Find the maximum count for the deepest level
    max_sub_subclass_count = max(count for class_counts in class_counts.values()
                                 for subclass_counts in class_counts.values()
                                 for count in subclass_counts.values())
    
    augmented_classes = {}
    
    for class_name, subclass_counts in class_counts.items():
        class_output_folder = os.path.join(output_root, class_name)
        class_input_folder = class_folders[class_name]
        try:
            for subclass_name, sub_subclass_counts in subclass_counts.items():
                for sub_subclass_name, count in sub_subclass_counts.items():
                    sub_subclass_input_folder = os.path.join(class_input_folder, subclass_name, sub_subclass_name)
                    sub_subclass_output_folder = os.path.join(class_output_folder, subclass_name, sub_subclass_name)
                    
                    if count < max_sub_subclass_count:
                        logger.info(f"Augmenting {class_name}/{subclass_name}/{sub_subclass_name} from {count} to {max_sub_subclass_count}")
                        augment_dataset(sub_subclass_input_folder, sub_subclass_output_folder, max_sub_subclass_count, strength)
                    else:
                        logger.info(f"Copying {class_name}/{subclass_name}/{sub_subclass_name} without augmentation")
                        shutil.copytree(sub_subclass_input_folder, sub_subclass_output_folder, dirs_exist_ok=True)
            
            augmented_classes[class_name] = (class_input_folder, class_output_folder)
        except Exception as e:
            logger.error(f"Error processing class {class_name}: {str(e)}")
            logger.error(traceback.format_exc())
    
    return augmented_classes

In [9]:
# Main program
try:
    start_time = time.time()

    logger.info("Starting dataset balancing and augmentation process...")
    augmented_classes = balance_dataset(input_root, output_root, augmentation_strength)

    # Print final counts for each class, subclass, and sub-subclass
    logger.info("Dataset balancing completed. Final image counts:")
    for class_name, (original_folder, augmented_folder) in augmented_classes.items():
        logger.info(f"Class: {class_name}")
        for subclass in os.listdir(original_folder):
            if subclass.startswith('.'): continue  # Skip hidden files/folders
            subclass_path = os.path.join(original_folder, subclass)
            if os.path.isdir(subclass_path):
                logger.info(f"  Subclass: {subclass}")
                for sub_subclass in os.listdir(subclass_path):
                    if sub_subclass.startswith('.'): continue  # Skip hidden files/folders
                    original_sub_subclass_folder = os.path.join(original_folder, subclass, sub_subclass)
                    augmented_sub_subclass_folder = os.path.join(augmented_folder, subclass, sub_subclass)
                    if os.path.isdir(original_sub_subclass_folder):
                        original_count = count_images(original_sub_subclass_folder)
                        augmented_count = count_images(augmented_sub_subclass_folder)
                        logger.info(f"    Sub-subclass {sub_subclass}: Original: {original_count}, Augmented: {augmented_count}")

    end_time = time.time()
    logger.info(f"Total runtime: {end_time - start_time:.2f} seconds")
    logger.info("Dataset balancing and augmentation process completed successfully.")

except Exception as e:
    logger.error(f"An error occurred during execution: {str(e)}")
    logger.error(traceback.format_exc())

2024-10-14 14:40:39,212 - INFO - Starting dataset balancing and augmentation process...
2024-10-14 14:40:39,218 - INFO - Augmenting defect/dry/medium from 200 to 400
2024-10-14 14:40:39,735 - INFO - Augmenting defect/dry/light from 200 to 400
2024-10-14 14:40:40,324 - INFO - Augmenting defect/dry/dark from 200 to 400
2024-10-14 14:40:40,889 - INFO - Augmenting defect/honey/medium from 200 to 400
2024-10-14 14:40:41,275 - INFO - Augmenting defect/honey/light from 200 to 400
2024-10-14 14:40:41,664 - INFO - Augmenting defect/honey/dark from 200 to 400
2024-10-14 14:40:42,118 - INFO - Augmenting defect/wet/medium from 200 to 400
2024-10-14 14:40:42,527 - INFO - Augmenting defect/wet/light from 200 to 400
2024-10-14 14:40:42,901 - INFO - Augmenting defect/wet/dark from 200 to 400
2024-10-14 14:40:43,265 - INFO - Copying normal/dry/medium without augmentation
2024-10-14 14:40:43,425 - INFO - Copying normal/dry/light without augmentation
2024-10-14 14:40:43,594 - INFO - Copying normal/dry/da