# Step 2: Data augmentation
Implement in src/augment.py later

### Apply techniques like:
- Rotation
- Horizontal/Vertical flipping
- Scaling / Cropping
- Color jitter / brightness adjustment
### Target: Increase dataset by at least 30%, especially for minority classes.
- Save augmented images or implement augmentation on-the-fly during training.

> after exploring data in the prev notebook and removing courpted images and counting no. of images for each class we found that
### Images per class:
- cardboard: 259 images
- glass: 401 images
- metal: 328 images
- paper: 476 images
- plastic: 386 images
- trash: 110 images

# Apply Data Augmentation Using ImageDataGenerator (Keras)
> documentation : https://www.tensorflow.org/api_docs/python/tf/keras/preprocessing/image/ImageDataGenerator
> video: https://www.youtube.com/watch?v=Ahy50JCRYNk

- `ImageDataGenerator`: Class used to apply data augmentation to images.
- `img_to_array`: Converts a PIL image to a NumPy array (needed for processing).
- `load_img`: Loads an image from disk as a PIL image.

In [31]:
import tensorflow as tf
from tensorflow.keras.preprocessing.image import ImageDataGenerator, img_to_array, load_img
import os

# Enhanced ImageDataGenerator with multiple augmentation techniques
For heavily augmented classes (trash, cardboard), we'll use:
- **Geometric**: rotation, shifts, zoom, flips
- **Color**: brightness, contrast adjustments  
- **Spatial**: shear transformations


# Step 0: Inspect initial image counts


In [32]:
import os
from collections import defaultdict

DATA_ROOT = "/Users/rodynaamr/Image_Classification_SVM_kNN/data"
IMAGE_EXTENSIONS = ('.jpg', '.jpeg', '.png')
TARGET_IMAGES = 5000  # Target number of images per class

# Count current images in each class
initial_counts = {}
for waste_class in ["cardboard", "glass", "metal", "paper", "plastic", "trash"]:
    folder = os.path.join(DATA_ROOT, waste_class)
    if os.path.exists(folder):
        count = len([f for f in os.listdir(folder) if f.lower().endswith(IMAGE_EXTENSIONS)])
        initial_counts[waste_class] = count

print("=" * 60)
print("INITIAL IMAGE COUNTS")
print("=" * 60)
for cls, count in sorted(initial_counts.items()):
    multiplier = TARGET_IMAGES / count if count > 0 else 0
    print(f"{cls:12} → {count:4} images (need {multiplier:5.1f}x augmentation to reach {TARGET_IMAGES})")
print("=" * 60)

# Calculate required augmentations per class
augment_counts = {}
for cls, count in initial_counts.items():
    if count > 0:
        # How many times to generate augmented versions of each original image
        augment_counts[cls] = max(1, round(TARGET_IMAGES / count) - 1)
    else:
        augment_counts[cls] = 1

print("\nAugmentation multipliers to apply:")
for cls, mult in sorted(augment_counts.items()):
    print(f"{cls:12} → generate {mult} augmented versions per original image")


INITIAL IMAGE COUNTS
cardboard    →  259 images (need  19.3x augmentation to reach 5000)
glass        →  401 images (need  12.5x augmentation to reach 5000)
metal        →  328 images (need  15.2x augmentation to reach 5000)
paper        →  476 images (need  10.5x augmentation to reach 5000)
plastic      →  386 images (need  13.0x augmentation to reach 5000)
trash        →  110 images (need  45.5x augmentation to reach 5000)

Augmentation multipliers to apply:
cardboard    → generate 18 augmented versions per original image
glass        → generate 11 augmented versions per original image
metal        → generate 14 augmented versions per original image
paper        → generate 10 augmented versions per original image
plastic      → generate 12 augmented versions per original image
trash        → generate 44 augmented versions per original image


> “Every time you load an image, randomly rotate it, shift it, zoom it, or flip it.”
| Parameter                    | Meaning                                     |
| ---------------------------- | ------------------------------------------- |
| `rotation_range=20`          | Rotate image randomly between -20° to +20°  |
| `width_shift_range=0.1`      | Move image left/right up to 10% of width    |
| `height_shift_range=0.1`     | Move image up/down up to 10% of height      |
| `zoom_range=0.1`             | Randomly zoom in/out up to ±10%             |
| `horizontal_flip=True`       | Flip image horizontally (mirror)            |
| `brightness_range=[0.8,1.2]` | Randomly change brightness from 80% to 120% |


In [33]:
datagen = ImageDataGenerator(
    rotation_range=25,
    width_shift_range=0.15,
    height_shift_range=0.15,
    shear_range=0.15,
    zoom_range=0.2,
    horizontal_flip=True,
    brightness_range=[0.7, 1.3],
    fill_mode='nearest'
)

