# Lab 3: Building Custom Datasets

In this lab, we will build a custom Dataset class from scratch to understand how PyTorch loads data.

## Learning Objectives

By the end of this lab, you will be able to:
- Understand when and why to create custom Dataset classes
- Implement the three essential methods: `__init__`, `__len__`, `__getitem__`
- Build a dataset that replicates `ImageFolder` functionality
- Compare custom datasets with built-in alternatives

## 0. Setup

In [2]:
import torch
import os
import pathlib
from PIL import Image
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, datasets
import matplotlib.pyplot as plt
import random
from typing import Tuple, Dict, List

device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f'Using device: {device}')

In [3]:
import requests
import zipfile
from pathlib import Path

data_path = Path('data/')
image_path = data_path / 'pizza_steak_sushi'

if image_path.is_dir():
    print(f'{image_path} directory exists.')
else:
    print(f'Creating {image_path} directory...')
    image_path.mkdir(parents=True, exist_ok=True)
    with open(data_path / 'pizza_steak_sushi.zip', 'wb') as f:
        request = requests.get('https://github.com/mrdbourke/pytorch-deep-learning/raw/main/data/pizza_steak_sushi.zip')
        print('Downloading data...')
        f.write(request.content)
    with zipfile.ZipFile(data_path / 'pizza_steak_sushi.zip', 'r') as zip_ref:
        print('Unzipping data...')
        zip_ref.extractall(image_path)
    print('Done!')

train_dir = image_path / 'train'
test_dir = image_path / 'test'

print(f'Train dir: {train_dir}')
print(f'Test dir: {test_dir}')

## 1. Helper Function: Find Classes

Before building our custom dataset, we need a helper function to discover classes from the directory structure.

### What This Function Does:

1. **Scans a directory** for subdirectories (each subdirectory = one class)
2. **Creates a mapping** from class names to integer indices
3. **Returns both** the class list and the mapping dictionary

### Implementation Details:

We'll use `os.scandir()` to traverse the target directory (which should be in standard image classification format):
- Get the class names by finding all subdirectories
- Raise an error if no class folders are found
- Turn the class names into a dictionary of numerical labels

**Important:** This mimics what `ImageFolder` does internally!

In [4]:
def find_classes(directory: str) -> Tuple[List[str], Dict[str, int]]:
    classes = sorted(entry.name for entry in os.scandir(directory) if entry.is_dir())
    
    if not classes:
        raise FileNotFoundError(f'No class folders found in {directory}')
    
    class_to_idx = {cls_name: i for i, cls_name in enumerate(classes)}
    return classes, class_to_idx

In [5]:
# Test the function
classes, class_to_idx = find_classes(train_dir)
print(f'Classes: {classes}')
print(f'Class to idx: {class_to_idx}')

### Function Works Correctly! ‚úÖ

Our `find_classes()` function successfully:
- Found all 3 classes in alphabetical order
- Created the correct index mapping (pizza‚Üí0, steak‚Üí1, sushi‚Üí2)
- Matches the behavior of PyTorch's `ImageFolder`

## 2. Build Custom Dataset Class

Now we're ready to build our own custom Dataset.

### Why Build Custom Datasets?

While `ImageFolder` works for standard image classification, you might need custom datasets when:
- Your data has a **non-standard format** (CSV annotations, JSON metadata)
- You need **custom preprocessing** or **complex augmentations**
- You're working with **multiple data sources** (images + text + tabular)
- Your labels are stored **separately** from images

### The Three Essential Methods

Every PyTorch Dataset must implement:

| Method | Purpose |
|--------|---------|
| `__init__(self, ...)` | Initialize paths, transforms, load metadata |
| `__len__(self)` | Return total number of samples |
| `__getitem__(self, idx)` | Return one sample (image, label) given an index |

In [6]:
class ImageFolderCustom(Dataset):
    
    def __init__(self, targ_dir: str, transform=None) -> None:
        self.paths = list(pathlib.Path(targ_dir).glob('*/*.jpg'))
        self.transform = transform
        self.classes, self.class_to_idx = find_classes(targ_dir)
    
    def load_image(self, index: int) -> Image.Image:
        image_path = self.paths[index]
        return Image.open(image_path)
    
    def __len__(self) -> int:
        return len(self.paths)
    
    def __getitem__(self, index: int) -> Tuple[torch.Tensor, int]:
        img = self.load_image(index)
        class_name = self.paths[index].parent.name
        class_idx = self.class_to_idx[class_name]
        
        if self.transform:
            return self.transform(img), class_idx
        else:
            return img, class_idx

