# Data Management in PyTorch

## Datasets

Often times the data that is provided to us may not be in a form that can be directly used in classes like DataLoader. For example the image names may be generic, and the labels might be in some other file (like a .mat file). In such cases. It is crucial that we learn how to define our own Datasets and Dataloaders.

## Data Access
### Defining our own Custom Datasets

In [None]:
import os

class OxfordFlowersDataset(Dataset):

  def __init__(self, root_dir):
    self.root_dir = root_dir
    self.img_dir = os.path.join(root_dir, 'images')

    labels_matlab = scipy.io.loadmat(os.path.join(root_dir, 'imagelabels.mat'))

    self.labels = labels_matlab['labels'][0] - 1

Here the method adopted is called Lazy loading of Data, because if we initialize the class with data as it is, it uses up alot of RAM, which is unnecessary. Instead, we just mention where to find the data.

labels are adjusted by subtracting 1 because PyTorch expects that the class labels start from 0. 

In [None]:
def __len__(self):
  return len(self.labels)

Used for returning the total number of samples in the dataset.

In [None]:
from PIL import Image

def __getitem__(self, idx):

  img_name = f'image_{idx+1:05d}.jpg'
  img_path = os.path.join(self.img_dir, img_name)

  image = Image.open(img_path)
  label = self.labels[idx]

  return image, label

This dunder function is used to return the image and its corresponding label for the index provided. img_name depends on the actual data in the directory. This only works for the image pattern in the actual dataset. 

Also here idx is incremented by 1 because the dataset images start with image_00001. If 1 was not added, it would have taken image number 00000 (or image_00000) which does not exist. 

So study the data, especially its metadata.

## Transform Pipelines (Quality)

### Learning why raw data won't work

Batching won't work because pytorch expects that the items in a batch are of same dimensions. Which is rarely the case for image data. Also, PyTorch expects tensors, not image data.

In [None]:
transform = transforms.Compose([
  transforms.Resize(256),
  transforms.CenterCrop(224),
])

transforms.Resize(256) resizes the shorter edge to 256 whilst preserving the aspect ratio of the image. Hard resizing where we give both dimensions (256, 256) would distort the image. 

Then transforms.CenterCrop(224) is used to obtain the middle portion (the 224x224 square image) of the image.

Now to convert the images into tensors:

In [None]:
transform = transforms.Compose([
  transforms.Resize(256),
  transforms.CenterCrop(224),
  transforms.ToTensor(),        #  <------- Add this
  transforms.Normalize(mean= [...],
                        std= [...])
])

ToTensor() is called 'The tensor Bridge'. Before the bridge the data type is image. After the bridge, the data is tensor. So applying transforms that could only be applied to tensors to images would cause errors. So handle that properly.

Adding transforms to the OxfordFlowersDataset class:

In [None]:
class OxfordFlowersDataset(Dataset):

  def __init__(self, root_dir, transform = None):
    # all other code
    self.transform = transform

  def __getitem__(self, idx):
    # all other code
    if self.transform:
      image = self.transform(image)
    return image, label

Now it could be batched:

In [None]:
dataloader = DataLoader(dataset, batch_size=4, shuffle=True)

For debugging, Take single datapoints and apply transforms individually. 

## DataLoader

### Splitting the data

In [None]:
train_dataset, val_dataset, test_dataset = random_split(
  dataset, [train_size, val_size, test_size]
)

This gives a good mix the entire data and distributes them according to the sizes mentioned.

### Batching using DataLoader

iterating through the dataloader object gives us batch-wise data.
For iterating through the first batch without starting a loop:

In [None]:
images, label = next(iter(train_loader))

## Bug-proofing

### On-the-fly transformation of PyTorch

Random transforms are applied to the training dataset as it is loaded, without extra memory usages, so that the model see different versions of the same image each time.