# 2. Function to augment a single class folder


In [34]:
IMAGE_EXTENSIONS = ('.jpg', '.jpeg', '.png')

def augment_class(input_dir, output_dir, augment_count=5, class_name=""):
    """Augment a class folder to reach target number of images.
    
    Args:
        input_dir: Path to input class directory
        output_dir: Path to output augmented class directory
        augment_count: Number of augmented versions per original image
        class_name: Name of class for logging
    """
    os.makedirs(output_dir, exist_ok=True)

    for image_name in os.listdir(input_dir):
        if not image_name.lower().endswith(IMAGE_EXTENSIONS):
            continue

        img_path = os.path.join(input_dir, image_name)
        try:
            img = load_img(img_path)
        except Exception as e:
            print(f"⏭ Skipping corrupted file: {image_name}")
            continue

        x = img_to_array(img)
        x = x.reshape((1,) + x.shape)

        prefix = image_name.split('.')[0]
        i = 0
        for batch in datagen.flow(
            x,
            batch_size=1,
            save_to_dir=output_dir,
            save_prefix=prefix,
            save_format='jpg'
        ):
            i += 1
            if i >= augment_count:
                break
        
        if augment_count > 0:
            print(f"  ✓ {class_name}: Augmented {image_name}")


# 3. Function to augment the entire dataset


In [35]:
def augment_dataset(root="/Users/rodynaamr/Image_Classification_SVM_kNN/data"):
    """Apply class-specific augmentation to reach ~5000 images per class."""
    
    print("\n" + "=" * 60)
    print("STARTING CLASS-SPECIFIC AUGMENTATION")
    print("=" * 60)
    
    for cls in ["cardboard", "glass", "metal", "paper", "plastic", "trash"]:
        input_dir = os.path.join(root, cls)
        output_dir = os.path.join(root, cls + "_aug")
        
        if not os.path.exists(input_dir):
            print(f"⏭ Skipping {cls}: directory not found")
            continue
        
        aug_count = augment_counts.get(cls, 1)
        print(f"\n▶ Augmenting '{cls}':")
        print(f"  Original images: {initial_counts.get(cls, 0)}")
        print(f"  Augmentations per image: {aug_count}")
        print(f"  Expected total: {initial_counts.get(cls, 0) * (1 + aug_count)}")
        
        augment_class(input_dir, output_dir, augment_count=aug_count, class_name=cls)

# 4. Run augmentation


In [36]:
augment_dataset(root="/Users/rodynaamr/Image_Classification_SVM_kNN/data")


STARTING CLASS-SPECIFIC AUGMENTATION

▶ Augmenting 'cardboard':
  Original images: 259
  Augmentations per image: 18
  Expected total: 4921
  ✓ cardboard: Augmented 3d1b6d5e-d81f-4bf1-b3d7-5c702625a5a7.jpg
  ✓ cardboard: Augmented 2fbdea77-1129-4be4-8014-8667aa793080.jpg
  ✓ cardboard: Augmented efb2516b-eefd-4e59-a7aa-9470b9c7e77c.jpg
  ✓ cardboard: Augmented 3896e55a-e949-427f-b02c-1132874a2b76.jpg
  ✓ cardboard: Augmented db482691-ff52-4197-bc61-637ed20aac3b.jpg
  ✓ cardboard: Augmented 9e6305ca-dcdc-48fa-93a9-54154a8f110e.jpg
  ✓ cardboard: Augmented c8040630-1679-42ed-a483-d2b780ba6e37.jpg
  ✓ cardboard: Augmented 608b7b7d-a4b1-457c-86dc-633ec96f2eb4.jpg
  ✓ cardboard: Augmented 25954c60-fbe3-45e4-a806-31676251910e.jpg
  ✓ cardboard: Augmented bbaa862c-4d91-48aa-9d4d-56c7cfa6c5f3.jpg
  ✓ cardboard: Augmented b76f3714-d7d0-4fce-a7ac-3cfa81b2582e.jpg
  ✓ cardboard: Augmented aa22f27b-691c-4836-ae90-d1c7b514d3e6.jpg
  ✓ cardboard: Augmented 358ba7a3-dc0d-451c-b122-b4deeb252e19.jpg
 

In [37]:

import os
for waste in ["cardboard", "glass", "metal", "paper", "plastic", "trash"]:
    count = 0
    for suffix in ["", "_aug"]:
        folder = f"/Users/rodynaamr/Image_Classification_SVM_kNN/data/{waste}{suffix}"
        if os.path.exists(folder):
            count += len([f for f in os.listdir(folder) if f.endswith(('.png', '.jpg', '.jpeg'))])
    print(f"{waste}: {count}")

cardboard: 4842
glass: 4631
metal: 4736
paper: 4963
plastic: 4740
trash: 4759