### Understanding the Implementation

Let's break down each part of our custom dataset:

**1. Subclass `torch.utils.data.Dataset`**
- Inherit from PyTorch's base Dataset class

**2. `__init__` Method**
- `targ_dir`: Target directory containing the images
- `transform`: Optional transforms to apply to images
- `self.paths`: List of all image file paths using `pathlib.Path.glob()`
- `self.classes` & `self.class_to_idx`: From our `find_classes()` function

**3. `load_image` Method**
- Helper method to load images from file using PIL
- Returns a PIL Image object
- Separated for clarity and potential customization

**4. `__len__` Method (Required)**
- Overrides the Dataset's `__len__` method
- Returns the number of samples in the dataset
- Allows calling `len(dataset)`

**5. `__getitem__` Method (Required)**
- Overrides the Dataset's `__getitem__` method
- Loads the image at the given index
- Extracts class name from the file path
- Converts class name to integer index
- Applies transforms if provided
- Returns tuple: `(transformed_image, class_index)`


## 3. Create Custom Datasets

Now let's instantiate our custom dataset class with proper transforms.

### Defining Transforms

Before testing our `ImageFolderCustom` class, we need to create transforms to prepare our images.

**Critical Distinction:**

| Transform Set | Augmentation | Purpose |
|---------------|--------------|---------|
| **Training** | ‚úÖ `RandomHorizontalFlip(p=0.5)` | Improve generalization |
| **Testing** | ‚ùå No augmentation | Consistent, fair evaluation |

**Why no augmentation for testing?**
- Test set should represent real-world conditions
- Augmentation adds randomness that makes evaluation inconsistent
- We want reproducible results for benchmarking

In [7]:
train_transforms = transforms.Compose([
    transforms.Resize((64, 64)),
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.ToTensor()
])

test_transforms = transforms.Compose([
    transforms.Resize((64, 64)),
    transforms.ToTensor()
])

### Instantiate Datasets

Now let's create `Dataset` objects for both training and testing images using our `ImageFolderCustom` class:

In [8]:
train_data_custom = ImageFolderCustom(
    targ_dir=train_dir,
    transform=train_transforms
)

test_data_custom = ImageFolderCustom(
    targ_dir=test_dir,
    transform=test_transforms
)

print(f'Train samples: {len(train_data_custom)}')
print(f'Test samples: {len(test_data_custom)}')
print(f'Classes: {train_data_custom.classes}')
print(f'Class_to_idx: {train_data_custom.class_to_idx}')

In [9]:
img, label = train_data_custom[0]
print(f'Image shape: {img.shape}')
print(f'Label: {label} ({train_data_custom.classes[label]})')

### Dataset Created Successfully! ‚úÖ

Perfect! Our custom dataset:
- Returns tensors with correct shape `[3, 64, 64]`
- Provides integer labels (0, 1, or 2)
- Can be indexed like a Python list
- Works exactly like PyTorch's built-in datasets

## 4. Compare with ImageFolder

Let's verify our custom dataset produces the same results as PyTorch's built-in `ImageFolder`.

### Verification Checklist:

| Property | Should Match |
|----------|--------------|
| Number of samples | `len(dataset)` |
| Class names | `dataset.classes` |
| Class indices | `dataset.class_to_idx` |
| Sample format | `(tensor, label)` tuple |

**Why this comparison matters:**
- Ensures our implementation is correct
- Validates that we understand how `ImageFolder` works internally
- Confirms our dataset can be used as a drop-in replacement

In [10]:
train_data_imagefolder = datasets.ImageFolder(
    root=train_dir,
    transform=train_transforms
)

print(f'Custom length: {len(train_data_custom)}')
print(f'ImageFolder length: {len(train_data_imagefolder)}')
print(f'Custom classes: {train_data_custom.classes}')
print(f'ImageFolder classes: {train_data_imagefolder.classes}')

## 5. Create DataLoaders

Custom datasets work seamlessly with `DataLoader` - this is the power of PyTorch's modular design!

