In [None]:
# Install required packages
!pip install pillow matplotlib seaborn opencv-python

In [None]:
# Import libraries
import os
import shutil
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image, ImageEnhance, ImageFilter
import random
from collections import Counter
import seaborn as sns
from google.colab import drive, files
import zipfile
import cv2

In [None]:
# Mount Google Drive
print("Mounting Google Drive...")
drive.mount('/content/drive')

In [None]:
class ImageDatasetBalancer:
    def __init__(self, input_folder, output_folder, target_count=1050):
        """
        Initialize the Image Dataset Balancer for Google Colab

        Args:
            input_folder (str): Path to the input dataset folder
            output_folder (str): Path to save the balanced dataset
            target_count (int): Target number of images per class (default: 1050)
        """
        self.input_folder = input_folder
        self.output_folder = output_folder
        self.target_count = target_count
        self.class_counts_before = {}
        self.class_counts_after = {}

        # Set matplotlib style for better plots in Colab
        plt.style.use('default')

    def count_images_in_folder(self, folder_path):
        """Count images in a folder"""
        if not os.path.exists(folder_path):
            return 0

        valid_extensions = {'.jpg', '.jpeg', '.png', '.bmp', '.tiff', '.tif'}
        count = 0

        try:
            for file in os.listdir(folder_path):
                if os.path.splitext(file.lower())[1] in valid_extensions:
                    count += 1
        except Exception as e:
            print(f"Error counting files in {folder_path}: {e}")

        return count

    def analyze_dataset(self):
        """Analyze the current dataset and count images per class"""
        print("🔍 === Dataset Analysis ===")

        # Check if input folder exists
        if not os.path.exists(self.input_folder):
            print(f"❌ Input folder not found: {self.input_folder}")
            return None

        class_names = ['benign', 'malignant', 'normal']  # Fixed the typo from 'bengin'

        for class_name in class_names:
            class_path = os.path.join(self.input_folder, class_name)
            if os.path.exists(class_path):
                count = self.count_images_in_folder(class_path)
                self.class_counts_before[class_name] = count
                print(f"📁 {class_name.capitalize()}: {count} images")
            else:
                print(f"⚠️  Warning: {class_name} folder not found at {class_path}")
                self.class_counts_before[class_name] = 0

        total_before = sum(self.class_counts_before.values())
        print(f"\n📊 Total images before augmentation: {total_before}")

        # Show class distribution
        if total_before > 0:
            print("\n📈 Current class distribution:")
            for class_name, count in self.class_counts_before.items():
                percentage = (count / total_before) * 100
                print(f"   {class_name.capitalize()}: {percentage:.1f}%")

        return self.class_counts_before

    def augment_image(self, image_path, output_path, augmentation_type):
        """
        Apply various augmentation techniques to an image
        """
        try:
            with Image.open(image_path) as img:
                # Convert to RGB if necessary
                if img.mode not in ['RGB', 'L']:
                    img = img.convert('RGB')

                # Apply different augmentation techniques
                if augmentation_type == 'rotation':
                    angle = random.uniform(-25, 25)
                    img = img.rotate(angle, expand=True, fillcolor='white')

                elif augmentation_type == 'flip_horizontal':
                    img = img.transpose(Image.FLIP_LEFT_RIGHT)

                elif augmentation_type == 'flip_vertical':
                    img = img.transpose(Image.FLIP_TOP_BOTTOM)

                elif augmentation_type == 'brightness':
                    enhancer = ImageEnhance.Brightness(img)
                    factor = random.uniform(0.8, 1.2)
                    img = enhancer.enhance(factor)

                elif augmentation_type == 'contrast':
                    enhancer = ImageEnhance.Contrast(img)
                    factor = random.uniform(0.8, 1.2)
                    img = enhancer.enhance(factor)

                elif augmentation_type == 'sharpness':
                    enhancer = ImageEnhance.Sharpness(img)
                    factor = random.uniform(0.7, 1.3)
                    img = enhancer.enhance(factor)

                elif augmentation_type == 'blur':
                    radius = random.uniform(0.5, 1.5)
                    img = img.filter(ImageFilter.GaussianBlur(radius=radius))

                elif augmentation_type == 'zoom':
                    # Zoom in by cropping and resizing
                    width, height = img.size
                    crop_factor = random.uniform(0.05, 0.2)
                    left = int(width * crop_factor)
                    top = int(height * crop_factor)
                    right = int(width * (1 - crop_factor))
                    bottom = int(height * (1 - crop_factor))
                    img = img.crop((left, top, right, bottom))
                    img = img.resize((width, height), Image.Resampling.LANCZOS)

                elif augmentation_type == 'noise':
                    # Add slight noise using numpy array manipulation
                    img_array = np.array(img)
                    noise = np.random.normal(0, 5, img_array.shape).astype(np.uint8)
                    img_array = np.clip(img_array.astype(np.int16) + noise, 0, 255).astype(np.uint8)
                    img = Image.fromarray(img_array)

                # Ensure the output directory exists
                os.makedirs(os.path.dirname(output_path), exist_ok=True)

                # Save the augmented image
                img.save(output_path, 'JPEG', quality=90)

        except Exception as e:
            print(f"❌ Error augmenting image {image_path}: {str(e)}")

    def balance_dataset(self):
        """Balance the dataset by augmenting underrepresented classes"""
        print("\n⚖️  === Starting Dataset Balancing ===")

        # Create output directory structure
        os.makedirs(self.output_folder, exist_ok=True)

        augmentation_techniques = [
            'rotation', 'flip_horizontal', 'flip_vertical', 'brightness',
            'contrast', 'sharpness', 'blur', 'zoom', 'noise'
        ]

        class_names = list(self.class_counts_before.keys())

        for class_name in class_names:
            input_class_path = os.path.join(self.input_folder, class_name)
            output_class_path = os.path.join(self.output_folder, class_name)

            os.makedirs(output_class_path, exist_ok=True)

            current_count = self.class_counts_before[class_name]
            needed_count = max(0, self.target_count - current_count)

            print(f"\n🔄 Processing {class_name}:")
            print(f"   📋 Current images: {current_count}")
            print(f"   🎯 Target images: {self.target_count}")
            print(f"   ➕ Need to generate: {needed_count}")

            if current_count == 0:
                print(f"   ⚠️  Skipping {class_name} - no images found")
                self.class_counts_after[class_name] = 0
                continue

            # Copy original images
            valid_extensions = {'.jpg', '.jpeg', '.png', '.bmp', '.tiff', '.tif'}
            original_files = []

            try:
                for file in os.listdir(input_class_path):
                    if os.path.splitext(file.lower())[1] in valid_extensions:
                        original_files.append(file)
                        src_path = os.path.join(input_class_path, file)
                        dst_path = os.path.join(output_class_path, file)
                        shutil.copy2(src_path, dst_path)

                print(f"   ✅ Copied {len(original_files)} original images")

            except Exception as e:
                print(f"   ❌ Error copying files: {e}")
                continue

            # Generate augmented images if needed
            if needed_count > 0 and original_files:
                print(f"   🎨 Generating {needed_count} augmented images...")

                # Progress tracking
                progress_step = max(1, needed_count // 10)

                for i in range(needed_count):
                    try:
                        # Select a random original image
                        original_file = random.choice(original_files)
                        original_path = os.path.join(input_class_path, original_file)

                        # Select a random augmentation technique
                        aug_type = random.choice(augmentation_techniques)

                        # Create augmented filename
                        name, ext = os.path.splitext(original_file)
                        aug_filename = f"{name}_aug_{i+1:04d}_{aug_type}.jpg"
                        aug_path = os.path.join(output_class_path, aug_filename)

                        # Apply augmentation
                        self.augment_image(original_path, aug_path, aug_type)

                        # Progress indicator
                        if (i + 1) % progress_step == 0 or (i + 1) == needed_count:
                            progress = ((i + 1) / needed_count) * 100
                            print(f"      🔄 Progress: {i + 1}/{needed_count} ({progress:.1f}%)")

                    except Exception as e:
                        print(f"      ❌ Error generating image {i+1}: {e}")
                        continue

            # Count final images
            final_count = self.count_images_in_folder(output_class_path)
            self.class_counts_after[class_name] = final_count
            print(f"   ✅ Final count: {final_count}")

    def create_comparison_chart(self, save_chart=True):
        """Create a bar chart comparing before and after counts"""
        print("\n📊 === Creating Comparison Chart ===")

        classes = list(self.class_counts_before.keys())
        before_counts = [self.class_counts_before[cls] for cls in classes]
        after_counts = [self.class_counts_after[cls] for cls in classes]

        # Set up the plot with larger size for Colab
        plt.figure(figsize=(14, 8))

        # Create bar positions
        x = np.arange(len(classes))
        width = 0.35

        # Create bars with better colors
        bars1 = plt.bar(x - width/2, before_counts, width, label='Before Augmentation',
                       color='#FF6B6B', alpha=0.8, edgecolor='black', linewidth=0.5)
        bars2 = plt.bar(x + width/2, after_counts, width, label='After Augmentation',
                       color='#4ECDC4', alpha=0.8, edgecolor='black', linewidth=0.5)

        # Customize the plot
        plt.xlabel('Classes', fontsize=14, fontweight='bold')
        plt.ylabel('Number of Images', fontsize=14, fontweight='bold')
        plt.title('Dataset Balancing: Before vs After Augmentation\n(Medical Image Classification)',
                 fontsize=16, fontweight='bold', pad=20)
        plt.xticks(x, [cls.capitalize() for cls in classes], fontsize=12)
        plt.legend(fontsize=12, loc='upper left')
        plt.grid(axis='y', alpha=0.3, linestyle='--')

        # Add value labels on bars
        for bar in bars1:
            height = bar.get_height()
            plt.text(bar.get_x() + bar.get_width()/2., height + 15,
                    f'{int(height):,}', ha='center', va='bottom',
                    fontweight='bold', fontsize=10)

        for bar in bars2:
            height = bar.get_height()
            plt.text(bar.get_x() + bar.get_width()/2., height + 15,
                    f'{int(height):,}', ha='center', va='bottom',
                    fontweight='bold', fontsize=10)

        # Set y-axis limit
        max_count = max(max(before_counts), max(after_counts))
        plt.ylim(0, max_count * 1.2)

        # Add a subtle background color
        plt.gca().set_facecolor('#F8F9FA')

        plt.tight_layout()

        if save_chart:
            # Save the chart
            chart_path = os.path.join(self.output_folder, 'dataset_comparison_chart.png')
            plt.savefig(chart_path, dpi=300, bbox_inches='tight', facecolor='white')
            print(f"📊 Chart saved to: {chart_path}")

        plt.show()

    def print_summary(self):
        """Print a detailed summary of the balancing process"""
        print("\n" + "="*60)
        print("🎯 DATASET BALANCING SUMMARY")
        print("="*60)

        print("\n📊 BEFORE AUGMENTATION:")
        total_before = 0
        for class_name, count in self.class_counts_before.items():
            print(f"   {class_name.capitalize()}: {count:,} images")
            total_before += count
        print(f"   📈 Total: {total_before:,} images")

        print("\n📊 AFTER AUGMENTATION:")
        total_after = 0
        for class_name, count in self.class_counts_after.items():
            print(f"   {class_name.capitalize()}: {count:,} images")
            total_after += count
        print(f"   📈 Total: {total_after:,} images")

        images_added = total_after - total_before
        print(f"\n➕ Images added: {images_added:,}")
        print(f"🎯 Target per class: {self.target_count:,}")

        # Check if dataset is balanced
        unique_counts = set(self.class_counts_after.values())
        is_balanced = len(unique_counts) <= 1 and 0 not in unique_counts

        print(f"⚖️  Dataset is balanced: {'✅ Yes' if is_balanced else '❌ No'}")

        if images_added > 0:
            print(f"📈 Dataset size increased by: {((images_added/total_before)*100):.1f}%")

        # Show final distribution
        if total_after > 0:
            print(f"\n📊 Final class distribution:")
            for class_name, count in self.class_counts_after.items():
                percentage = (count / total_after) * 100
                print(f"   {class_name.capitalize()}: {percentage:.1f}%")

    def create_zip_download(self):
        """Create a zip file of the balanced dataset for easy download"""
        print("\n📦 Creating zip file for download...")

        zip_path = f"{self.output_folder}_balanced.zip"

        with zipfile.ZipFile(zip_path, 'w', zipfile.ZIP_DEFLATED) as zipf:
            for root, dirs, files in os.walk(self.output_folder):
                for file in files:
                    file_path = os.path.join(root, file)
                    arc_name = os.path.relpath(file_path, os.path.dirname(self.output_folder))
                    zipf.write(file_path, arc_name)

        print(f"✅ Zip file created: {zip_path}")
        return zip_path

    def run_complete_process(self):
        """Run the complete dataset balancing process"""
        print("🚀 Starting Image Dataset Balancing Process...")
        print("="*70)

        # Step 1: Analyze current dataset
        if self.analyze_dataset() is None:
            return

        # Step 2: Balance the dataset
        self.balance_dataset()

        # Step 3: Create comparison chart
        self.create_comparison_chart()

        # Step 4: Print summary
        self.print_summary()

        print("\n" + "="*70)
        print("🎉 Dataset balancing process completed successfully!")
        print(f"📁 Balanced dataset saved to: {self.output_folder}")

        # Optional: Create zip for download
        try:
            zip_path = self.create_zip_download()
            print(f"📦 Zip file ready for download: {zip_path}")
        except Exception as e:
            print(f"⚠️  Could not create zip file: {e}")

In [None]:
INPUT_DATASET_PATH = "/content/drive/MyDrive/your_dataset_folder"  # Update this path
OUTPUT_DATASET_PATH = "/content/drive/MyDrive/balanced_dataset"    # Update this path

In [None]:
# ===== RUN THE BALANCING PROCESS =====
def main():
    """Main function to run the dataset balancing"""

    # Check if paths exist
    if not os.path.exists(INPUT_DATASET_PATH):
        print(f"❌ Please update INPUT_DATASET_PATH. Current path not found: {INPUT_DATASET_PATH}")
        print("💡 Right-click on your dataset folder in Google Drive and copy the path")
        return

    # Create balancer instance
    balancer = ImageDatasetBalancer(
        input_folder=INPUT_DATASET_PATH,
        output_folder=OUTPUT_DATASET_PATH,
        target_count=1050  # Target number of images per class
    )

    # Run the complete process
    balancer.run_complete_process()

# Uncomment the line below to run the process
# main()