# 4.1.5. Exercises


In [1]:
%matplotlib inline
import time
import torch
import torchvision
from torchvision import transforms
from d2l import torch as d2l

d2l.use_svg_display()


class FashionMNIST(d2l.DataModule):  # @save
  """The Fashion-MNIST dataset."""

  def __init__(self, batch_size=64, resize=(28, 28)):
    super().__init__()
    self.save_hyperparameters()
    trans = transforms.Compose([transforms.Resize(resize), transforms.ToTensor()])
    self.train = torchvision.datasets.FashionMNIST(
      root=self.root, train=True, transform=trans, download=True
    )
    self.val = torchvision.datasets.FashionMNIST(
      root=self.root, train=False, transform=trans, download=True
    )

  def text_labels(self, indices):
    """Return text labels."""
    labels = [
      "t-shirt",
      "trouser",
      "pullover",
      "dress",
      "coat",
      "sandal",
      "shirt",
      "sneaker",
      "bag",
      "ankle boot",
    ]
    return [labels[int(i)] for i in indices]

  def get_dataloader(self, train):
    data = self.train if train else self.val
    return torch.utils.data.DataLoader(
      data, self.batch_size, shuffle=train, num_workers=self.num_workers
    )


In [2]:
data = FashionMNIST(resize=(32, 32))


In [3]:
def load_time(data):
  tic = time.time()
  for X, y in data.train_dataloader():
    continue
  return time.time() - tic

##### 1. Does reducing the `batch_size` (for instance, to 1) affect the reading performance?


答：当批量大小减少时，每次迭代中一起处理的示例数量会减少。这可能导致数据加载和预处理操作的频率增加，从而可能增加时间。


In [5]:
f"{load_time(FashionMNIST(batch_size=64, resize=(32, 32))):.2f} sec"


'1.81 sec'

In [6]:
f"{load_time(FashionMNIST(batch_size=1, resize=(32, 32))):.2f} sec"


'11.05 sec'

##### 2. The data iterator performance is important. Do you think the current implementation is fast enough? Explore various options to improve it. Use a system profiler to find out where the bottlenecks are.


答：改进的方法：使用多线程并行加载数据，使用高效的数据格式

_在 Mac M1 Max 上，`{method 'poll' of 'select.poll' objects}` 花了最多的时间_


In [7]:
import cProfile

profiler = cProfile.Profile()
profiler.enable()
# Call the function you want to profile
load_time(data)
profiler.disable()
profiler.print_stats(sort="tottime")

         203749 function calls (203743 primitive calls) in 1.920 seconds

   Ordered by: internal time

   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
      942    1.634    0.002    1.634    0.002 {method 'poll' of 'select.poll' objects}
     1876    0.058    0.000    0.058    0.000 {built-in method _new_shared_filename_cpu}
        1    0.043    0.043    1.920    1.920 2229652835.py:1(load_time)
        4    0.017    0.004    0.017    0.004 {built-in method _posixsubprocess.fork_exec}
      938    0.015    0.000    0.115    0.000 {built-in method _pickle.loads}
     1876    0.009    0.000    0.009    0.000 {built-in method torch.tensor}
      939    0.009    0.000    0.009    0.000 {built-in method torch._ops.profiler._record_function_enter_new}
        8    0.009    0.001    0.009    0.001 {method '_share_filename_cpu_' of 'torch._C.StorageBase' objects}
     1876    0.007    0.000    0.007    0.000 {method 'set_' of 'torch._C._TensorBase' objects}
      940 

##### 3. Check out the framework's online API documentation. Which other datasets are available?


答：

- 图像分类: Caltech 101 Dataset, Caltech 256 Dataset, Large-scale CelebFaces Attributes (CelebA) Dataset Dataset, CIFAR10 Dataset, CIFAR100 Dataset...
- 图像检测或分割: MS Coco Detection Dataset, Cityscapes Dataset, KITTI Dataset...
- Optical Flow（光流）
- Stereo Matching（立体匹配）
- ...
