In [2]:
import os
import shutil
import random
import logging
from pathlib import Path
from typing import List, Tuple, Dict
from sklearn.model_selection import train_test_split
from PIL import Image

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

SUPPORTED_IMAGE_FORMATS = [".bmp", ".jpg", ".jpeg", ".png"]


class DataLoader:
    def __init__(self, dataset_path: Path, processed_path: Path, image_size: Tuple[int, int], seed: int = 42):
        self.dataset_path = Path(dataset_path)
        self.processed_path = Path(processed_path)
        self.image_size = image_size
        self.seed = seed
        random.seed(seed)

    def _load_image_paths_and_labels(self) -> List[Tuple[Path, str]]:
        """Finds all supported image files and extracts labels from parent folders."""
        logger.info(f"Loading data from {self.dataset_path}")
        image_paths = []
        for class_dir in self.dataset_path.iterdir():
            if class_dir.is_dir():
                label = class_dir.name.lower()
                for ext in SUPPORTED_IMAGE_FORMATS:
                    image_paths.extend([(p, label) for p in class_dir.rglob(f"*{ext}")])
        logger.info(f"Found {len(image_paths)} images")
        return image_paths

    def _split_dataset(
        self,
        data: List[Tuple[Path, str]],
        val_split: float,
        test_split: float
    ) -> Tuple[List, List, List]:
        """Splits data into train, validation, and test sets."""
        if len(data) == 0:
            raise ValueError("No images found to split. Please check the dataset path and format.")

        train_data, test_data = train_test_split(data, test_size=test_split, random_state=self.seed, stratify=[label for _, label in data])
        train_data, val_data = train_test_split(train_data, test_size=val_split / (1.0 - test_split), random_state=self.seed, stratify=[label for _, label in train_data])
        return train_data, val_data, test_data

    def _copy_data_to_folder(self, data: List[Tuple[Path, str]], target_dir: Path):
        for img_path, label in data:
            label_dir = target_dir / label
            label_dir.mkdir(parents=True, exist_ok=True)
            dest_path = label_dir / img_path.name
            shutil.copy2(img_path, dest_path)

    def prepare_and_split_data(
        self,
        val_split: float = 0.2,
        test_split: float = 0.1,
        save_stats: bool = True
    ):
        """Loads, splits and copies data into processed/train|val|test folders."""
        all_data = self._load_image_paths_and_labels()
        train_data, val_data, test_data = self._split_dataset(all_data, val_split, test_split)

        split_dirs = {
            "train": self.processed_path / "train",
            "validation": self.processed_path / "validation",
            "test": self.processed_path / "test"
        }

        # Clean existing processed folders if any
        if self.processed_path.exists():
            logger.info(f"Cleaning up old processed directory: {self.processed_path}")
            shutil.rmtree(self.processed_path)

        for split_name, split_data in zip(split_dirs.keys(), [train_data, val_data, test_data]):
            logger.info(f"Copying {split_name} data with {len(split_data)} samples...")
            self._copy_data_to_folder(split_data, split_dirs[split_name])

        logger.info(f"Data organization complete. Total: {len(all_data)} images")

        # Optionally, save stats
        if save_stats:
            stats_path = self.processed_path / "data_stats.json"
            import json
            stats = {
                "total": len(all_data),
                "train": len(train_data),
                "validation": len(val_data),
                "test": len(test_data),
                "classes": sorted({label for _, label in all_data})
            }
            stats_path.parent.mkdir(parents=True, exist_ok=True)
            with open(stats_path, "w") as f:
                json.dump(stats, f, indent=4)
            logger.info(f"Statistics saved to {stats_path}")
