In [1]:
import albumentations as A
import cv2
import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path
import random
from typing import Dict, List, Tuple, Optional, Union
import json

class CustomAugmentationPipeline:
    """
    A flexible data augmentation pipeline using Albumentations.
    Supports different augmentation strategies for various use cases.
    """

    def __init__(self, config: Optional[Dict] = None):
        """
        Initialize the augmentation pipeline with configuration.

        Args:
            config: Dictionary containing augmentation parameters
        """
        self.config = config or self._get_default_config()
        self.pipelines = self._create_pipelines()

    def _get_default_config(self) -> Dict:
        """Default configuration for augmentation parameters."""
        return {
            'image_size': (256, 256),
            'geometric': {
                'rotate_limit': 30,
                'shift_limit': 0.1,
                'scale_limit': 0.2,
                'perspective_scale': 0.05,
                'probability': 0.7
            },
            'color': {
                'brightness_limit': 0.2,
                'contrast_limit': 0.2,
                'saturation_limit': 0.2,
                'hue_shift_limit': 20,
                'probability': 0.6
            },
            'blur_noise': {
                'blur_limit': 3,
                'noise_var_limit': (10.0, 50.0),
                'jpeg_quality': (70, 100),
                'probability': 0.4
            },
            'weather': {
                'rain_intensity': (0.1, 0.3),
                'fog_coef': (0.1, 0.3),
                'sun_intensity': (0.1, 0.4),
                'probability': 0.3
            }
        }

    def _create_pipelines(self) -> Dict[str, A.Compose]:
        """Create different augmentation pipelines for various scenarios."""

        pipelines = {}

        # Light augmentation pipeline (for validation/testing)
        pipelines['light'] = A.Compose([
            A.Resize(*self.config['image_size']),
            A.HorizontalFlip(p=0.5),
            A.RandomBrightnessContrast(
                brightness_limit=0.1,
                contrast_limit=0.1,
                p=0.3
            ),
            A.Normalize(
                mean=[0.485, 0.456, 0.406],
                std=[0.229, 0.224, 0.225]
            )
        ])

        # Medium augmentation pipeline (for general training)
        pipelines['medium'] = A.Compose([
            A.Resize(*self.config['image_size']),

            # Geometric transformations
            A.OneOf([
                A.Rotate(
                    limit=self.config['geometric']['rotate_limit'],
                    border_mode=cv2.BORDER_CONSTANT,
                    p=1.0
                ),
                A.ShiftScaleRotate(
                    shift_limit=self.config['geometric']['shift_limit'],
                    scale_limit=self.config['geometric']['scale_limit'],
                    rotate_limit=self.config['geometric']['rotate_limit'],
                    border_mode=cv2.BORDER_CONSTANT,
                    p=1.0
                ),
                A.Perspective(
                    scale=self.config['geometric']['perspective_scale'],
                    p=1.0
                )
            ], p=self.config['geometric']['probability']),

            # Flips
            A.HorizontalFlip(p=0.5),
            A.VerticalFlip(p=0.2),

            # Color transformations
            A.OneOf([
                A.ColorJitter(
                    brightness=self.config['color']['brightness_limit'],
                    contrast=self.config['color']['contrast_limit'],
                    saturation=self.config['color']['saturation_limit'],
                    hue=self.config['color']['hue_shift_limit'] / 360.0,
                    p=1.0
                ),
                A.HueSaturationValue(
                    hue_shift_limit=self.config['color']['hue_shift_limit'],
                    sat_shift_limit=int(self.config['color']['saturation_limit'] * 100),
                    val_shift_limit=int(self.config['color']['brightness_limit'] * 100),
                    p=1.0
                )
            ], p=self.config['color']['probability']),

            A.Normalize(
                mean=[0.485, 0.456, 0.406],
                std=[0.229, 0.224, 0.225]
            )
        ])

        # Heavy augmentation pipeline (for robust training)
        pipelines['heavy'] = A.Compose([
            A.Resize(*self.config['image_size']),

            # Geometric transformations
            A.OneOf([
                A.ShiftScaleRotate(
                    shift_limit=self.config['geometric']['shift_limit'],
                    scale_limit=self.config['geometric']['scale_limit'],
                    rotate_limit=self.config['geometric']['rotate_limit'],
                    border_mode=cv2.BORDER_CONSTANT,
                    p=1.0
                ),
                A.Perspective(
                    scale=self.config['geometric']['perspective_scale'],
                    p=1.0
                ),
                A.GridDistortion(p=1.0),
                A.OpticalDistortion(p=1.0)
            ], p=self.config['geometric']['probability']),

            # Flips
            A.HorizontalFlip(p=0.5),
            A.VerticalFlip(p=0.3),
            A.Transpose(p=0.2),

            # Crops and cuts
            A.OneOf([
                A.RandomCrop(
                    height=int(self.config['image_size'][0] * 0.8),
                    width=int(self.config['image_size'][1] * 0.8),
                    p=1.0
                ),
                A.CenterCrop(
                    height=int(self.config['image_size'][0] * 0.9),
                    width=int(self.config['image_size'][1] * 0.9),
                    p=1.0
                ),
                A.CoarseDropout(
                    max_holes=8,
                    max_height=16,
                    max_width=16,
                    p=1.0
                )
            ], p=0.4),

            # Ensure consistent size after crops
            A.Resize(*self.config['image_size']),

            # Color transformations
            A.OneOf([
                A.ColorJitter(
                    brightness=self.config['color']['brightness_limit'],
                    contrast=self.config['color']['contrast_limit'],
                    saturation=self.config['color']['saturation_limit'],
                    hue=self.config['color']['hue_shift_limit'] / 360.0,
                    p=1.0
                ),
                A.CLAHE(p=1.0),
                A.RandomGamma(p=1.0),
                A.ChannelShuffle(p=1.0)
            ], p=self.config['color']['probability']),

            # Blur and noise
            A.OneOf([
                A.Blur(blur_limit=self.config['blur_noise']['blur_limit'], p=1.0),
                A.GaussianBlur(blur_limit=self.config['blur_noise']['blur_limit'], p=1.0),
                A.MotionBlur(blur_limit=self.config['blur_noise']['blur_limit'], p=1.0),
                A.GaussNoise(var_limit=self.config['blur_noise']['noise_var_limit'], p=1.0),
                A.ImageCompression(
                    quality_lower=self.config['blur_noise']['jpeg_quality'][0],
                    quality_upper=self.config['blur_noise']['jpeg_quality'][1],
                    p=1.0
                )
            ], p=self.config['blur_noise']['probability']),

            # Weather effects
            A.OneOf([
                A.RandomRain(
                    blur_value=2,
                    brightness_coefficient=0.9,
                    p=1.0
                ),
                A.RandomFog(
                    fog_coef_lower=self.config['weather']['fog_coef'][0],
                    fog_coef_upper=self.config['weather']['fog_coef'][1],
                    p=1.0
                ),
                A.RandomSunFlare(
                    flare_roi=(0, 0, 1, 0.5),
                    angle_lower=0,
                    angle_upper=1,
                    p=1.0
                ),
                A.RandomShadow(p=1.0)
            ], p=self.config['weather']['probability']),

            A.Normalize(
                mean=[0.485, 0.456, 0.406],
                std=[0.229, 0.224, 0.225]
            )
        ])

        # Custom segmentation pipeline
        pipelines['segmentation'] = A.Compose([
            A.Resize(*self.config['image_size']),
            A.HorizontalFlip(p=0.5),
            A.ShiftScaleRotate(
                shift_limit=0.05,
                scale_limit=0.1,
                rotate_limit=15,
                border_mode=cv2.BORDER_CONSTANT,
                value=0,
                mask_value=0,
                p=0.7
            ),
            A.ColorJitter(
                brightness=0.1,
                contrast=0.1,
                saturation=0.1,
                hue=0.05,
                p=0.5
            ),
            A.OneOf([
                A.Blur(blur_limit=3, p=1.0),
                A.GaussNoise(var_limit=(10.0, 30.0), p=1.0)
            ], p=0.3),
            A.Normalize(
                mean=[0.485, 0.456, 0.406],
                std=[0.229, 0.224, 0.225]
            )
        ])

        return pipelines

    def augment(self, image: np.ndarray, pipeline: str = 'medium',
                mask: Optional[np.ndarray] = None) -> Union[Dict, np.ndarray]:
        """
        Apply augmentation to image and optionally mask.

        Args:
            image: Input image as numpy array
            pipeline: Pipeline type ('light', 'medium', 'heavy', 'segmentation')
            mask: Optional mask for segmentation tasks

        Returns:
            Augmented image (and mask if provided)
        """
        if pipeline not in self.pipelines:
            raise ValueError(f"Pipeline '{pipeline}' not found. Available: {list(self.pipelines.keys())}")

        if mask is not None:
            result = self.pipelines[pipeline](image=image, mask=mask)
            return result
        else:
            result = self.pipelines[pipeline](image=image)
            return result['image']

    def augment_batch(self, images: List[np.ndarray], pipeline: str = 'medium',
                     masks: Optional[List[np.ndarray]] = None) -> List:
        """
        Apply augmentation to a batch of images.

        Args:
            images: List of images
            pipeline: Pipeline type
            masks: Optional list of masks

        Returns:
            List of augmented images (and masks if provided)
        """
        results = []

        for i, image in enumerate(images):
            mask = masks[i] if masks is not None else None
            result = self.augment(image, pipeline, mask)
            results.append(result)

        return results

    def visualize_augmentations(self, image: np.ndarray, num_samples: int = 4,
                              pipeline: str = 'medium', figsize: Tuple = (15, 10)):
        """
        Visualize different augmentation results.

        Args:
            image: Input image
            num_samples: Number of augmented samples to show
            pipeline: Pipeline type
            figsize: Figure size for visualization
        """
        fig, axes = plt.subplots(1, num_samples + 1, figsize=figsize)

        # Original image
        if len(image.shape) == 3 and image.shape[2] == 3:
            axes[0].imshow(cv2.cvtColor(image, cv2.COLOR_BGR2RGB))
        else:
            axes[0].imshow(image, cmap='gray')
        axes[0].set_title('Original')
        axes[0].axis('off')

        # Augmented images
        for i in range(num_samples):
            aug_image = self.augment(image.copy(), pipeline)

            # Denormalize if normalized
            if aug_image.dtype == np.float32 and aug_image.min() < 0:
                mean = np.array([0.485, 0.456, 0.406])
                std = np.array([0.229, 0.224, 0.225])
                aug_image = aug_image * std + mean
                aug_image = np.clip(aug_image, 0, 1)

            if len(aug_image.shape) == 3 and aug_image.shape[2] == 3:
                axes[i + 1].imshow(aug_image)
            else:
                axes[i + 1].imshow(aug_image, cmap='gray')
            axes[i + 1].set_title(f'Augmented {i + 1}')
            axes[i + 1].axis('off')

        plt.tight_layout()
        plt.show()

    def save_config(self, filepath: str):
        """Save current configuration to JSON file."""
        with open(filepath, 'w') as f:
            json.dump(self.config, f, indent=4)

    def load_config(self, filepath: str):
        """Load configuration from JSON file."""
        with open(filepath, 'r') as f:
            self.config = json.load(f)
        self.pipelines = self._create_pipelines()

    def get_pipeline_info(self) -> Dict:
        """Get information about available pipelines."""
        info = {}
        for name, pipeline in self.pipelines.items():
            transforms = []
            for transform in pipeline.transforms:
                transforms.append({
                    'name': transform.__class__.__name__,
                    'probability': getattr(transform, 'p', 1.0)
                })
            info[name] = {
                'num_transforms': len(transforms),
                'transforms': transforms
            }
        return info


