# Datasets and Dataloaders

PyTorch helps us a lot with the data during training. `Dataloaders` create random splits of the data in every epoch of training for us, but they do need to know how to get the data and what the data is exactly. This is where the `Dataset` class comes in.

In [12]:
# import some super basic data
from sklearn.datasets import load_iris

# load the iris dataset
iris = load_iris(as_frame=True)
print(iris.keys())

dict_keys(['data', 'target', 'frame', 'target_names', 'DESCR', 'feature_names', 'filename', 'data_module'])


In [29]:
# create a Dataset class for the iris data

# import the Dataset class
from torch.utils.data import Dataset

class IrisDataset(Dataset):
    """This is a child of Dataset providing the iris data to the dataloader
    The init and getitem constructors are absolutely necessary"""
    def __init__(self, data, targets):
        super().__init__()
        self.data = data
        self.targets = targets
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        # for a given index, return the data and target
        return self.data.iloc[idx], self.targets[idx]

iris_dataset = IrisDataset(iris.data, iris.target)