# Data augmentation

In [1]:
"""
DATA AUGMENTATION SCRIPT
========================
Optimized for square images (1:1 aspect ratio)
Creates balanced dataset with 500 samples per class

Author: Based on image analysis showing perfect 1:1 ratio images
Usage: python augmentation_script.py
"""

import numpy as np
import os
import shutil
from tensorflow.keras.preprocessing.image import ImageDataGenerator, load_img, img_to_array, array_to_img

# ✅ PERFECT! Your images are already 1:1 squares - force_square is optimal
print("🎯 DATA AUGMENTATION FOR BALANCED DATASET")
print("=" * 70)
print("Target: 500 samples per class")
print()

# Define paths
base_dir = os.path.join(os.getcwd(), 'data')
train_dir = os.path.join(base_dir, 'train')
balanced_train_dir = os.path.join(base_dir, 'balanced_train')

# Class-specific augmentation requirements - EXACT VALUES for target 400
class_augmentation_info = {
    'MEL': {'current': 43,  'target': 500, 'augment_ratio': 11.63, 'samples_needed': 457},  # +1063.0%
    'SCC': {'current': 130, 'target': 500, 'augment_ratio': 3.85,  'samples_needed': 370},  # +284.6%
    'SEK': {'current': 161, 'target': 500, 'augment_ratio': 3.11,  'samples_needed': 339},  # +210.6%
    'NEV': {'current': 177, 'target': 500, 'augment_ratio': 2.82,  'samples_needed': 323},  # +182.5%
    'BCC': {'current': 588, 'target': 500, 'downsample_ratio': 0.85, 'samples_to_remove': 88},   # -15.0%
    'ACK': {'current': 506, 'target': 500, 'downsample_ratio': 0.99, 'samples_to_remove': 6}     # -1.2%
}

def count_images_per_class(directory):
    """Count images in each class folder"""
    class_counts = {}
    for class_name in os.listdir(directory):
        class_path = os.path.join(directory, class_name)
        if os.path.isdir(class_path):
            image_count = len([f for f in os.listdir(class_path)
                               if f.lower().endswith(('.png', '.jpg', '.jpeg'))])
            class_counts[class_name] = image_count
    return class_counts

def get_augmentation_strategy(class_name):
    """Get optimized augmentation parameters for each class"""

    if class_name == 'MEL':  # CRITICAL - needs 9.3x increase (357 new samples)
        return ImageDataGenerator(
            rotation_range=30,           # More rotation for critical class
            width_shift_range=0.3,       # Increased shifts
            height_shift_range=0.3,
            shear_range=0.25,           # More shear transformation
            zoom_range=0.3,             # More zoom variation
            horizontal_flip=True,
            vertical_flip=True,         # Safe for skin lesions
            brightness_range=[0.7, 1.3], # Wider brightness range
            channel_shift_range=0.2,    # Color channel shifts
            fill_mode='nearest'
        ), "🔥 AGGRESSIVE"

    elif class_name == 'SCC':  # HIGH - needs 3.08x increase (270 new samples)
        return ImageDataGenerator(
            rotation_range=25,
            width_shift_range=0.25,
            height_shift_range=0.25,
            shear_range=0.2,
            zoom_range=0.25,
            horizontal_flip=True,
            vertical_flip=True,
            brightness_range=[0.8, 1.2],
            channel_shift_range=0.15,
            fill_mode='nearest'
        ), "🔥 HIGH"

    elif class_name == 'SEK':  # MODERATE - needs 2.48x increase (239 new samples)
        return ImageDataGenerator(
            rotation_range=20,
            width_shift_range=0.2,
            height_shift_range=0.2,
            shear_range=0.15,
            zoom_range=0.2,
            horizontal_flip=True,
            brightness_range=[0.85, 1.15],
            channel_shift_range=0.1,
            fill_mode='nearest'
        ), "🔶 MODERATE"

    elif class_name == 'NEV':  # MODERATE - needs 2.26x increase (223 new samples)
        return ImageDataGenerator(
            rotation_range=20,
            width_shift_range=0.2,
            height_shift_range=0.2,
            shear_range=0.15,
            zoom_range=0.2,
            horizontal_flip=True,
            brightness_range=[0.85, 1.15],
            fill_mode='nearest'
        ), "🔶 MODERATE"

    else:  # Default for any other classes
        return ImageDataGenerator(
            rotation_range=15,
            width_shift_range=0.15,
            height_shift_range=0.15,
            zoom_range=0.15,
            horizontal_flip=True,
            fill_mode='nearest'
        ), "⚪ DEFAULT"

