### Overview of the below code:

# Class: `AugmentedDataset`

#### Purpose
- Manage image datasets.
- Apply specified transformations and augmentations to images.
- Provide methods to retrieve, display, and save augmented images.

#### Constructor: `__init__`
- **Parameters**:
  - `data_path`: Path to the directory containing images.
  - `base_transform`: (Optional) Base transformation to apply to images.
  - `aug_transforms`: (Optional) List of augmentation transformations to apply to images.
- **Attributes**:
  - Initializes the dataset by storing the directory path, base transformation, and augmentation transformations.
  - Collects the list of image filenames in the specified directory.

#### Method: `__len__`
- Returns the number of images in the dataset.

#### Method: `__getitem__`
- **Parameters**:
  - `index`: Index of the image to retrieve.
- **Returns**:
  - Original image.
  - List of augmented images.
  - Image name.
  - Names of the applied transformations.

#### Method: `_get_transform_names`
- **Parameters**:
  - `transform`: A transformation or a `Compose` of transformations.
- **Returns**:
  - String representation of the transformation(s) applied.
- Handles both single transformations and composed transformations.

#### Method: `display_images`
- Displays original and augmented images from the dataset using `matplotlib`.
- Iterates over the dataset, displaying each original image alongside its augmented versions with titles indicating the applied transformations.

#### Method: `aug_img_dir`
- **Parameters**:
  - `directory`: Path to the directory where augmented images will be saved.
- Saves augmented images to the specified directory.
- Creates the directory if it does not exist.
- Saves the original and augmented images with filenames indicating the applied transformations.

In [None]:
import os
from PIL import Image
import matplotlib.pyplot as plt
from torch.utils.data import Dataset
from torchvision import transforms as T

class AugmentedDataset(Dataset):
    """
    Dataset class for augmented images.

    Args:
        data_path (str): Path to the directory containing images.
        base_transform (callable, optional): Base transformation to apply to images.
        aug_transforms (list of callables, optional): List of augmentation transformations to apply to images.

    Attributes:
        data_path (str): Path to the directory containing images.
        base_transform (callable, optional): Base transformation to apply to images.
        aug_transforms (list of callables, optional): List of augmentation transformations to apply to images.
        labels (list): List of image filenames in the dataset.

    Methods:
        __len__(): Returns the length of the dataset.
        __getitem__(index): Returns the original and augmented images at the specified index.
        display_images(): Displays original and augmented images from the dataset.
        aug_img_dir(directory): Saves augmented images to a specified directory.
    """

    def __init__(self, data_path, base_transform=None, aug_transforms=None):
        self.data_path = data_path
        self.base_transform = base_transform
        self.aug_transforms = aug_transforms
        self.labels = [i for i in os.listdir(data_path) if i.endswith(('jpg', 'png', 'jpeg'))]

    def __len__(self):
        """Returns the length of the dataset."""
        return len(self.labels)

    def __getitem__(self, index):
        """
        Returns the original and augmented images at the specified index.

        Args:
            index (int): Index of the image to retrieve.

        Returns:
            tuple: Tuple containing original image, list of augmented images, and image name.
        """
        image_name = self.labels[index]
        image_dest = os.path.join(self.data_path, image_name)
        image = Image.open(image_dest).convert("RGB")

        org_img = image
        aug_img = []
        transform_names = []
        if self.base_transform:
            org_img = self.base_transform(image)

        if self.aug_transforms:
            for transform in self.aug_transforms:
                aug_img.append(transform(image))
                transform_names.append(self._get_transform_names(transform))

        return org_img, aug_img, image_name, transform_names

    def _get_transform_names(self, transform):
        """
        Returns a string representation of the transformation(s) applied.

        Args:
            transform (callable): A transformation or a Compose of transformations.

        Returns:
            str: Names of the transformations.
        """
        if isinstance(transform, T.Compose):
            return ', '.join([t.__class__.__name__ for t in transform.transforms])
        else:
            return transform.__class__.__name__

    def display_images(self):
        """Displays original and augmented images from the dataset."""
        for i in range(len(self)):
            org_img, aug_img, image_name, transform_names = self[i]

            plt.figure(figsize=(15, 5))

            plt.subplot(1, len(aug_img) + 1, 1)
            plt.imshow(org_img)
            plt.title(f"Original - {image_name}")
            plt.axis('off')

            for idx, img in enumerate(aug_img):
                plt.subplot(1, len(aug_img) + 1, idx + 2)
                plt.imshow(img)
                plt.title(f"{transform_names[idx]}")
                plt.axis('off')

            plt.show()

    def aug_img_dir(self, directory):
        """
        Saves augmented images to a specified directory.

        Args:
            directory (str): Directory path where augmented images will be saved.
        """
        if not os.path.exists(directory):
            os.makedirs(directory)

        for i in range(len(self)):
            org_img, aug_img, image_name, transform_names = self[i]
            org_img_path = os.path.join(directory, f"org_{image_name}")
            org_img.save(org_img_path)

            for j, aug_image in enumerate(aug_img):
                aug_img_path = os.path.join(directory, f"{image_name.split('.')[0]}_{transform_names[j]}.{image_name.split('.')[-1]}")
                aug_image.save(aug_img_path)
