In [2]:
%matplotlib inline
import time
import torch
import torchvision
from torchvision import transforms
from d2l import torch as d2l

d2l.use_svg_display()

In [3]:
class FashionMNIST(d2l.DataModule):  #@save
    """The Fashion-MNIST dataset."""
    def __init__(self, batch_size=64, resize=(28, 28)):
        super().__init__()
        self.save_hyperparameters()
        trans = transforms.Compose([transforms.Resize(resize),
                                    transforms.ToTensor()])
        self.train = torchvision.datasets.FashionMNIST(
            root=self.root, train=True, transform=trans, download=True)
        self.val = torchvision.datasets.FashionMNIST(
            root=self.root, train=False, transform=trans, download=True)
    
    def text_labels(self, indices):
        """Return text labels."""
        labels = ['t-shirt', 'trouser', 'pullover', 'dress', 'coat',
                'sandal', 'shirt', 'sneaker', 'bag', 'ankle boot']
        return [labels[int(i)] for i in indices]
    
    def get_dataloader(self, train):
        data = self.train if train else self.val
        return torch.utils.data.DataLoader(data, self.batch_size, shuffle=train,
                                        num_workers=self.num_workers)

In [4]:
data = FashionMNIST(resize=(32, 32))

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz to ../data\FashionMNIST\raw\train-images-idx3-ubyte.gz


100.0%


Extracting ../data\FashionMNIST\raw\train-images-idx3-ubyte.gz to ../data\FashionMNIST\raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz to ../data\FashionMNIST\raw\train-labels-idx1-ubyte.gz


100.0%


Extracting ../data\FashionMNIST\raw\train-labels-idx1-ubyte.gz to ../data\FashionMNIST\raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz to ../data\FashionMNIST\raw\t10k-images-idx3-ubyte.gz


100.0%


Extracting ../data\FashionMNIST\raw\t10k-images-idx3-ubyte.gz to ../data\FashionMNIST\raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz to ../data\FashionMNIST\raw\t10k-labels-idx1-ubyte.gz


100.0%

Extracting ../data\FashionMNIST\raw\t10k-labels-idx1-ubyte.gz to ../data\FashionMNIST\raw






In [7]:
def load_time(data):
    tic = time.time()
    for X, y in data.train_dataloader():
        continue
    return time.time() - tic

1. Does reducing the batch_size (for instance, to 1) affect the reading performance?

In [8]:
f'{load_time(FashionMNIST(batch_size=64, resize=(32, 32))):.2f} sec'

'10.73 sec'

In [9]:
f'{load_time(FashionMNIST(batch_size=1, resize=(32, 32))):.2f} sec'

'47.61 sec'

When the batch size is reduced, the number of examples processed together in each iteration decreases. This can lead to an increase in the frequency of data loading and preprocessing operations, which may increase the time.

2. The data iterator performance is important. Do you think the current implementation is fast enough? Explore various options to improve it. Use a system profiler to find out where the bottlenecks are.

In [10]:
import cProfile

profiler = cProfile.Profile()
profiler.enable()
# Call the function you want to profile
load_time(data)
profiler.disable()
profiler.print_stats(sort="tottime")

         191638 function calls (191632 primitive calls) in 10.267 seconds

   Ordered by: internal time

   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
      676    8.578    0.013    8.578    0.013 {built-in method _winapi.WaitForMultipleObjects}
       10    0.736    0.074    0.736    0.074 {built-in method _winapi.WaitForSingleObject}
      962    0.173    0.000    0.173    0.000 {method 'acquire' of '_thread.lock' objects}
        4    0.086    0.022    0.086    0.022 {built-in method _winapi.CreateProcess}
        1    0.070    0.070   10.267   10.267 2419004203.py:1(load_time)
     1876    0.070    0.000    0.070    0.000 {built-in method _new_shared_filename_cpu}
      938    0.042    0.000    0.250    0.000 {built-in method _pickle.loads}
     1876    0.039    0.000    0.039    0.000 {built-in method torch.tensor}
     1876    0.031    0.000    0.031    0.000 {built-in method _winapi.PeekNamedPipe}
      939    0.029    0.000    0.029    0.000 {built-in 

The built-in method `_winapi.WaitForMultipleObjects` took the most time. ？？

Ways to improve it: use multithread to load data in parallel, use efficient data format

3. Check out the framework’s online API documentation. Which other datasets are available?

- **Image classification:** Caltech 101 Dataset, Caltech 256 Dataset, Large-scale CelebFaces Attributes (CelebA) Dataset Dataset, CIFAR10 Dataset, CIFAR100 Dataset...
- **Image detection or segmentation:** MS Coco Detection Dataset, Cityscapes Dataset, KITTI Dataset...
- **Optical Flow**
- **Stereo Matching**
- ...