def generate_augmented_images(source_dir, target_dir, num_augmented, class_name):
    """Generate augmented images optimized for each class's needs"""

    # Get optimized augmentation strategy for this class
    datagen, strategy_name = get_augmentation_strategy(class_name)

    # Get original images
    image_files = [f for f in os.listdir(source_dir)
                   if f.lower().endswith(('.png', '.jpg', '.jpeg'))]

    if not image_files:
        print(f"No images found in {source_dir}")
        return

    generated_count = 0

    print(f"     Strategy: {strategy_name}")
    print(f"     📏 Square input → 150×150 output (optimal for your data)")
    print(f"     🎯 Generating {num_augmented} new samples...")

    while generated_count < num_augmented:
        # Randomly select an original image
        img_file = np.random.choice(image_files)
        img_path = os.path.join(source_dir, img_file)

        try:
            # Load and resize - your images are already square so this is perfect!
            img = load_img(img_path, target_size=(150, 150))  # Minimal distortion since input is square
            img_array = img_to_array(img)
            img_array = np.expand_dims(img_array, axis=0)

            # Generate augmented image
            aug_iter = datagen.flow(img_array, batch_size=1)
            aug_image = aug_iter[0][0]

            # Convert back to PIL image and save
            aug_image = array_to_img(aug_image)
            aug_filename = f"aug_{class_name}_{generated_count:04d}.png"
            aug_image.save(os.path.join(target_dir, aug_filename))

            generated_count += 1

            # Progress updates
            if num_augmented >= 100 and generated_count % 50 == 0:
                progress = (generated_count / num_augmented) * 100
                print(f"        🔄 Progress: {generated_count}/{num_augmented} ({progress:.1f}%)")
            elif num_augmented < 100 and generated_count % 25 == 0:
                progress = (generated_count / num_augmented) * 100
                print(f"        🔄 Progress: {generated_count}/{num_augmented} ({progress:.1f}%)")

        except Exception as e:
            print(f"        ⚠️ Error processing {img_file}: {e}")
            continue

    print(f"     ✅ COMPLETED: {num_augmented} augmented images for {class_name}")

def create_balanced_dataset():
    """Main function to create balanced dataset"""

    print("🛡️  CREATING BALANCED DATASET")
    print(f"     Source: {train_dir} (READ ONLY - never modified)")
    print(f"     Target: {balanced_train_dir} (NEW DIRECTORY)")
    print()

    # Verify source directory exists
    if not os.path.exists(train_dir):
        print(f"❌ ERROR: Source directory not found: {train_dir}")
        print("   Please check your data directory structure.")
        return False

    # Remove existing balanced directory if it exists
    if os.path.exists(balanced_train_dir):
        print(f"     🗑️ Removing existing: {balanced_train_dir}")
        shutil.rmtree(balanced_train_dir)

    # Create new balanced train directory
    os.makedirs(balanced_train_dir)
    print(f"     ✅ Created: {balanced_train_dir}")
    print()

    # Get current counts from original directory (READ ONLY)
    current_counts = count_images_per_class(train_dir)

    if not current_counts:
        print("❌ ERROR: No class directories found in train directory")
        return False

    total_original = sum(current_counts.values())

    print("📊 CURRENT CLASS DISTRIBUTION:")
    for class_name, count in current_counts.items():
        info = class_augmentation_info.get(class_name, {})
        if 'target' in info:
            change = info['target'] - count
            symbol = "🔬" if change > 0 else "📉" if change < 0 else "✅"
            print(f"     {symbol} {class_name}: {count} → {info['target']} ({change:+d})")
        else:
            print(f"     ⚪ {class_name}: {count} (unknown class)")
    print(f"     📊 TOTAL ORIGINAL: {total_original} images")
    print()

    # Process each class
    for class_name in current_counts.keys():
        class_info = class_augmentation_info.get(class_name, {})
        original_class_dir = os.path.join(train_dir, class_name)  # SOURCE (read-only)
        balanced_class_dir = os.path.join(balanced_train_dir, class_name)  # TARGET (new)
        os.makedirs(balanced_class_dir)

        # Get all image files from ORIGINAL directory
        image_files = [f for f in os.listdir(original_class_dir)
                       if f.lower().endswith(('.png', '.jpg', '.jpeg'))]

        if not image_files:
            print(f"⚠️ WARNING: No images found in {class_name} directory")
            continue

        if 'samples_needed' in class_info:  # Classes needing augmentation
            print(f"🔬 AUGMENTING {class_name}:")
            print(f"     Current: {current_counts[class_name]} → Target: {class_info['target']}")
            print(f"     Multiplication factor: {class_info['augment_ratio']:.2f}x")
            print(f"     New samples to generate: {class_info['samples_needed']}")

            # COPY (not move) original images to new directory
            copied_count = 0
            for img_file in image_files:
                try:
                    shutil.copy2(os.path.join(original_class_dir, img_file),
                                 os.path.join(balanced_class_dir, img_file))
                    copied_count += 1
                except Exception as e:
                    print(f"        ⚠️ Error copying {img_file}: {e}")

            print(f"     ✅ Copied {copied_count} original images")

            # Generate augmented images in new directory
            generate_augmented_images(original_class_dir, balanced_class_dir,
                                      class_info['samples_needed'], class_name)

        elif 'samples_to_remove' in class_info:  # Classes needing downsampling
            print(f"📉 DOWNSAMPLING {class_name}:")
            print(f"     Current: {current_counts[class_name]} → Target: {class_info['target']}")
            print(f"     Keep ratio: {class_info['downsample_ratio']:.2f}x")
            print(f"     Samples to remove: {class_info['samples_to_remove']}")

            # Randomly select images to keep (COPY, don't modify original)
            np.random.seed(42)  # For reproducible results
            np.random.shuffle(image_files)
            selected_files = image_files[:class_info['target']]

            copied_count = 0
            for img_file in selected_files:
                try:
                    shutil.copy2(os.path.join(original_class_dir, img_file),
                                 os.path.join(balanced_class_dir, img_file))
                    copied_count += 1
                except Exception as e:
                    print(f"        ⚠️ Error copying {img_file}: {e}")

            print(f"     ✅ Copied {copied_count} selected images")

        else:
            print(f"⚪ UNKNOWN CLASS {class_name}: Copying all images as-is")
            for img_file in image_files:
                shutil.copy2(os.path.join(original_class_dir, img_file),
                             os.path.join(balanced_class_dir, img_file))

        print()

    return True