# Example usage and testing
def main():
    """Example usage of the CustomAugmentationPipeline."""

    # Create a sample image (replace with your actual image loading)
    sample_image = np.random.randint(0, 255, (256, 256, 3), dtype=np.uint8)

    # Initialize the pipeline
    augmenter = CustomAugmentationPipeline()

    # Print pipeline information
    print("Available pipelines:")
    pipeline_info = augmenter.get_pipeline_info()
    for name, info in pipeline_info.items():
        print(f"\n{name.upper()} Pipeline:")
        print(f"  - Number of transforms: {info['num_transforms']}")
        for transform in info['transforms']:
            print(f"  - {transform['name']} (p={transform['probability']})")

    # Example 1: Single image augmentation
    print("\n" + "="*50)
    print("SINGLE IMAGE AUGMENTATION")
    print("="*50)

    aug_image = augmenter.augment(sample_image, pipeline='medium')
    print(f"Original shape: {sample_image.shape}")
    print(f"Augmented shape: {aug_image.shape}")

    # Example 2: Batch augmentation
    print("\n" + "="*50)
    print("BATCH AUGMENTATION")
    print("="*50)

    batch_images = [sample_image.copy() for _ in range(3)]
    aug_batch = augmenter.augment_batch(batch_images, pipeline='light')
    print(f"Batch size: {len(aug_batch)}")
    print(f"Each augmented image shape: {aug_batch[0].shape}")

    # Example 3: Segmentation with mask
    print("\n" + "="*50)
    print("SEGMENTATION WITH MASK")
    print("="*50)

    # Create a sample mask
    sample_mask = np.random.randint(0, 2, (256, 256), dtype=np.uint8)

    seg_result = augmenter.augment(sample_image, pipeline='segmentation', mask=sample_mask)
    print(f"Segmentation result keys: {seg_result.keys()}")
    print(f"Image shape: {seg_result['image'].shape}")
    print(f"Mask shape: {seg_result['mask'].shape}")

    # Example 4: Custom configuration
    print("\n" + "="*50)
    print("CUSTOM CONFIGURATION")
    print("="*50)

    custom_config = {
        'image_size': (224, 224),
        'geometric': {
            'rotate_limit': 45,
            'shift_limit': 0.2,
            'scale_limit': 0.3,
            'perspective_scale': 0.1,
            'probability': 0.8
        },
        'color': {
            'brightness_limit': 0.3,
            'contrast_limit': 0.3,
            'saturation_limit': 0.3,
            'hue_shift_limit': 30,
            'probability': 0.7
        },
        'blur_noise': {
            'blur_limit': 5,
            'noise_var_limit': (10.0, 80.0),
            'jpeg_quality': (60, 100),
            'probability': 0.5
        },
        'weather': {
            'rain_intensity': (0.2, 0.5),
            'fog_coef': (0.2, 0.5),
            'sun_intensity': (0.2, 0.6),
            'probability': 0.4
        }
    }

    custom_augmenter = CustomAugmentationPipeline(custom_config)
    custom_aug = custom_augmenter.augment(sample_image, pipeline='heavy')
    print(f"Custom augmented shape: {custom_aug.shape}")

    # Save and load configuration
    augmenter.save_config('augmentation_config.json')
    print("Configuration saved to 'augmentation_config.json'")

    # Visualize augmentations (uncomment to use)
    # augmenter.visualize_augmentations(sample_image, num_samples=4, pipeline='medium')


