In [2]:
import os
import cv2
import numpy as np
from PIL import Image
import mediapipe as mp
import albumentations as A
from pathlib import Path

class FaceDatasetPreparator:
    def __init__(self, input_dir, output_dir):
        self.input_dir = Path(input_dir)
        self.output_dir = Path(output_dir)
        self.output_dir.mkdir(parents=True, exist_ok=True)
        
        # Initialize face detection
        self.mp_face_detection = mp.solutions.face_detection
        self.face_detection = self.mp_face_detection.FaceDetection(
            model_selection=1, min_detection_confidence=0.5
        )
        
        # Define augmentation pipeline
        self.transform = A.Compose([
            A.RandomBrightnessContrast(p=0.5),
            A.HueSaturationValue(p=0.3),
            A.GaussNoise(p=0.2),
            A.RandomRotate90(p=0.2),
            A.Flip(p=0.2),
            A.OneOf([
                A.MotionBlur(p=0.2),
                A.MedianBlur(blur_limit=3, p=0.1),
                A.GaussianBlur(blur_limit=3, p=0.1),
            ], p=0.2),
        ])

    def detect_and_crop_face(self, image):
        """Detect face in image and return cropped region with padding"""
        results = self.face_detection.process(cv2.cvtColor(image, cv2.COLOR_BGR2RGB))
        
        if not results.detections:
            return None
            
        detection = results.detections[0]  # Use the first detected face
        bbox = detection.location_data.relative_bounding_box
        
        h, w = image.shape[:2]
        x, y = int(bbox.xmin * w), int(bbox.ymin * h)
        width, height = int(bbox.width * w), int(bbox.height * h)
        
        # Add padding around face
        padding = 0.2
        x = max(0, int(x - width * padding))
        y = max(0, int(y - height * padding))
        width = min(w - x, int(width * (1 + 2 * padding)))
        height = min(h - y, int(height * (1 + 2 * padding)))
        
        return image[y:y+height, x:x+width]

    def process_image(self, image_path, idx):
        """Process single image and save original and augmented versions"""
        image = cv2.imread(str(image_path))
        if image is None:
            print(f"Failed to load image: {image_path}")
            return
            
        # Detect and crop face
        face = self.detect_and_crop_face(image)
        if face is None:
            print(f"No face detected in: {image_path}")
            return
            
        # Resize to standard size
        face = cv2.resize(face, (512, 512))
        
        # Save original processed face
        original_path = self.output_dir / f"face_{idx:03d}.png"
        cv2.imwrite(str(original_path), face)
        
        # Generate augmented versions
        for aug_idx in range(3):  # Create 3 augmented versions of each image
            augmented = self.transform(image=face)['image']
            aug_path = self.output_dir / f"face_{idx:03d}_aug_{aug_idx}.png"
            cv2.imwrite(str(aug_path), augmented)

    def prepare_dataset(self):
        """Process all images in the input directory"""
        image_files = list(self.input_dir.glob("*.jpg")) + list(self.input_dir.glob("*.png"))
        
        for idx, image_path in enumerate(image_files):
            print(f"Processing image {idx+1}/{len(image_files)}: {image_path}")
            self.process_image(image_path, idx)
            
        print(f"\nDataset preparation completed. Processed images saved to: {self.output_dir}")
        print(f"Total original images: {len(image_files)}")
        print(f"Total dataset size (including augmentations): {len(image_files) * 4}")

def prepare_metadata(output_dir, prompt_prefix="a photo of a person"):
    """Create metadata file for fine-tuning"""
    metadata_file = Path(output_dir) / "metadata.jsonl"
    image_files = list(Path(output_dir).glob("*.png"))
    
    with open(metadata_file, "w") as f:
        for image_file in image_files:
            f.write(f'{{"file_name": "{image_file.name}", "text": "{prompt_prefix}"}}\n')
            
    print(f"Metadata file created: {metadata_file}")

# Usage example
if __name__ == "__main__":
    INPUT_DIR = "raw_images"    # Directory containing your original face photos
    OUTPUT_DIR = "processed_dataset"  # Directory where processed images will be saved
    
    preparator = FaceDatasetPreparator(INPUT_DIR, OUTPUT_DIR)
    preparator.prepare_dataset()
    prepare_metadata(OUTPUT_DIR, prompt_prefix="thenmozhi")

Processing image 1/37: raw_images\photo_2024-12-16_19-09-49 (2).jpg
Processing image 2/37: raw_images\photo_2024-12-16_19-09-49 (3).jpg
Processing image 3/37: raw_images\photo_2024-12-16_19-09-49 (4).jpg
Processing image 4/37: raw_images\photo_2024-12-16_19-09-49 (5).jpg
Processing image 5/37: raw_images\photo_2024-12-16_19-09-49.jpg
Processing image 6/37: raw_images\photo_2024-12-16_19-09-50 (2).jpg
Processing image 7/37: raw_images\photo_2024-12-16_19-09-50 (3).jpg
Processing image 8/37: raw_images\photo_2024-12-16_19-09-50 (4).jpg
Processing image 9/37: raw_images\photo_2024-12-16_19-09-50 (5).jpg
Processing image 10/37: raw_images\photo_2024-12-16_19-09-50 (6).jpg
Processing image 11/37: raw_images\photo_2024-12-16_19-09-50 (7).jpg
Processing image 12/37: raw_images\photo_2024-12-16_19-09-50 (8).jpg
Processing image 13/37: raw_images\photo_2024-12-16_19-09-50.jpg
Processing image 14/37: raw_images\photo_2024-12-16_19-09-51 (2).jpg
Processing image 15/37: raw_images\photo_2024-12-16