# Unit 4.4: Defining Efficient Data Loaders

In [1]:
%load_ext watermark
%watermark -v -p matplotlib,numpy,pandas,torchvision,torch --conda

Python implementation: CPython
Python version       : 3.12.7
IPython version      : 8.27.0

matplotlib : 3.9.2
numpy      : 1.26.4
pandas     : 2.2.3
torchvision: 0.20.1
torch      : 2.5.1

conda environment: base



## 1) Defining the Dataset Class

In [2]:
import os

import matplotlib.pyplot as plt
import pandas as pd
from PIL import Image
from torch.utils.data import Dataset



class MyDataset(Dataset):
    def __init__(self, csv_path, img_dir, transform=None):

        df = pd.read_csv(csv_path)
        self.img_dir = img_dir
        self.transform = transform

        # based on DataFrame columns
        self.img_names = df["filepath"]
        self.labels = df["label"]

    def __getitem__(self, index):
        img = Image.open(os.path.join(self.img_dir, self.img_names[index]))

        if self.transform is not None:
            img = self.transform(img)

        label = self.labels[index]
        return img, label

    def __len__(self):
        return self.labels.shape[0]

## 2) Defining an optional batch visualization function

In [3]:
def viz_batch_images(batch):

    plt.figure(figsize=(8, 8))
    plt.axis("off")
    plt.title("Training images")
    plt.imshow(
        np.transpose(
            vutils.make_grid(batch[0][:64], padding=2, normalize=True), (1, 2, 0)
        )
    )
    plt.show()

## 3) Defining optional image transformations

In [4]:
from torchvision import transforms

data_transforms = {
    "train": transforms.Compose(
        [
            transforms.Resize(32),
            transforms.RandomCrop((28, 28)),
            transforms.ToTensor(),
            # normalize images to [-1, 1] range
            transforms.Normalize((0.5,), (0.5,)),
        ]
    ),
    "test": transforms.Compose(
        [
            transforms.Resize(32),
            transforms.CenterCrop((28, 28)),
            transforms.ToTensor(),
            # normalize images to [-1, 1] range
            transforms.Normalize((0.5,), (0.5,)),
        ]
    ),
}

## 4) Defining the data loaders

In [5]:
from torch.utils.data import DataLoader

train_dataset = MyDataset(
    csv_path="mnist-pngs/new_train.csv",
    img_dir="mnist-pngs/",
    transform=data_transforms["train"],
)

train_loader = DataLoader(
    dataset=train_dataset,
    batch_size=32,
    shuffle=True,  # want to shuffle the dataset
    num_workers=0,  # number processes/CPUs to use
)

In [6]:
val_dataset = MyDataset(
    csv_path="mnist-pngs/new_val.csv",
    img_dir="mnist-pngs/",
    transform=data_transforms["test"],
)

val_loader = DataLoader(
    dataset=val_dataset,
    batch_size=32,
    shuffle=False,
    num_workers=0,
)

In [7]:
test_dataset = MyDataset(
    csv_path="mnist-pngs/test.csv",
    img_dir="mnist-pngs/",
    transform=data_transforms["test"],
)

test_loader = DataLoader(
    dataset=test_dataset,
    batch_size=32,
    shuffle=False,
    num_workers=0
)

## 5) Testing the data loaders

In [8]:
import time

num_epochs = 1
for epoch in range(num_epochs):

    for batch_idx, (x, y) in enumerate(train_loader):
        time.sleep(1)
        if batch_idx >= 3:
            break
        print(" Batch index:", batch_idx, end="")
        print(" | Batch size:", y.shape[0], end="")
        print(" | x shape:", x.shape, end="")
        print(" | y shape:", y.shape)

print("Labels from current batch:", y)

# Uncomment to visualize a data batch:
# batch = next(iter(train_loader))
# viz_batch_images(batch[0])

 Batch index: 0 | Batch size: 32 | x shape: torch.Size([32, 1, 28, 28]) | y shape: torch.Size([32])
 Batch index: 1 | Batch size: 32 | x shape: torch.Size([32, 1, 28, 28]) | y shape: torch.Size([32])
 Batch index: 2 | Batch size: 32 | x shape: torch.Size([32, 1, 28, 28]) | y shape: torch.Size([32])
Labels from current batch: tensor([5, 4, 7, 9, 2, 8, 9, 1, 2, 8, 9, 9, 0, 4, 3, 0, 8, 9, 1, 1, 6, 5, 3, 3,
        9, 5, 3, 2, 6, 2, 8, 1])
