# InfraOwl Data Preprocessing

This notebook handles data preprocessing for the InfraOwl model training pipeline.

## Tasks
- Load and validate raw data
- Split data into train/validation/test sets
- Resize and normalize images
- Apply data augmentation
- Generate preprocessing statistics

In [None]:
import sys
sys.path.append('../scripts')

import os
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from pathlib import Path
import yaml
from PIL import Image
import shutil
from sklearn.model_selection import train_test_split

# Import our preprocessing module
from data_preprocessing import DataPreprocessor

print("🔄 InfraOwl Data Preprocessing Notebook")
print("======================================")

## 1. Configuration Setup

In [None]:
# Load configuration
with open('../configs/training_config.yaml', 'r') as f:
    config = yaml.safe_load(f)

print("📋 Training Configuration:")
print(f"  Classes: {config['classes']}")
print(f"  Target Size: {config['data']['preprocessing']['target_size']}")
print(f"  Train Split: {config['data']['train_split']}")
print(f"  Validation Split: {config['data']['val_split']}")
print(f"  Test Split: {config['data']['test_split']}")
print(f"  Data Augmentation: {config['data']['augmentation']['enabled']}")

## 2. Raw Data Validation

In [None]:
# Initialize preprocessor
preprocessor = DataPreprocessor('../configs/training_config.yaml')

# Validate raw data
print("🔍 Validating raw data...")
try:
    raw_stats = preprocessor.validate_raw_data()
    print("✅ Raw data validation completed")
except Exception as e:
    print(f"❌ Raw data validation failed: {e}")
    print("\nPlease ensure images are properly organized in:")
    for class_name in config['classes']:
        print(f"  - ../data/raw_images/{class_name}/")

## 3. Data Splitting Strategy

In [None]:
# Visualize splitting strategy
train_split = config['data']['train_split']
val_split = config['data']['val_split']
test_split = 1 - train_split - val_split

# Create pie chart
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 6))

# Split visualization
splits = [train_split, val_split, test_split]
labels = ['Train', 'Validation', 'Test']
colors = ['#FF9999', '#66B2FF', '#99FF99']

ax1.pie(splits, labels=labels, colors=colors, autopct='%1.1f%%', startangle=90)
ax1.set_title('Data Split Strategy')

# Expected sample counts (if we have raw stats)
if 'raw_stats' in locals():
    expected_counts = {}
    class_names = list(raw_stats.keys())
    
    for split_name, split_ratio in zip(['Train', 'Validation', 'Test'], splits):
        expected_counts[split_name] = [int(raw_stats[class_name] * split_ratio) 
                                     for class_name in class_names]
    
    # Stacked bar chart
    x = np.arange(len(class_names))
    width = 0.6
    
    bottom_train = np.zeros(len(class_names))
    bottom_val = expected_counts['Train']
    bottom_test = np.array(expected_counts['Train']) + np.array(expected_counts['Validation'])
    
    ax2.bar(x, expected_counts['Train'], width, label='Train', color=colors[0])
    ax2.bar(x, expected_counts['Validation'], width, bottom=bottom_val, label='Validation', color=colors[1])
    ax2.bar(x, expected_counts['Test'], width, bottom=bottom_test, label='Test', color=colors[2])
    
    ax2.set_xlabel('Classes')
    ax2.set_ylabel('Number of Images')
    ax2.set_title('Expected Samples per Class')
    ax2.set_xticks(x)
    ax2.set_xticklabels(class_names, rotation=45)
    ax2.legend()

plt.tight_layout()
plt.show()

## 4. Image Preprocessing Pipeline

In [None]:
# Demonstrate image preprocessing
def show_preprocessing_example():
    """Show before/after example of image preprocessing."""
    
    # Find a sample image
    raw_data_path = Path('../data/raw_images')
    sample_image = None
    
    for class_name in config['classes']:
        class_dir = raw_data_path / class_name
        if class_dir.exists():
            for ext in ['*.jpg', '*.jpeg', '*.png']:
                image_files = list(class_dir.glob(ext))
                if image_files:
                    sample_image = image_files[0]
                    break
        if sample_image:
            break
    
    if not sample_image:
        print("No sample images found for preprocessing demo")
        return
    
    # Load and process image
    print(f"📸 Processing sample image: {sample_image.name}")
    
    # Original image
    original_img = Image.open(sample_image)
    
    # Processed image (simulate preprocessing)
    target_size = tuple(config['data']['preprocessing']['target_size'])
    
    # Create processed version
    processed_img = original_img.copy()
    if processed_img.mode != 'RGB':
        processed_img = processed_img.convert('RGB')
    
    # Resize maintaining aspect ratio, then center crop
    processed_img.thumbnail(target_size, Image.Resampling.LANCZOS)
    
    # Create new image with target size
    final_img = Image.new('RGB', target_size, (0, 0, 0))
    paste_x = (target_size[0] - processed_img.width) // 2
    paste_y = (target_size[1] - processed_img.height) // 2
    final_img.paste(processed_img, (paste_x, paste_y))
    
    # Display comparison
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 6))
    
    ax1.imshow(original_img)
    ax1.set_title(f'Original\n{original_img.size[0]} x {original_img.size[1]}')
    ax1.axis('off')
    
    ax2.imshow(final_img)
    ax2.set_title(f'Processed\n{final_img.size[0]} x {final_img.size[1]}')
    ax2.axis('off')
    
    plt.suptitle('Image Preprocessing Example')
    plt.tight_layout()
    plt.show()
    
    print(f"Original size: {original_img.size}")
    print(f"Target size: {target_size}")
    print(f"Final size: {final_img.size}")

show_preprocessing_example()

## 5. Run Data Preprocessing

