In [None]:
# https://pytorch.org/tutorials/beginner/basics/data_tutorial.html

In [1]:
from torchvision import datasets
from torchvision.transforms import ToTensor
import matplotlib.pyplot as plt
import torch

In [None]:
# There are many dataset in datasets

In [3]:
training_data = datasets.FashionMNIST(
    root="data",
    train=True,
    download=True,
    transform=ToTensor())

In [16]:
len(training_data)

60000

In [8]:
x, y = training_data[0]

In [None]:
# A dataset class have 3 function
# - __init__(self, annotations_file, img_dir, transform=None, target_transform=None)
#        self.img_labels = pd.read_csv(annotations_file)
#        self.img_dir = img_dir
#        self.transform = transform
#        self.target_transform = target_transform
# - __len__
# - __getitem__

In [17]:
# len
len(training_data)

60000

In [19]:
# getitem
# like list we can access value by an index
x, y = training_data[0]
x.shape, y

(torch.Size([1, 28, 28]), 9)

In [20]:
# we usualy have to define 2 functions: __len__, and __getitem__

In [21]:
# Example
import os
import pandas as pd
from torchvision.io import read_image

class CustomImageDataset(Dataset):
    def __init__(self, annotations_file, img_dir, transform=None, target_transform=None):
        self.img_labels = pd.read_csv(annotations_file)
        self.img_dir = img_dir
        self.transform = transform
        self.target_transform = target_transform

    def __len__(self):
        return len(self.img_labels)

    def __getitem__(self, idx):
        img_path = os.path.join(self.img_dir, self.img_labels.iloc[idx, 0])
        image = read_image(img_path)
        label = self.img_labels.iloc[idx, 1]
        if self.transform:
            image = self.transform(image)
        if self.target_transform:
            label = self.target_transform(label)
        return image, label

In [22]:
# len is the number of images
# getitem: get idx and return image and label in this case. We can return any thing we want.
# we can input preprocess X by transform and y by target_transform

In [23]:
# After we have a dataset that can access sample by index and have len function.
# We now can use DataLoader to return a batch and shuffle data.

In [24]:
from torch.utils.data import Dataset, DataLoader

In [28]:
train_dataloader = DataLoader(training_data, batch_size=64, shuffle=True)

In [29]:
train_features, train_labels = next(iter(train_dataloader))

In [31]:
train_features.shape, train_labels.shape

(torch.Size([64, 1, 28, 28]), torch.Size([64]))