# Purpose
In place of an NSFW classifier, we will train a CNN classifier that is able to detect whether a cat exists

In [1]:
CAT_DIR = 'data/cat'
NON_CAT_DIR = 'data/non-cat'
OUTPUT_DIR = 'data/balanced_dataset_output'
SAMPLE_SIZE = 2000  # Set your desired sample size per class
SEED = 42 # for reproducibility
TEST_RATIO = 0.2  # 20% for testing

In [2]:
import shutil
import random
import pandas as pd
from pathlib import Path
from tqdm import tqdm

def get_image_paths(directory):
    """Recursively find all image file paths in a directory using Pathlib."""
    directory = Path(directory)
    extensions = {'.png', '.jpg', '.jpeg', '.gif', '.bmp'}
    return [p for p in directory.rglob('*') if p.suffix.lower() in extensions]

def curate_dataset(cat_dir, non_cat_dir, output_dir, target_size, test_ratio=0.2, seed=42, clear_output_folder=False):
    random.seed(seed)
    output_path = Path(output_dir)
    
    # 1. Clean and Setup Output Directory
    if output_path.exists():
        if clear_output_folder:
            print(f"Clearing existing output directory: {output_path}")
            shutil.rmtree(output_path)
        else:
            print(f"Warning: Output directory {output_path} exists and clear_output_folder is False.")
            # Depending on use case, you might want to return or raise error here
            
    output_path.mkdir(parents=True, exist_ok=True)

    # 2. Collect Files
    print("Scanning directories...")
    cat_files = get_image_paths(cat_dir)
    print(f"Found {len(cat_files)} cat images.")
    non_cat_files = get_image_paths(non_cat_dir)
    print(f"Found {len(non_cat_files)} non-cat images.")
    
    # 3. Handle Sample Size logic
    available_min = min(len(cat_files), len(non_cat_files))
    if available_min < target_size:
        print(f"⚠️ Insufficient data for target size {target_size}. Adjusting to {available_min}.")
        target_size = available_min

    # 4. Select and Shuffle Samples
    selected_cats = random.sample(cat_files, target_size)
    selected_non_cats = random.sample(non_cat_files, target_size)

    # 5. Calculate Split Index
    split_idx = int(target_size * (1 - test_ratio))
    
    data_splits = {
        'training_data': [
            (selected_cats[:split_idx], 'cat', 1),
            (selected_non_cats[:split_idx], 'non-cat', 0)
        ],
        'test_data': [
            (selected_cats[split_idx:], 'cat', 1),
            (selected_non_cats[split_idx:], 'non-cat', 0)
        ]
    }

    # 6. Process Splits
    for split_name, categories in data_splits.items():
        # output/training_data
        split_dir = output_path / split_name
        split_dir.mkdir(exist_ok=True)
        
        csv_data = [] 
        
        print(f"\nProcessing {split_name}...")
        
        for files, class_name, class_id in categories:
            
            for file_path in tqdm(files, desc=f"  Copying {class_name}"):
                # Create unique filename: 1_imageName.jpg
                new_filename = f"{class_id}_{file_path.name}"
                
                # Destination: output/training_data/1_imageName.jpg
                dest_path = split_dir / new_filename
                
                shutil.copy2(file_path, dest_path)
                
                # CSV Path: just the filename, because CSV is in the same folder
                csv_data.append({
                    'img_path': new_filename, 
                    'label': class_id
                })

        # Save the CSV
        df = pd.DataFrame(csv_data)
        df.to_csv(split_dir / 'annotations.csv', index=False)

    print(f"\nSUCCESS! Dataset curated at: {output_path}")
    print(f"Structure:")
    print(f"  ├── training_data/ (images + annotations.csv)")
    print(f"  └── test_data/     (images + annotations.csv)")

In [3]:
from torch.utils.data import DataLoader

# 1. Run the curation logic
curate_dataset(
    cat_dir=CAT_DIR,
    non_cat_dir=NON_CAT_DIR,
    output_dir=OUTPUT_DIR,
    target_size=SAMPLE_SIZE,
    seed=SEED,
    test_ratio=TEST_RATIO,
    clear_output_folder=True
)

Clearing existing output directory: data/balanced_dataset_output
Scanning directories...
Found 1668 cat images.
Found 24511 non-cat images.
⚠️ Insufficient data for target size 2000. Adjusting to 1668.

Processing training_data...


  Copying cat: 100%|██████████| 1334/1334 [00:00<00:00, 4643.06it/s]
  Copying non-cat: 100%|██████████| 1334/1334 [00:00<00:00, 6323.22it/s]



Processing test_data...


  Copying cat: 100%|██████████| 334/334 [00:00<00:00, 4982.12it/s]
  Copying non-cat: 100%|██████████| 334/334 [00:00<00:00, 8955.71it/s]


SUCCESS! Dataset curated at: data/balanced_dataset_output
Structure:
  ├── training_data/ (images + annotations.csv)
  └── test_data/     (images + annotations.csv)