### The Beauty of PyTorch's Design

The `DataLoader` doesn't care whether you're using:
- Built-in datasets (`ImageFolder`, `MNIST`, `CIFAR10`)
- Custom datasets (like our `ImageFolderCustom`)
- Any class that implements `__len__` and `__getitem__`

**As long as your dataset implements the required methods correctly, everything just works!**

### Why This Matters:

1. **Modularity** - Swap datasets without changing training code
2. **Consistency** - Same DataLoader API for all datasets
3. **Flexibility** - Build custom datasets for any use case
4. **Integration** - Works with all PyTorch training loops

In [None]:
BATCH_SIZE = 32

train_dataloader_custom = DataLoader(
    dataset=train_data_custom,
    batch_size=BATCH_SIZE,
    shuffle=True,
    num_workers=0
)

test_dataloader_custom = DataLoader(
    dataset=test_data_custom,
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=0
)

images, labels = next(iter(train_dataloader_custom))
print(f'Batch shape: {images.shape}')
print(f'Labels shape: {labels.shape}')

## 6. Visualize Custom Dataset

Let's verify our custom dataset returns properly formatted images by visualizing a random sample.

**What we're checking:**
- Images display correctly (no corruption)
- Transforms are applied properly (64√ó64 size)
- Labels match the images
- Random sampling works as expected

In [None]:
def display_random_images(dataset, classes, n=6, seed=42):
    random.seed(seed)
    random_idx = random.sample(range(len(dataset)), k=n)
    
    fig, axes = plt.subplots(2, 3, figsize=(12, 8))
    for i, idx in enumerate(random_idx):
        img, label = dataset[idx]
        ax = axes[i // 3, i % 3]
        ax.imshow(img.permute(1, 2, 0))
        ax.set_title(f'{classes[label]}')
        ax.axis('off')
    plt.tight_layout()
    plt.show()

In [None]:
display_random_images(train_data_custom, train_data_custom.classes)

### Custom Dataset Working Perfectly! üéâ

Excellent! Our `ImageFolderCustom` dataset is working exactly as expected:
- Images are loaded correctly from disk
- Transforms are applied (Resize, RandomFlip, ToTensor)
- Labels match the visual content
- Ready for training with DataLoader

We've successfully built a custom dataset from scratch!

## Key Takeaways

### The Three Essential Methods

| Method | Purpose | Implementation |
|--------|---------|----------------|
| **`__init__`** | Setup: store paths, find classes, define transforms | Initialize all attributes |
| **`__len__`** | Return `len(self.paths)` - total samples | Enable `len(dataset)` |
| **`__getitem__`** | Load image, apply transform, return (tensor, label) | Enable indexing `dataset[i]` |

### What We Learned

1. **Custom Dataset Structure** - How to subclass `torch.utils.data.Dataset`
2. **Helper Functions** - Using `find_classes()` to discover labels
3. **Transform Separation** - Different transforms for train/test
4. **DataLoader Integration** - Custom datasets work seamlessly
5. **Verification** - Comparing with built-in `ImageFolder`

### When to Use Custom vs ImageFolder

| Use Case | Recommendation | Reason |
|----------|----------------|--------|
| Standard image classification folders | `ImageFolder` | Fast, tested, optimized |
| Annotations in CSV/JSON | Custom Dataset | Need custom label parsing |
| Multiple input types (image + text) | Custom Dataset | Flexible data handling |
| Complex preprocessing pipeline | Custom Dataset | Full control over loading |
| Learning PyTorch internals | Custom Dataset | Educational value |

### Custom Dataset Advantages

‚úÖ **Complete control** over data loading process  
‚úÖ **Flexible** - handle any data format or structure  
‚úÖ **Customizable** - add special preprocessing logic  
‚úÖ **Educational** - understand PyTorch internals  

### What's Next?

In **Lab 4**, we'll build and train a **TinyVGG** convolutional neural network using our complete data pipeline!

**The full workflow:**
```
Custom Dataset ‚Üí DataLoader ‚Üí TinyVGG Model ‚Üí Training ‚Üí Evaluation
```

---

**Remember:** The power of custom datasets is flexibility. While `ImageFolder` is great for standard use cases, custom datasets unlock PyTorch's full potential for complex data scenarios!