In [2]:
import matplotlib.pyplot as plt
import torch
import torchvision
import torchvision.transforms as transforms
import time
import sys
import numpy as np
%matplotlib widget

#### 获取FashionMNIST数据集

In [5]:
## torchvision.datasets.FashionMNIST参数意义
# 与其他torchvision.datasets.XXX数据集函数的参数意义相同
# root是datasets的位置；
# train如果为True则表示加载训练集，否则加载验证集；
# download如果为True表示如果root路径下没有数据就将数据下载到该路径下
# transform 接收一个函数对原本的PIL格式的（数据类型uint8）图片数据进行转换处理
mnist_train = torchvision.datasets.FashionMNIST(root='./Datasets/FashionMNIST',train=True,
                                                download=True,transform=transforms.ToTensor())
mnist_test = torchvision.datasets.FashionMNIST(root='./Datasets/FashionMNIST',train=False,
                                               download=True,transform=transforms.ToTensor())

`transforms.ToTensor`函数的用处：<br>
1. 将uint8型的数据转变为float32的范围为[0,1]的浮点数
2. 如果不进行转换返回的数据为PIL图片

In [7]:
print(type(mnist_train))
print(len(mnist_train), len(mnist_test))

<class 'torchvision.datasets.mnist.FashionMNIST'>
60000 10000


**访问样本**

In [8]:
feature, label = mnist_train[0]

In [10]:
feature.shape

torch.Size([1, 28, 28])

In [20]:
print(type(label))
label

<class 'int'>


9

可知每一个样本由一个`tensor型`的feature和`int型`的label组成<br>
注意！：<br>
feature的尺寸是$(C\times H\times W)$而不是$(H\times W\times C)$

**将label映射对应类别**<br>
0-9 --> 't-shirt','trouser',...

In [23]:
def get_fashion_mnist_labels(labels):
    text_labels = ['t-shirt', 'trouser', 'pullover', 'dress', 'coat',
                   'sandal', 'shirt', 'sneaker', 'bag', 'ankle boot']
    return [text_labels[int(i)] for i in labels]

In [26]:
def show_fashion_mnist(images, labels):
    # 这里的_表示忽略（不使用）的变量
    _, figs = plt.subplots(1, len(images), figsize=(12, 12))
    for f, img, lbl in zip(figs, images, labels):
        f.imshow(img.view((28, 28)).numpy())
        f.set_title(lbl)
        # 不显示xy坐标轴
        f.axes.get_xaxis().set_visible(False)
        f.axes.get_yaxis().set_visible(False)
    plt.show()

In [28]:
# test
x = [mnist_train[i][0] for i in range(10)]
y = [mnist_train[i][1] for i in range(10)]
show_fashion_mnist(x,get_fashion_mnist_labels(y))

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

#### 读取小批量数据

**多进程读取数据**

In [31]:
sys.platform.startswith('win')

True

In [50]:
batch_size = 256
train_iter = torch.utils.data.DataLoader(mnist_train, batch_size=batch_size, shuffle=True, num_workers=4)
test_iter = torch.utils.data.DataLoader(mnist_test, batch_size=batch_size, shuffle=True, num_workers=4)

In [51]:
start = time.time()
for X, y in train_iter:
    continue
print('%.2f sec' % (time.time() - start))

7.22 sec
