In [None]:
import torch 
from torch import nn, optim
import torch.nn.functional as F
from torchvision import datasets, transforms


In [None]:
data_dir = '/Users/sunxiaolei/Python_Projects/Deeplearning/experiment_three'


In [None]:
# 定义一个转换器：标准化图像数据，使得灰度数据在-1到+1之间，使得训练出的权重在0附近，利于神经网络的训练
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

In [None]:
# 下载Fashion-MNIST训练集数据，并构建训练集数据载入器 trainloader,每次从训练集中载入64张图片，每次载入都打乱顺序 #shuffle:每次重新打乱进行抽取
batch_size= 64
trainset = datasets.FashionMNIST(root=data_dir, download=True, train=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True)

testset = datasets.FashionMNIST(root=data_dir, download=True, train=False, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=64, shuffle=True)

In [None]:
images, labels = next(iter(trainloader))

In [None]:
images.shape

In [None]:
labels.shape

In [None]:
index = 0
image = images[index]
label = labels[index]

In [None]:
image.shape

In [None]:
import matplotlib.pyplot as plt
plt.imshow(image.reshape(28,28))
labellist = ['T恤','裤子','套衫','裙子','外套','凉鞋','汗衫','运 动鞋','包包','靴子']
print(f'这张图片对应的标签是 {labellist[label]}')

# 2 搭建并训练四层全连接神经网络

In [None]:
from torch import nn, optim
import torch.nn.functional as F



class Classifier(nn.Module):
    def __init__(self):
        super().__init__()
        
        self.fc1 = nn.Linear(784, 256)
        self.fc2 = nn.Linear(256, 128)
        self.fc3 = nn.Linear(128, 64)
        self.fc4 = nn.Linear(64, 10)
        
    def forward(self, x):
        x = x.view(x.shape[0], -1)
        
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = F.relu(self.fc3(x))
        x = F.log_softmax(self.fc4(x), dim=1)
        
        return x

In [None]:
# 对上面定义的Classifier类进行实例化
model = Classifier()

# 定义损失函数为交叉熵
criterion = nn.CrossEntropyLoss()

# 优化方法为SGD，学习率为0.003
optimizer = optim.SGD(model.parameters(), lr=0.003)

# 对训练集的全部数据学习15遍
epochs = 15

# 将每次训练的训练误差和测试误差存储在这两个列表里，供后面绘制误差变化折线图用
train_losses, test_losses = [], []

def train():
    print('开始训练：')
    for e in range(epochs):
        running_loss = 0

        # 对训练集中的所有图片都过一遍
        for images, labels in trainloader:
            # 将优化器中的求导结果都设为0，否则会在每次反向传播之后叠加之前的
            optimizer.zero_grad()

            # 对64张图片进行推断，计算损失函数，反向传播优化权重，将损失求和
            log_ps = model(images)
            loss = criterion(log_ps, labels)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()

        # 每次学完一遍数据集，都进行以下测试操作
        else:
            test_loss = 0
            accuracy = 0
            # 测试的时候不需要开自动求导和反向传播
            with torch.no_grad():
                # 将模型转换为评估模式，在该模式下不会影响到训练
                model.eval()

                # 对测试集中的所有图片都过一遍
                for images, labels in testloader:
                    log_ps = model(images)
                    test_loss += criterion(log_ps, labels)
                    ps = torch.exp(log_ps)
                    top_p, top_class = ps.topk(1, dim=1)
                    equals = top_class == labels.view(*top_class.shape)
                    # 等号右边为每一批64张测试图片中预测正确的占比
                    accuracy += torch.mean(equals.type(torch.FloatTensor))
                    # 恢复Droput
            model.train()
            # 将训练误差和测试误差存在两个列表里，后面绘制误差变化折线图
            train_losses.append(running_loss / len(trainloader))
            test_losses.append(test_loss / len(testloader))

            print("训练集学习次数：{}/{}..".format(e + 1, epochs),
                  "训练误差：{:.3f}..".format(running_loss / len(trainloader)),
                  "测试误差:{:.3f}..".format(test_loss / len(testloader)),
                  "模型分类准确性：{:.3f}".format(accuracy / len(testloader)))


In [None]:
%time train()

# 3 验证模型效果

## 绘制训练误差和测试误差随学习次数增加的变化

In [None]:
import matplotlib.pyplot as plt
%matplotlib inline

In [None]:
# 将列表中的所有张量从GPU移动到CPU上
for i, tensor in enumerate(test_losses):
    test_losses[i] = tensor.cpu()

plt.plot(train_losses, label='Training loss')
plt.plot(test_losses, label='Validation loss')
plt.legend()