def verify_balanced_dataset():
    """Verify the balanced dataset was created correctly"""

    print("🔍 FINAL VERIFICATION:")
    print("=" * 70)

    if not os.path.exists(balanced_train_dir):
        print("❌ ERROR: Balanced directory was not created")
        return False

    # Get final counts
    final_counts = count_images_per_class(balanced_train_dir)
    total_balanced = sum(final_counts.values())

    print("✅ BALANCED DATASET VERIFICATION:")

    all_perfect = True
    for class_name, count in final_counts.items():
        target = class_augmentation_info.get(class_name, {}).get('target', count)
        status = "✅" if count == target else "⚠️"
        if count != target:
            all_perfect = False
        print(f"     {status} {class_name}: {count} samples (target: {target})")

    print(f"\n📊 SUMMARY:")
    original_counts = count_images_per_class(train_dir)
    total_original = sum(original_counts.values())

    print(f"     📈 Original total: {total_original} images")
    print(f"     📈 Balanced total: {total_balanced} images")
    print(f"     📈 Net change: {'+' if total_balanced > total_original else ''}{total_balanced - total_original} images")
    print(f"     📈 Classes: {len(final_counts)} classes")
    print(f"     📈 Target per class: 400 images")

    if all_perfect:
        print(f"\n🎯 PERFECT BALANCE ACHIEVED! ✅")
        print(f"     All classes have exactly 400 samples")
    else:
        print(f"\n⚠️ Minor discrepancies detected")
        print(f"     Check individual class counts above")

    print(f"\n🛡️ SAFETY CONFIRMATION:")
    print(f"     ✅ Original data preserved at: {train_dir}")
    print(f"     ✅ Balanced data created at: {balanced_train_dir}")
    print(f"     ✅ All images resized to optimal 150×150 format")

    return all_perfect

def main():
    """Main execution function"""

    print("🚀 STARTING DATA AUGMENTATION PROCESS...")
    print("=" * 70)

    # Create balanced dataset
    success = create_balanced_dataset()

    if not success:
        print("❌ AUGMENTATION FAILED")
        return

    # Verify results
    verify_balanced_dataset()

    print("\n🎉 DATA AUGMENTATION COMPLETED!")
    print("=" * 70)
    print("📁 Your balanced dataset is ready for training!")
    print(f"📁 Location: {balanced_train_dir}")
    print("🚀 Next step: Run the model training script")
    print()

if __name__ == "__main__":
    main()

🎯 DATA AUGMENTATION FOR BALANCED DATASET
Target: 500 samples per class

🚀 STARTING DATA AUGMENTATION PROCESS...
🛡️  CREATING BALANCED DATASET
     Source: /Users/daperez/Documents/ProjectsSW/DermAI/data/train (READ ONLY - never modified)
     Target: /Users/daperez/Documents/ProjectsSW/DermAI/data/balanced_train (NEW DIRECTORY)

     ✅ Created: /Users/daperez/Documents/ProjectsSW/DermAI/data/balanced_train

📊 CURRENT CLASS DISTRIBUTION:
     🔬 SEK: 161 → 500 (+339)
     📉 BCC: 588 → 500 (-88)
     🔬 NEV: 177 → 500 (+323)
     🔬 SCC: 130 → 500 (+370)
     🔬 MEL: 43 → 500 (+457)
     📉 ACK: 506 → 500 (-6)
     📊 TOTAL ORIGINAL: 1605 images

🔬 AUGMENTING SEK:
     Current: 161 → Target: 500
     Multiplication factor: 3.11x
     New samples to generate: 339
     ✅ Copied 161 original images
     Strategy: 🔶 MODERATE
     📏 Square input → 150×150 output (optimal for your data)
     🎯 Generating 339 new samples...
        🔄 Progress: 50/339 (14.7%)
        🔄 Progress: 100/339 (29.5%)
      