In [1]:
%matplotlib inline
import torch
import torchvision
from torch.utils import data
from torchvision import transforms
from d2l import torch as d2l

d2l.use_svg_display()

In [2]:
trans = transforms.ToTensor()
mnist_train = torchvision.datasets.FashionMNIST(root="../data/",train=True,transform=trans,download=True)
mnist_test = torchvision.datasets.FashionMNIST(root="../data/",train=False,transform=trans,download=True)
print(len(mnist_train),len(mnist_test))
mnist_train[0][0].shape

60000 10000


torch.Size([1, 28, 28])

transforms.ToTensor():将图像数据转为tensor，并归一化至[0-1]
torchvision.datasets : ⼀些加载数据的函数及常⽤的数据集接⼝；
torchvision.transforms : 常⽤的图⽚变换，例如裁剪、旋转等；

In [1]:
def get_fashion_mnist_labels(labels):     #@save
    """"返回Fashion-Mnist数据集的文本标签"""
    text_labels = ['t-shirts','trouser','pullover','dress','coat','sandal','shirt','sneaker','bag','ankle boot']#[ , ,套衫, , ,凉鞋, ,运动鞋, ,裸靴]
    return [text_labels[int(i)] for i in labels]

In [5]:
def show_images(imgs,num_rows,num_cols,titles = None,scale = 1):  #@save
    """"绘制图像列表"""
    figsize = (num_cols * scale,num_rows * scale)
    
    _,axes = d2l.plt.subplot(num_rows,num_cols,figsize)
    axes = axes.flatten()
    for i,(ax,img) in enumerate(zip(axes,imgs)):
        if torch.is_tensor(img):
            #图片张量
            ax.imshow(img.numpy())
        else:
            #PIL图片
            ax.imshow(img)
        ax.axes.get_xaxis().set_visible(False)
        ax.axes.get_yaxis().set_visible(False)
        if titles:
            ax.set_title(titles[i])
    return axes

subplot(nrows, ncols, index, **kwargs)，一般我们只用到前三个参数，将整个绘图区域分成 nrows 行和 ncols 列，而 index 用于对子图进行编号。
enumerate() 函数:遍历一个集合对象，在遍历的同时还可以得到当前元素的索引位置
zip()函数可以接收多个可迭代对象，然后把每个可迭代对象中的第i个元素组合在一起，形成一个新的迭代器，类型为元组。
flatten()数组降维

In [8]:
X,y = next(iter(data.DataLoader(mnist_train,batch_size=18)))
show_images(X.reshape(18,28,28),2,9,titles=get_fashion_mnist_labels(y))

IndexError: GridSpec slice would result in no space allocated for subplot

<Figure size 640x480 with 0 Axes>

In [11]:
batch_size = 256

def get_dataloader_workers():   #@save
    """"使用4个进程来读取数据"""
    return 4

train_iter = data.DataLoader(mnist_train,batch_size,shuffle=True,num_workers=get_dataloader_workers())

In [10]:
timer = d2l.Timer()
for X,y in train_iter:
    continue
f'{timer.stop():.2f}sec'

'4.21sec'

读取训练集所需时间

In [14]:
def load_data_fashion_mnist(batch_size,resize = None):  #@save
    """"下载Fashion-MNIST数据集,然后将其加载到内存中"""
    trans = [transforms.ToTensor()]
    if resize:
        trans.insert(0,transforms.Resize(resize))
    trans = transforms.Compose(trans)
    mnist_train = torchvision.datasets.FashionMNIST(root="../data/",train=True,transform=trans,download=True)
    mnist_test = torchvision.datasets.FashionMNIST(root="../data/",train=False,transform=trans,download=True)
    return (data.DataLoader(mnist_train,batch_size,shuffle=True,num_workers=get_dataloader_workers()),
            data.DataLoader(mnist_test,batch_size,shuffle=False,num_workers=get_dataloader_workers()))
    

In [15]:
train_iter,test_iter = load_data_fashion_mnist(32,resize=64)
for X,y in train_iter:
    print(X.shape,X.dtype,y.shape,y.dtype)
    break

torch.Size([32, 1, 64, 64]) torch.float32 torch.Size([32]) torch.int64