In [None]:
train_transform = transforms.Compose([
  #Random augmentation transforms
  transforms.RandomHorizontalFlip(p=0.5),
  transforms.RandomRotation(degrees=10),
  transforms.ColorJitter(brightness=0.2),

  #Other preprocessing steps
  transforms.Resize(256),
  transforms.CenterCrop(224),
  transforms.ToTensor(),
  transforms.Normalize(mean= [...],
                        std= [...])
])


### Corrupted files (Gracefully handling)

In __getitem__ function include:

In [None]:
image.verify()
image = Image.open(img_path)      #        <---- Reopen the image, because verify, closes the file.

if image.size[0] < 32 or image.size[1] < 32:
  raise ValueError(f"Image too small")

if image.mode != 'RGB':      #        <---- Converting to RGB
  image = image.convert('RGB')

In case of other Exceptions, take the next idx:

In [None]:
next_idx = (idx + 1) % len(self)
return self.__getitem__(next_idx)

### Monitoring data

In [None]:
def __getitem__(self, idx):
  import time
  start_time = time.time()

  self.access_counts[idx] = self.access_counts.get(idx, 0) + 1

  result = super().__getitem__(idx)

  load_time = time.time() - start_time
  self.load_times.append(load_time)

  if load_time > 1.0:
    print(f" Slow load for image index : {idx}"
          "Time taken: {load_time:.2f}s")
  return result

#Important note on data augmentation and data subsets created using random-split

Splits produced from random split references the Dataset class from which it is partitioned. So we can't change the transformations seperately. So we'll need to describe something new. 

**This is important because we only need augmentation transforms for training data, but regular transformations for testing and validation.**

In [None]:
class SubsetWithTransform(Dataset):
    """
    A wrapper for a PyTorch Subset that applies a specific transformation.

    This class allows for applying a different set of transformations to a
    subset of a dataset, which is useful for creating distinct training,
    validation, or test sets with different preprocessing steps from the
    same base dataset.
    """
    def __init__(self, subset, transform=None):
        """
        Initializes the SubsetWithTransform object.

        Args:
            subset: A PyTorch Subset object containing a portion of a dataset.
            transform (callable, optional): An optional transform to be applied
                to the samples within this subset.
        """
        # Store the original subset of the dataset.
        self.subset = subset
        # Store the transformations to be applied.
        self.transform = transform

    def __len__(self):
        """
        Returns the total number of samples in the subset.
        """
        # Return the length of the underlying subset.
        return len(self.subset)

    def __getitem__(self, idx):
        """
        Retrieves a sample and applies the transform.

        Args:
            idx (int): The index of the sample to retrieve.

        Returns:
            tuple: A tuple containing the transformed image and its label.
        """
        # Get the original image and label from the underlying subset.
        image, label = self.subset[idx]
        # Check if a transform has been provided.
        if self.transform:
            # Apply the transform to the image.
            image = self.transform(image)
        # Return the transformed image and its label.
        return image, label

This looks exactly like the Dataset class we'd described. But it allows us to add custom transforms.

# Robust Error Handling 

