In [None]:
import cv2
import numpy as np
import random
import os
import pickle
from glob import glob
import multiprocessing as mp
from concurrent.futures import ThreadPoolExecutor, ProcessPoolExecutor
from tqdm import tqdm
import albumentations as A
import time
from pathlib import Path

class FacialDataBalancer:
    def __init__(self, target_count=300, target_size=(112, 112), 
                 max_workers=None, use_multiprocessing=True):
        self.target_count = target_count
        self.target_size = target_size
        self.use_multiprocessing = use_multiprocessing
        
        # Set max workers (use CPU cores - 2 for stability)
        if max_workers is None:
            self.max_workers = max(1, mp.cpu_count() - 2)
        else:
            self.max_workers = max_workers
        
        print(f"Using {self.max_workers} workers for processing")
        
        # Define augmentation pipelines
        self.setup_augmentation_pipelines()
        
        # Statistics
        self.stats = {
            'total_processed': 0,
            'total_augmented': 0,
            'failed_images': 0,
            'start_time': time.time()
        }
        
        # For single progress bar
        self.main_progress_bar = None
    
    def setup_augmentation_pipelines(self):
        """Setup different augmentation pipelines based on dataset size"""
        # Light augmentations (for nearly sufficient datasets)
        self.light_aug = A.Compose([
            A.HorizontalFlip(p=0.3),
            A.ShiftScaleRotate(
                shift_limit=0.03,
                scale_limit=0.03,
                rotate_limit=3,
                p=0.3
            ),
            A.RandomBrightnessContrast(
                brightness_limit=0.1,
                contrast_limit=0.1,
                p=0.2
            )
        ])
        
        # Moderate augmentations (for medium-sized datasets)
        self.moderate_aug = A.Compose([
            A.HorizontalFlip(p=0.5),
            A.ShiftScaleRotate(
                shift_limit=0.07,
                scale_limit=0.07,
                rotate_limit=7,
                p=0.5
            ),
            A.RandomBrightnessContrast(
                brightness_limit=0.15,
                contrast_limit=0.15,
                p=0.4
            ),
            A.HueSaturationValue(
                hue_shift_limit=5,
                sat_shift_limit=10,
                val_shift_limit=10,
                p=0.3
            ),
            A.GaussNoise(var_limit=(5.0, 15.0), p=0.2)
        ])
        
        # Aggressive augmentations (for small datasets)
        self.aggressive_aug = A.Compose([
            A.HorizontalFlip(p=0.7),
            A.ShiftScaleRotate(
                shift_limit=0.1,
                scale_limit=0.1,
                rotate_limit=10,
                p=0.7
            ),
            A.RandomBrightnessContrast(
                brightness_limit=0.2,
                contrast_limit=0.2,
                p=0.5
            ),
            A.HueSaturationValue(
                hue_shift_limit=10,
                sat_shift_limit=20,
                val_shift_limit=20,
                p=0.5
            ),
            A.GaussNoise(var_limit=(10.0, 30.0), p=0.3),
            A.GaussianBlur(blur_limit=(3, 7), p=0.3),
            A.RandomGamma(gamma_limit=(80, 120), p=0.2),
            A.CoarseDropout(
                max_holes=2,
                max_height=8,
                max_width=8,
                min_holes=1,
                min_height=4,
                min_width=4,
                fill_value=0,
                p=0.1
            )
        ])
    
    def balance_dataset(self, dataset_path, output_path, checkpoint_interval=5):
        """
        Main function to balance the dataset with a single progress bar
        """
        print(f"\n{'='*60}")
        print(f"Starting dataset balancing")
        print(f"Input: {dataset_path}")
        print(f"Output: {output_path}")
        print(f"Target: {self.target_count} images per class at {self.target_size}")
        print(f"{'='*60}\n")
        
        # Get class distribution
        class_distribution = self.get_class_distribution(dataset_path)
        total_classes = len(class_distribution)
        
        print(f"Found {total_classes} classes")
        
        # Process classes with checkpointing
        processed_classes = self.load_checkpoint(output_path)
        
        # Calculate number of classes to process
        classes_to_process = [c for c in class_distribution.keys() if c not in processed_classes]
        
        if not classes_to_process:
            print("\nAll classes already processed!")
            return
        
        print(f"Processing {len(classes_to_process)} classes (skipping {len(processed_classes)} already processed)")
        
        # Create single main progress bar
        self.main_progress_bar = tqdm(
            total=len(classes_to_process),
            desc="Overall Progress",
            position=0,
            leave=True,
            bar_format='{l_bar}{bar}| {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}{postfix}]'
        )
        
        # Add initial info to progress bar
        self.main_progress_bar.set_postfix_str(f"Classes: 0/{len(classes_to_process)}")
        
        # Process each class
        for idx, class_name in enumerate(classes_to_process, 1):
            image_paths = class_distribution[class_name]
            
            # Update progress bar description
            self.main_progress_bar.set_description(f"Processing: {class_name}")
            
            current_count = len(image_paths)
            
            # Process the class
            if current_count >= self.target_count:
                self.process_sufficient_class_simple(image_paths, class_name, output_path)
            else:
                self.augment_class_simple(image_paths, class_name, current_count, output_path)
            
            # Update checkpoint
            processed_classes.add(class_name)
            if idx % checkpoint_interval == 0:
                self.save_checkpoint(output_path, processed_classes)
            
            # Update statistics and progress bar
            self.stats['total_processed'] += 1
            self.main_progress_bar.update(1)
            self.main_progress_bar.set_postfix_str(f"Classes: {idx}/{len(classes_to_process)}")
        
        # Close main progress bar
        self.main_progress_bar.close()
        
        # Save final checkpoint
        self.save_checkpoint(output_path, processed_classes)
        
        # Print final statistics
        self.print_statistics()
    
    def process_sufficient_class_simple(self, image_paths, class_name, output_path):
        """Simple processing for classes with enough images"""
        output_class_dir = Path(output_path) / class_name
        output_class_dir.mkdir(parents=True, exist_ok=True)
        
        # Sample target_count images if there are more
        if len(image_paths) > self.target_count:
            image_paths = random.sample(image_paths, self.target_count)
        
        count = 0
        for idx, img_path in enumerate(image_paths):
            img = self.load_and_resize(img_path)
            if img is not None:
                output_path_img = output_class_dir / f"{class_name}_{idx:04d}.jpg"
                cv2.imwrite(str(output_path_img), img[:, :, ::-1])
                count += 1
        
        # Update progress bar with brief info
        if self.main_progress_bar:
            old_desc = self.main_progress_bar.desc
            self.main_progress_bar.set_description(f"{old_desc.split(':')[0]}: {class_name} ✓")
    
    def augment_class_simple(self, image_paths, class_name, current_count, output_path):
        """Simple augmentation without individual progress bars"""
        output_class_dir = Path(output_path) / class_name
        output_class_dir.mkdir(parents=True, exist_ok=True)
        
        # Load and resize all valid images
        valid_images = []
        for img_path in image_paths:
            img = self.load_and_resize(img_path)
            if img is not None:
                valid_images.append(img)
            else:
                self.stats['failed_images'] += 1
        
        if not valid_images:
            if self.main_progress_bar:
                old_desc = self.main_progress_bar.desc
                self.main_progress_bar.set_description(f"{old_desc.split(':')[0]}: {class_name} ✗ (no valid images)")
            return
        
        current_count = len(valid_images)
        
        # Save original images
        for i, img in enumerate(valid_images):
            output_path_img = output_class_dir / f"{class_name}_orig_{i:04d}.jpg"
            cv2.imwrite(str(output_path_img), img[:, :, ::-1])
        
        # Determine augmentation strategy
        needed_images = self.target_count - current_count
        
        if current_count < 50:
            augmentations_per_image = min(20, max(5, needed_images // current_count + 1))
            augmentation_pipeline = self.aggressive_aug
        elif current_count < 200:
            augmentations_per_image = min(10, max(3, needed_images // current_count + 1))
            augmentation_pipeline = self.moderate_aug
        else:
            augmentations_per_image = min(5, max(1, needed_images // current_count + 1))
            augmentation_pipeline = self.light_aug
        
        # Generate augmented images
        augmented_count = 0
        total_to_generate = needed_images
        
        while augmented_count < total_to_generate:
            for img in valid_images:
                if augmented_count >= total_to_generate:
                    break
                
                # Generate multiple augmentations per image
                for _ in range(augmentations_per_image):
                    if augmented_count >= total_to_generate:
                        break
                    
                    try:
                        # Apply augmentation
                        augmented = augmentation_pipeline(image=img)['image']
                        
                        # Save augmented image
                        output_path_aug = output_class_dir / f"{class_name}_aug_{augmented_count:04d}.jpg"
                        cv2.imwrite(str(output_path_aug), augmented[:, :, ::-1])
                        
                        augmented_count += 1
                        
                    except Exception:
                        continue
            
            # If we still need more images, shuffle and continue
            if augmented_count < total_to_generate:
                random.shuffle(valid_images)
        
        # Update statistics
        self.stats['total_augmented'] += augmented_count
        
        # Update progress bar with brief info
        if self.main_progress_bar:
            old_desc = self.main_progress_bar.desc
            self.main_progress_bar.set_description(f"{old_desc.split(':')[0]}: {class_name} ✓ (+{augmented_count})")
    
    def get_class_distribution(self, dataset_path):
        """Get all images per class efficiently"""
        dataset_path = Path(dataset_path)
        distribution = {}
        
        # Get all class directories
        class_dirs = [d for d in dataset_path.iterdir() if d.is_dir()]
        
        for class_dir in class_dirs:
            class_name = class_dir.name
            # Use glob with common image extensions
            image_patterns = ['*.jpg', '*.jpeg', '*.png', '*.JPG', '*.JPEG', '*.PNG']
            image_paths = []
            for pattern in image_patterns:
                image_paths.extend(list(class_dir.glob(pattern)))
            
            # Convert Path objects to strings
            image_paths = [str(p) for p in image_paths]
            
            if image_paths:  # Only include classes with images
                distribution[class_name] = image_paths
        
        # Sort by number of images (ascending)
        distribution = dict(sorted(distribution.items(), key=lambda x: len(x[1])))
        return distribution
    
    def load_and_resize(self, image_path):
        """
        Fast and efficient image loading and resizing
        """
        try:
            # Read image
            img = cv2.imread(image_path)
            if img is None:
                # Try alternative reading for problematic images
                try:
                    with open(image_path, 'rb') as f:
                        img_array = np.frombuffer(f.read(), np.uint8)
                        img = cv2.imdecode(img_array, cv2.IMREAD_COLOR)
                except:
                    return None
            
            if img is None:
                return None
            
            # Convert to RGB
            img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
            
            h, w = img.shape[:2]
            
            # Handle different size cases
            if h == self.target_size[0] and w == self.target_size[1]:
                return img
            elif h < self.target_size[0] or w < self.target_size[1]:
                return self.quick_resize_small(img)
            else:
                return self.quick_resize_large(img)
                
        except Exception as e:
            return None
    
    def quick_resize_small(self, img):
        """Quick resize for small images"""
        h, w = img.shape[:2]
        
        # Simple padding approach for speed
        if h < self.target_size[0] or w < self.target_size[1]:
            # Calculate padding
            pad_h = max(0, self.target_size[0] - h)
            pad_w = max(0, self.target_size[1] - w)
            
            # Apply padding
            img = cv2.copyMakeBorder(
                img,
                pad_h//2, pad_h - pad_h//2,
                pad_w//2, pad_w - pad_w//2,
                cv2.BORDER_REFLECT101
            )
        
        # Final resize if needed
        if img.shape[0] != self.target_size[0] or img.shape[1] != self.target_size[1]:
            img = cv2.resize(img, self.target_size, interpolation=cv2.INTER_LINEAR)
        
        return img
    
    def quick_resize_large(self, img):
        """Quick resize for large images"""
        # Simple center crop and resize for speed
        h, w = img.shape[:2]
        
        # Calculate scale for center crop
        scale = max(self.target_size[0] / h, self.target_size[1] / w)
        
        # Resize
        img = cv2.resize(img, None, fx=scale, fy=scale, interpolation=cv2.INTER_AREA)
        
        # Center crop
        h_new, w_new = img.shape[:2]
        y = (h_new - self.target_size[0]) // 2
        x = (w_new - self.target_size[1]) // 2
        
        img = img[y:y+self.target_size[0], x:x+self.target_size[1]]
        
        return img
    
    def load_checkpoint(self, output_path):
        """Load processed classes from checkpoint"""
        checkpoint_file = Path(output_path) / 'checkpoint.pkl'
        if checkpoint_file.exists():
            try:
                with open(checkpoint_file, 'rb') as f:
                    processed_classes = pickle.load(f)
                print(f"Loaded checkpoint with {len(processed_classes)} processed classes")
                return set(processed_classes)
            except:
                print("Could not load checkpoint, starting fresh")
        return set()
    
    def save_checkpoint(self, output_path, processed_classes):
        """Save checkpoint of processed classes"""
        checkpoint_file = Path(output_path) / 'checkpoint.pkl'
        try:
            with open(checkpoint_file, 'wb') as f:
                pickle.dump(list(processed_classes), f)
        except:
            pass
    
    def print_statistics(self):
        """Print processing statistics"""
        total_time = time.time() - self.stats['start_time']
        
        print(f"\n{'='*60}")
        print("PROCESSING STATISTICS")
        print(f"{'='*60}")
        print(f"Total classes processed: {self.stats['total_processed']}")
        print(f"Total images augmented: {self.stats['total_augmented']:,}")
        print(f"Failed images: {self.stats['failed_images']}")
        print(f"Total time: {total_time:.1f} seconds")
        
        if self.stats['total_processed'] > 0:
            print(f"Average time per class: {total_time/self.stats['total_processed']:.1f}s")
            print(f"Images augmented per second: {self.stats['total_augmented']/total_time:.1f}")
        
        print(f"{'='*60}")


# ============ VERIFICATION FUNCTION ============

def verify_dataset(dataset_path, target_count, target_size, sample_size=5):
    """Quick verification of the dataset with a single progress bar"""
    dataset_path = Path(dataset_path)
    
    print(f"\n{'='*60}")
    print("VERIFYING DATASET")
    print(f"{'='*60}")
    
    all_classes = [d for d in dataset_path.iterdir() if d.is_dir()]
    total_images = 0
    passed_classes = 0
    
    # Single progress bar for verification
    with tqdm(total=len(all_classes), desc="Verifying classes", position=0, leave=True) as pbar:
        for class_dir in all_classes:
            class_name = class_dir.name
            
            # Count images
            image_files = list(class_dir.glob('*.jpg')) + list(class_dir.glob('*.jpeg')) + list(class_dir.glob('*.png'))
            image_count = len(image_files)
            
            if image_count >= target_count:
                passed_classes += 1
            
            total_images += image_count
            
            # Update progress bar
            pbar.update(1)
            pbar.set_postfix_str(f"Passed: {passed_classes}/{len(all_classes)}")
    
    print(f"\nSUMMARY:")
    print(f"  Classes: {len(all_classes)}")
    print(f"  Classes with ≥{target_count} images: {passed_classes}/{len(all_classes)}")
    print(f"  Total images: {total_images:,}")
    print(f"  Average images per class: {total_images/len(all_classes):.1f}")
    
    if passed_classes == len(all_classes):
        print(f"\n✓ All classes have at least {target_count} images!")
    else:
        print(f"\n✗ {len(all_classes) - passed_classes} classes are below target count")
    
    return passed_classes == len(all_classes)


# ============ MAIN EXECUTION ============

def main():
    # Configuration - CHANGE THESE PATHS
    INPUT_PATH = r"D:\Final_Semester_Project\AI_Attendance_System\ai-ml-model\DataSets\raw"
    OUTPUT_PATH = r"D:\Final_Semester_Project\AI_Attendance_System\ai-ml-model\DataSets\FullProcessedAgumented"
    TARGET_SIZE = (112, 112)
    TARGET_COUNT = 500
    
    print("Facial Recognition Dataset Balancer")
    print("="*50)
    
    # Initialize balancer
    balancer = FacialDataBalancer(
        target_count=TARGET_COUNT,
        target_size=TARGET_SIZE,
        use_multiprocessing=False,  # Set to False to avoid multiprocessing issues
        max_workers=2  # Use fewer workers for stability
    )
    
    # Create output directory
    Path(OUTPUT_PATH).mkdir(parents=True, exist_ok=True)
    
    # Start balancing
    try:
        balancer.balance_dataset(INPUT_PATH, OUTPUT_PATH)
    except KeyboardInterrupt:
        print("\n\nProcess interrupted by user. Checkpoint saved.")
        print(f"Progress saved to: {Path(OUTPUT_PATH) / 'checkpoint.pkl'}")
        return
    except Exception as e:
        print(f"\nError during processing: {e}")
        return
    
    # Verify results
    verify_dataset(OUTPUT_PATH, TARGET_COUNT, TARGET_SIZE)


# ============ ALTERNATIVE VERSION WITH STATUS UPDATES ============

class FacialDataBalancerSimple:
    """Even simpler version with just one progress bar"""
    def __init__(self, target_count=500, target_size=(112, 112)):
        self.target_count = target_count
        self.target_size = target_size
        
        # Simple augmentation pipeline
        self.augmentation = A.Compose([
            A.HorizontalFlip(p=0.3),
            A.ShiftScaleRotate(shift_limit=0.05, scale_limit=0.05, rotate_limit=5, p=0.3),
            A.RandomBrightnessContrast(brightness_limit=0.1, contrast_limit=0.1, p=0.2),
        ])
    
    def balance_simple(self, input_path, output_path):
        """Simple balancing with single progress bar"""
        import_path = Path(input_path)
        output_path = Path(output_path)
        
        # Get all classes
        classes = [d for d in import_path.iterdir() if d.is_dir()]
        
        print(f"Processing {len(classes)} classes...")
        
        # Single progress bar for all classes
        with tqdm(total=len(classes), desc="Processing classes", bar_format='{l_bar}{bar}| {n_fmt}/{total_fmt} [{elapsed}<{remaining}]') as pbar:
            for class_dir in classes:
                class_name = class_dir.name
                
                # Update progress bar with current class name
                pbar.set_description(f"Processing: {class_name[:20]:20s}")
                
                # Process the class
                self._process_class_simple(class_dir, class_name, output_path)
                
                pbar.update(1)
        
        print("\nProcessing complete!")
    
    def _process_class_simple(self, class_dir, class_name, output_path):
        """Process a single class"""
        output_dir = output_path / class_name
        output_dir.mkdir(parents=True, exist_ok=True)
        
        # Get all images
        images = list(class_dir.glob('*.jpg')) + list(class_dir.glob('*.jpeg')) + list(class_dir.glob('*.png'))
        
        if not images:
            return
        
        # Process existing images
        for i, img_path in enumerate(images[:min(len(images), self.target_count)]):
            try:
                img = cv2.imread(str(img_path))
                if img is not None:
                    img = cv2.resize(img, self.target_size)
                    output_file = output_dir / f"{class_name}_{i:04d}.jpg"
                    cv2.imwrite(str(output_file), img)
            except:
                pass
        
        # Augment if needed
        if len(images) < self.target_count:
            needed = self.target_count - len(images)
            # Simple augmentation from first image
            if images:
                base_img = cv2.imread(str(images[0]))
                if base_img is not None:
                    base_img = cv2.resize(base_img, self.target_size)
                    for i in range(needed):
                        try:
                            augmented = self.augmentation(image=base_img)['image']
                            output_file = output_dir / f"{class_name}_aug_{i:04d}.jpg"
                            cv2.imwrite(str(output_file), augmented)
                        except:
                            pass


def run_simple():
    """Run the simple version"""
    input_path = r"D:\Final_Semester_Project\AI_Attendance_System\ai-ml-model\DataSets\raw"
    output_path = r"D:\Final_Semester_Project\AI_Attendance_System\ai-ml-model\DataSets\ProcessedSimple"
    
    balancer = FacialDataBalancerSimple(target_count=500, target_size=(112, 112))
    balancer.balance_simple(input_path, output_path)
    
    print(f"\nDataset saved to: {output_path}")


if __name__ == "__main__":
    # Run the main version (with single progress bar)
    main()
    
    # Or run the simple version
    # run_simple()

Facial Recognition Dataset Balancer
Using 2 workers for processing

Starting dataset balancing
Input: D:\Final_Semester_Project\AI_Attendance_System\ai-ml-model\DataSets\raw
Output: D:\Final_Semester_Project\AI_Attendance_System\ai-ml-model\DataSets\FullProcessedAgumented
Target: 500 images per class at (112, 112)

Found 540 classes
Processing 540 classes (skipping 0 already processed)


Processing: n000270 ✓: 100%|█████████████████████████████████████| 540/540 [1:03:42<00:00,  7.08s/it, Classes: 540/540]



PROCESSING STATISTICS
Total classes processed: 540
Total images augmented: 6,634
Failed images: 0
Total time: 3825.2 seconds
Average time per class: 7.1s
Images augmented per second: 1.7

VERIFYING DATASET


Verifying classes: 100%|███████████████████████████████████████████| 540/540 [00:04<00:00, 116.99it/s, Passed: 540/540]


SUMMARY:
  Classes: 540
  Classes with ≥500 images: 540/540
  Total images: 270,000
  Average images per class: 500.0

✓ All classes have at least 500 images!