if __name__ == "__main__":
    main()

  original_init(self, **validated_kwargs)
  A.CoarseDropout(
  A.GaussNoise(var_limit=self.config['blur_noise']['noise_var_limit'], p=1.0),
  A.ImageCompression(
  A.RandomFog(
  A.RandomSunFlare(
  A.ShiftScaleRotate(
  A.GaussNoise(var_limit=(10.0, 30.0), p=1.0)


Available pipelines:

LIGHT Pipeline:
  - Number of transforms: 4
  - Resize (p=1.0)
  - HorizontalFlip (p=0.5)
  - RandomBrightnessContrast (p=0.3)
  - Normalize (p=1.0)

MEDIUM Pipeline:
  - Number of transforms: 6
  - Resize (p=1.0)
  - OneOf (p=0.7)
  - HorizontalFlip (p=0.5)
  - VerticalFlip (p=0.2)
  - OneOf (p=0.6)
  - Normalize (p=1.0)

HEAVY Pipeline:
  - Number of transforms: 11
  - Resize (p=1.0)
  - OneOf (p=0.7)
  - HorizontalFlip (p=0.5)
  - VerticalFlip (p=0.3)
  - Transpose (p=0.2)
  - OneOf (p=0.4)
  - Resize (p=1.0)
  - OneOf (p=0.6)
  - OneOf (p=0.4)
  - OneOf (p=0.3)
  - Normalize (p=1.0)

SEGMENTATION Pipeline:
  - Number of transforms: 6
  - Resize (p=1.0)
  - HorizontalFlip (p=0.5)
  - ShiftScaleRotate (p=0.7)
  - ColorJitter (p=0.5)
  - OneOf (p=0.3)
  - Normalize (p=1.0)

SINGLE IMAGE AUGMENTATION
Original shape: (256, 256, 3)
Augmented shape: (256, 256, 3)

BATCH AUGMENTATION
Batch size: 3
Each augmented image shape: (256, 256, 3)

SEGMENTATION WITH MASK
Segme