In [None]:
class RobustFlowerDataset(Dataset):
    """
    A custom dataset class with robust error handling for loading images.

    This class is designed to gracefully handle issues with individual data
    samples, such as corrupted files or incorrect formats. It logs any errors
    and attempts to load a different sample instead of crashing.
    """
    def __init__(self, root_dir, transform=None):
        """
        Initializes the dataset object.

        Args:
            root_dir (str): The root directory where the dataset is stored.
            transform (callable, optional): Optional transform to be applied
                on a sample.
        """
        # Store the root directory path.
        self.root_dir = root_dir
        # Construct the full path to the image directory.
        self.img_dir = os.path.join(root_dir, "jpg")
        # Store the optional transformations.
        self.transform = transform
        # Load and process the labels from the corresponding file.
        self.labels = self.load_and_correct_labels()
        # Initialize a list to keep track of any errors encountered.
        self.error_logs = []

    def __getitem__(self, idx):
        """
        Retrieves a sample, handling errors by trying the next available item.

        Args:
            idx (int): The index of the sample to retrieve.

        Returns:
            tuple: A tuple containing the image and its label.
        """
        # Loop to attempt loading a valid sample, preventing an infinite loop.
        for attempt in range(len(self)):
            # Attempt to load and process the sample.
            try:
                # Retrieve the image using the helper method.
                image = self.retrieve_image(idx)
                # Check if a transform has been provided.
                if self.transform:
                    # Apply the transform to the image.
                    image = self.transform(image)
                # Get the label for the current index.
                label = self.labels[idx]
                # Return the valid image and its corresponding label.
                return image, label
            # Catch any exception that occurs during the process.
            except Exception as e:
                # Log the error with its index and message.
                self.log_error(idx, e)
                # Move to the next index, wrapping around if necessary.
                idx = (idx + 1) % len(self)

    def __len__(self):
        """
        Returns the total number of samples in the dataset.
        """
        # The total number of samples is the number of labels.
        return len(self.labels)

    def retrieve_image(self, idx):
        """
        Loads and validates a single image from disk.

        Args:
            idx (int): The index of the image to load.

        Returns:
            PIL.Image.Image: The validated and loaded image object.
        """
        # Construct the image filename based on the index.
        img_name = f"image_{idx+1:05d}.jpg"
        # Construct the full path to the image file.
        img_path = os.path.join(self.img_dir, img_name)
        # Open the image file to check its integrity without loading fully.
        with Image.open(img_path) as img:
            # Perform a quick verification of the file's structure.
            img.verify()
        # Re-open the image file after successful verification.
        image = Image.open(img_path)
        # Fully load the image data into memory.
        image.load()
        # Check if the image dimensions are below a minimum threshold.
        if image.size[0] < 32 or image.size[1] < 32:
            # Raise an error for images that are too small.
            raise ValueError(f"Image too small: {image.size}")
        # Check if the image is not in the RGB color mode.
        if image.mode != "RGB":
            # Convert the image to RGB.
            image = image.convert("RGB")
        # Return the fully loaded and validated image.
        return image

    def load_and_correct_labels(self):
        """
        Loads labels from a .mat file and adjusts them.

        Returns:
            numpy.ndarray: An array of zero-indexed integer labels.
        """
        # Load the MATLAB file containing the labels.
        self.labels_mat = scipy.io.loadmat(
            os.path.join(self.root_dir, "imagelabels.mat")
        )
        # Extract the labels array and correct for zero-based indexing.
        labels = self.labels_mat["labels"][0] - 1
        # Truncate the dataset to the first 10 labels for quick testing.
        labels = labels[:10]
        # Return the processed labels.
        return labels

    def log_error(self, idx, e):
        """
        Records the details of an error encountered during data loading.

        Args:
            idx (int): The index of the problematic sample.
            e (Exception): The exception object that was raised.
        """
        # Construct the filename of the problematic image.
        img_name = f"image_{idx + 1:05d}.jpg"
        # Construct the full path to the image file.
        img_path = os.path.join(self.img_dir, img_name)
        # Append a dictionary with error details to the log.
        self.error_logs.append(
            {
                "index": idx,
                "error": str(e),
                "path": img_path if "img_path" in locals() else "unknown",
            }
        )
        # Print a warning to the console about the skipped image.
        print(f"Warning: Skipping corrupted image {idx}: {e}")

    def get_error_summary(self):
        """
        Prints a summary of all errors encountered during dataset processing.
        """
        # Check if the error log is empty.
        if not self.error_logs:
            # Print a message indicating the dataset is clean.
            print("No errors encountered - dataset is clean!")
        else:
            # Print the total number of problematic images found.
            print(f"\nEncountered {len(self.error_logs)} problematic images:")
            # Iterate through the first few logged errors.
            for error in self.error_logs[:5]:
                # Print the details of an individual error.
                print(f"  Index {error['index']}: {error['error']}")
            # Check if there are more errors than were displayed.
            if len(self.error_logs) > 5:
                # Print a summary of the remaining errors.
                print(f"  ... and {len(self.error_logs) - 5} more")

The above implementation incorporates robust error handling.