We implement a dataloader from scratch. The dataloader takes in a set of features X and targets y and transforms them into batches that can be iterated over. We also implement the option to shuffle the dataset before splitting it up in batches. And, we implement the len function to obtain the number of batches.

In [1]:
import torch
import numpy as np

In [2]:
class DataLoader:
    def __init__(self, X, y, batch_size=64, shuffle=True):
        """
        Custom DataLoader for batching and iterating over a dataset.

        :param X: Input features, can be a list, numpy array, or PyTorch tensor.
        :param y: Labels corresponding to input features.
        :param batch_size: Number of samples per batch.
        :param shuffle: Whether to shuffle the data at the beginning of each iteration.
        :param transform: Optional transform to be applied on each batch.
        """
        self.X = X
        self.y = y
        self.batch_size = batch_size
        self.shuffle = shuffle

    def __len__(self):
        """Returns the number of batches in the DataLoader."""
        return int(np.ceil(len(self.X) / self.batch_size))

    def __iter__(self):
        """Iterator to generate data batches."""
        n_samples = len(self.X)
        indices = np.arange(n_samples)

        if self.shuffle:
            indices = np.random.permutation(indices)

        for start_idx in range(0, n_samples, self.batch_size):
            end_idx = min(start_idx + self.batch_size, n_samples)
            batch_indices = indices[start_idx:end_idx]

            X_batch = self.X[batch_indices]
            y_batch = self.y[batch_indices]

            yield X_batch, y_batch

In [3]:
# Generate some data
n_samples = 1000
n_features = 5
X_train = torch.randn(n_samples, n_features)
true_weights = torch.randn(n_features, 1)
y_train = X_train @ true_weights + torch.randn(n_samples, 1) * 0.5

train_dataloader = DataLoader(X_train, y_train, batch_size=256, shuffle=True)

In [4]:
for X_batch, y_batch in train_dataloader:
    print(X_batch.shape, y_batch.shape)

torch.Size([256, 5]) torch.Size([256, 1])
torch.Size([256, 5]) torch.Size([256, 1])
torch.Size([256, 5]) torch.Size([256, 1])
torch.Size([232, 5]) torch.Size([232, 1])


In [5]:
len(train_dataloader)

4