In [None]:
# Run the complete preprocessing pipeline
print("🚀 Running complete data preprocessing pipeline...")
print("This may take a few minutes depending on dataset size.\n")

try:
    # Run preprocessing
    processed_stats = preprocessor.split_and_process_data()
    
    print("\n✅ Data preprocessing completed successfully!")
    
    # Display results
    print("\n📊 Processing Results:")
    for split_name in ['train', 'validation', 'test']:
        total = sum(processed_stats[split_name].values())
        print(f"  {split_name.capitalize()}: {total} images")
        for class_name, count in processed_stats[split_name].items():
            print(f"    {class_name}: {count}")
    
except Exception as e:
    print(f"❌ Preprocessing failed: {e}")
    processed_stats = None

## 6. Generate Statistics and Visualizations

In [None]:
# Generate statistics if preprocessing was successful
if processed_stats:
    print("📈 Generating statistics and visualizations...")
    
    try:
        preprocessor.generate_statistics(processed_stats)
        print("✅ Statistics generated successfully!")
        
        # Display the generated plot
        stats_plot_path = Path('../outputs/dataset_statistics.png')
        if stats_plot_path.exists():
            from IPython.display import Image as IPImage, display
            print("\n📊 Dataset Statistics:")
            display(IPImage(filename=str(stats_plot_path)))
        
    except Exception as e:
        print(f"⚠️  Statistics generation failed: {e}")
else:
    print("⏭️  Skipping statistics generation (preprocessing not completed)")

## 7. Data Augmentation Preview

In [None]:
# Show data augmentation examples
def show_augmentation_examples():
    """Demonstrate data augmentation techniques."""
    
    if not config['data']['augmentation']['enabled']:
        print("ℹ️  Data augmentation is disabled in configuration")
        return
    
    # Find a processed image
    processed_path = Path('../data/processed/train')
    if not processed_path.exists():
        print("ℹ️  No processed data found. Run preprocessing first.")
        return
    
    sample_image = None
    for class_name in config['classes']:
        class_dir = processed_path / class_name
        if class_dir.exists():
            image_files = list(class_dir.glob('*.jpg'))
            if image_files:
                sample_image = image_files[0]
                break
    
    if not sample_image:
        print("No processed images found for augmentation demo")
        return
    
    # Load image
    from tensorflow.keras.preprocessing.image import ImageDataGenerator, load_img, img_to_array
    import numpy as np
    
    img = load_img(sample_image)
    img_array = img_to_array(img)
    img_array = img_array.reshape((1,) + img_array.shape)
    
    # Create augmentation generator
    aug_config = config['data']['augmentation']
    datagen = ImageDataGenerator(
        rotation_range=aug_config['rotation_range'],
        width_shift_range=aug_config['width_shift_range'],
        height_shift_range=aug_config['height_shift_range'],
        shear_range=aug_config['shear_range'],
        zoom_range=aug_config['zoom_range'],
        horizontal_flip=aug_config['horizontal_flip'],
        brightness_range=aug_config['brightness_range'],
        fill_mode='nearest'
    )
    
    # Generate augmented examples
    print(f"🔄 Data Augmentation Examples: {sample_image.name}")
    
    fig, axes = plt.subplots(2, 4, figsize=(16, 8))
    
    # Original image
    axes[0, 0].imshow(img)
    axes[0, 0].set_title('Original')
    axes[0, 0].axis('off')
    
    # Generate 7 augmented versions
    aug_iter = datagen.flow(img_array, batch_size=1)
    
    for i, (row, col) in enumerate([(0,1), (0,2), (0,3), (1,0), (1,1), (1,2), (1,3)]):
        if i >= 7:  # We want 7 augmented images
            break
        
        aug_img = next(aug_iter)[0].astype('uint8')
        axes[row, col].imshow(aug_img)
        axes[row, col].set_title(f'Augmented {i+1}')
        axes[row, col].axis('off')
    
    plt.suptitle('Data Augmentation Examples')
    plt.tight_layout()
    plt.show()
    
    print("\n🎯 Augmentation Parameters:")
    for param, value in aug_config.items():
        if param != 'enabled':
            print(f"  {param}: {value}")

show_augmentation_examples()

## 8. Validation and Next Steps

In [None]:
# Final validation
print("✅ Data Preprocessing Summary")
print("============================")

processed_path = Path('../data/processed')
if processed_path.exists():
    print("📁 Processed data structure:")
    for split in ['train', 'validation', 'test']:
        split_path = processed_path / split
        if split_path.exists():
            total_images = sum(len(list(class_dir.glob('*.jpg'))) 
                             for class_dir in split_path.iterdir() 
                             if class_dir.is_dir())
            print(f"  {split}/: {total_images} images")
            
            for class_dir in split_path.iterdir():
                if class_dir.is_dir():
                    class_count = len(list(class_dir.glob('*.jpg')))
                    print(f"    {class_dir.name}/: {class_count} images")
    
    print("\n🚀 Ready for Model Training!")
    print("\nNext steps:")
    print("1. 🎯 Train the model:")
    print("   python ../scripts/train_model.py")
    print()
    print("2. 📱 Convert to TensorFlow Lite:")
    print("   python ../scripts/convert_to_tflite.py")
    print()
    print("3. 📊 Evaluate performance:")
    print("   python ../scripts/evaluate_model.py")
    
else:
    print("❌ Processed data not found")
    print("Please run the preprocessing steps above successfully before proceeding.")

print("\n💡 Tips:")
print("• Monitor training in TensorBoard: tensorboard --logdir ../logs")
print("• Adjust hyperparameters in ../configs/training_config.yaml")
print("• Add more data if model performance is insufficient")