## course content

1. lenet 模型介绍
2. lenet 网络搭建
3. 运用lenet进行图像识别-fashion-mnist数据集

#  Convolutional Neural Networks

使用全连接层的局限性：

- 图像在同一列邻近的像素在这个向量中可能相距较远。它们构成的模式可能难以被模型识别。
- 对于大尺寸的输入图像，使用全连接层容易导致模型过大。

使用卷积层的优势：

- 卷积层保留输入形状。
- 卷积层通过滑动窗口将同一卷积核与不同位置的输入重复计算，从而避免参数尺寸过大。


## LeNet 模型

LeNet分为卷积层块和全连接层块两个部分。下面我们分别介绍这两个模块。


![Image Name](https://cdn.kesci.com/upload/image/q5ndwsmsao.png?imageView2/0/w/960/h/960)


卷积层块里的基本单位是卷积层后接平均池化层：卷积层用来识别图像里的空间模式，如线条和物体局部，之后的平均池化层则用来降低卷积层对位置的敏感性。

卷积层块由两个这样的基本单位重复堆叠构成。在卷积层块中，每个卷积层都使用$5 \times 5$的窗口，并在输出上使用sigmoid激活函数。第一个卷积层输出通道数为6，第二个卷积层输出通道数则增加到16。

全连接层块含3个全连接层。它们的输出个数分别是120、84和10，其中10为输出的类别个数。

下面我们通过Sequential类来实现LeNet模型。

In [3]:
#import
import sys
sys.path.append(r"C:\000Disk\Learn\PyTorch\start_to_learn\data\input")
import d2lzh1981 as d2l
import torch
import torch.nn as nn
import torch.optim as optim
import time

In [4]:
#net
class Flatten(torch.nn.Module):  #展平操作
    def forward(self, x):
        return x.view(x.shape[0], -1)

class Reshape(torch.nn.Module): #将图像大小重定型
    def forward(self, x):
        return x.view(-1,1,28,28)      #(B x C x H x W)
    
net = torch.nn.Sequential(     #Lelet                                                  
    Reshape(),
    nn.Conv2d(in_channels=1, out_channels=6, kernel_size=5, padding=2), #b*1*28*28  =>b*6*28*28
    nn.Sigmoid(),                                                       
    nn.AvgPool2d(kernel_size=2, stride=2),                              #b*6*28*28  =>b*6*14*14
    nn.Conv2d(in_channels=6, out_channels=16, kernel_size=5),           #b*6*14*14  =>b*16*10*10
    nn.Sigmoid(),
    nn.AvgPool2d(kernel_size=2, stride=2),                              #b*16*10*10  => b*16*5*5
    Flatten(),                                                          #b*16*5*5   => b*400
    nn.Linear(in_features=16*5*5, out_features=120),
    nn.Sigmoid(),
    nn.Linear(120, 84),
    nn.Sigmoid(),
    nn.Linear(84, 10)
)

接下来我们构造一个高和宽均为28的单通道数据样本，并逐层进行前向计算来查看每个层的输出形状。

In [5]:
#print
X = torch.randn(size=(1,1,28,28), dtype = torch.float32)
for layer in net:
    X = layer(X)
    print(layer.__class__.__name__,'output shape: \t',X.shape)

Reshape output shape: 	 torch.Size([1, 1, 28, 28])
Conv2d output shape: 	 torch.Size([1, 6, 28, 28])
Sigmoid output shape: 	 torch.Size([1, 6, 28, 28])
AvgPool2d output shape: 	 torch.Size([1, 6, 14, 14])
Conv2d output shape: 	 torch.Size([1, 16, 10, 10])
Sigmoid output shape: 	 torch.Size([1, 16, 10, 10])
AvgPool2d output shape: 	 torch.Size([1, 16, 5, 5])
Flatten output shape: 	 torch.Size([1, 400])
Linear output shape: 	 torch.Size([1, 120])
Sigmoid output shape: 	 torch.Size([1, 120])
Linear output shape: 	 torch.Size([1, 84])
Sigmoid output shape: 	 torch.Size([1, 84])
Linear output shape: 	 torch.Size([1, 10])


可以看到，在卷积层块中输入的高和宽在逐层减小。卷积层由于使用高和宽均为5的卷积核，从而将高和宽分别减小4，而池化层则将高和宽减半，但通道数则从1增加到16。全连接层则逐层减少输出个数，直到变成图像的类别数10。


![Image Name](https://cdn.kesci.com/upload/image/q5ndxi6jl5.png?imageView2/0/w/640/h/640)


## 获取数据和训练模型

下面我们来实现LeNet模型。我们仍然使用Fashion-MNIST作为训练数据集。

In [None]:
# 数据
batch_size = 256
train_iter, test_iter = d2l.load_data_fashion_mnist(
    batch_size=batch_size, root=r'C:\000Disk\Learn\PyTorch\start_to_learn\data\input\FashionMNIST2065')
print(len(train_iter))



0it [00:00, ?it/s][A

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz to C:\000Disk\Learn\PyTorch\start_to_learn\data\input\FashionMNIST2065\FashionMNIST\raw\train-images-idx3-ubyte.gz



  0%|                                                                                     | 0/26421880 [00:00<?, ?it/s][A
  0%|                                                                      | 16384/26421880 [00:01<11:32, 38146.74it/s][A
  0%|▏                                                                     | 49152/26421880 [00:01<08:56, 49123.47it/s][A
  0%|▏                                                                     | 81920/26421880 [00:01<07:58, 55068.20it/s][A
  0%|▏                                                                     | 90112/26421880 [00:01<08:58, 48928.94it/s][A
  0%|▎                                                                     | 98304/26421880 [00:02<09:44, 45063.82it/s][A
  0%|▎                                                                    | 114688/26421880 [00:02<08:52, 49426.61it/s][A
  0%|▎                                                                    | 122880/26421880 [00:02<10:16, 42692.35it/s][A
  0%|▎         

  2%|█▋                                                                   | 630784/26421880 [00:36<36:53, 11651.31it/s][A
  2%|█▋                                                                   | 638976/26421880 [00:37<35:57, 11952.19it/s][A
  2%|█▋                                                                   | 647168/26421880 [00:38<38:31, 11150.35it/s][A
  2%|█▋                                                                   | 655360/26421880 [00:38<40:21, 10640.75it/s][A
  3%|█▋                                                                   | 663552/26421880 [00:39<34:55, 12294.85it/s][A
  3%|█▊                                                                   | 671744/26421880 [00:39<34:29, 12443.06it/s][A
  3%|█▊                                                                   | 679936/26421880 [00:40<31:27, 13641.34it/s][A
  3%|█▊                                                                   | 688128/26421880 [00:40<28:41, 14946.86it/s][A
  3%|█▊         

  4%|███                                                                 | 1179648/26421880 [01:14<30:35, 13754.95it/s][A
  4%|███                                                                 | 1187840/26421880 [01:16<41:09, 10218.36it/s][A
  5%|███                                                                 | 1196032/26421880 [01:16<39:31, 10637.32it/s][A
  5%|███                                                                | 1204224/26421880 [01:19<1:02:04, 6771.51it/s][A
  5%|███▏                                                                 | 1212416/26421880 [01:20<56:32, 7430.22it/s][A
  5%|███▏                                                                 | 1220608/26421880 [01:21<59:25, 7067.56it/s][A
  5%|███                                                                | 1228800/26421880 [01:22<1:04:22, 6521.92it/s][A
  5%|███▏                                                                 | 1236992/26421880 [01:23<52:11, 8042.53it/s][A
  5%|███▎       

  7%|████▍                                                               | 1728512/26421880 [02:07<34:43, 11850.19it/s][A
  7%|████▍                                                               | 1736704/26421880 [02:07<27:19, 15052.72it/s][A
  7%|████▍                                                               | 1744896/26421880 [02:08<32:02, 12834.87it/s][A
  7%|████▌                                                               | 1753088/26421880 [02:08<32:31, 12640.97it/s][A
  7%|████▌                                                               | 1761280/26421880 [02:09<38:07, 10782.34it/s][A
  7%|████▌                                                               | 1769472/26421880 [02:10<35:48, 11473.74it/s][A
  7%|████▌                                                               | 1777664/26421880 [02:10<31:49, 12908.03it/s][A
  7%|████▌                                                               | 1785856/26421880 [02:11<32:22, 12684.61it/s][A
  7%|████▋      

  9%|█████▉                                                              | 2285568/26421880 [02:48<27:00, 14891.91it/s][A
  9%|█████▉                                                              | 2293760/26421880 [02:49<27:52, 14430.15it/s][A
  9%|█████▉                                                              | 2310144/26421880 [02:49<23:56, 16783.58it/s][A
  9%|█████▉                                                              | 2318336/26421880 [02:50<26:01, 15434.31it/s][A
  9%|█████▉                                                              | 2326528/26421880 [02:51<31:03, 12927.99it/s][A
  9%|██████                                                              | 2334720/26421880 [02:51<27:39, 14519.07it/s][A
  9%|██████                                                              | 2342912/26421880 [02:52<34:09, 11747.44it/s][A
  9%|██████                                                              | 2351104/26421880 [02:52<27:27, 14610.92it/s][A
  9%|██████     

 11%|███████▎                                                            | 2842624/26421880 [03:25<26:05, 15062.81it/s][A
 11%|███████▎                                                            | 2850816/26421880 [03:25<21:09, 18562.54it/s][A
 11%|███████▎                                                            | 2859008/26421880 [03:26<22:21, 17566.01it/s][A
 11%|███████▍                                                            | 2867200/26421880 [03:26<22:06, 17753.44it/s][A
 11%|███████▍                                                            | 2875392/26421880 [03:27<19:30, 20109.43it/s][A
 11%|███████▍                                                            | 2883584/26421880 [03:27<19:26, 20184.31it/s][A
 11%|███████▍                                                            | 2891776/26421880 [03:28<28:58, 13536.13it/s][A
 11%|███████▌                                                             | 2899968/26421880 [03:29<40:56, 9577.31it/s][A
 11%|███████▍   

 13%|████████▋                                                           | 3399680/26421880 [04:08<32:08, 11940.13it/s][A
 13%|████████▊                                                           | 3407872/26421880 [04:09<28:09, 13618.03it/s][A
 13%|████████▊                                                           | 3416064/26421880 [04:09<25:23, 15105.15it/s][A
 13%|████████▊                                                           | 3424256/26421880 [04:11<37:55, 10105.61it/s][A
 13%|████████▊                                                           | 3440640/26421880 [04:11<29:23, 13033.57it/s][A
 13%|████████▉                                                           | 3448832/26421880 [04:12<29:03, 13175.49it/s][A
 13%|████████▉                                                           | 3457024/26421880 [04:12<23:09, 16531.48it/s][A
 13%|████████▉                                                           | 3465216/26421880 [04:12<19:04, 20059.29it/s][A
 13%|████████▉  

 15%|██████████▏                                                         | 3973120/26421880 [04:48<13:15, 28226.88it/s][A
 15%|██████████▏                                                         | 3981312/26421880 [04:48<17:33, 21305.48it/s][A
 15%|██████████▎                                                         | 3989504/26421880 [04:49<18:21, 20357.88it/s][A
 15%|██████████▎                                                         | 3997696/26421880 [04:49<15:35, 23967.68it/s][A
 15%|██████████▎                                                         | 4005888/26421880 [04:49<19:11, 19465.53it/s][A
 15%|██████████▎                                                         | 4014080/26421880 [04:50<24:39, 15141.11it/s][A
 15%|██████████▎                                                         | 4022272/26421880 [04:51<28:20, 13171.78it/s][A
 15%|██████████▎                                                         | 4030464/26421880 [04:52<28:06, 13280.58it/s][A
 15%|██████████▍

 17%|███████████▋                                                        | 4530176/26421880 [05:28<31:08, 11714.78it/s][A
 17%|███████████▋                                                        | 4538368/26421880 [05:28<30:18, 12036.67it/s][A
 17%|███████████▊                                                         | 4546560/26421880 [05:30<37:34, 9704.14it/s][A
 17%|███████████▋                                                        | 4554752/26421880 [05:30<35:18, 10323.24it/s][A
 17%|███████████▉                                                         | 4562944/26421880 [05:31<38:52, 9369.77it/s][A
 17%|███████████▊                                                        | 4571136/26421880 [05:32<32:34, 11179.86it/s][A
 17%|███████████▊                                                        | 4579328/26421880 [05:32<30:54, 11776.61it/s][A
 17%|███████████▊                                                        | 4587520/26421880 [05:33<27:00, 13476.43it/s][A
 17%|███████████

 19%|█████████████                                                       | 5070848/26421880 [06:01<14:41, 24214.05it/s][A
 19%|█████████████                                                       | 5079040/26421880 [06:02<18:13, 19519.22it/s][A
 19%|█████████████                                                       | 5087232/26421880 [06:03<28:45, 12366.92it/s][A
 19%|█████████████▏                                                      | 5103616/26421880 [06:04<24:32, 14476.38it/s][A
 19%|█████████████▏                                                      | 5111808/26421880 [06:04<22:24, 15848.12it/s][A
 19%|█████████████▏                                                      | 5120000/26421880 [06:05<21:36, 16426.01it/s][A
 19%|█████████████▏                                                      | 5128192/26421880 [06:05<19:39, 18054.24it/s][A
 19%|█████████████▏                                                      | 5136384/26421880 [06:05<16:21, 21679.26it/s][A
 19%|███████████

 21%|██████████████▌                                                     | 5660672/26421880 [06:35<16:59, 20358.39it/s][A
 21%|██████████████▌                                                     | 5668864/26421880 [06:35<16:59, 20360.21it/s][A
 21%|██████████████▌                                                     | 5677056/26421880 [06:36<19:30, 17728.72it/s][A
 22%|██████████████▋                                                     | 5685248/26421880 [06:36<16:42, 20682.90it/s][A
 22%|██████████████▋                                                     | 5693440/26421880 [06:36<14:16, 24202.18it/s][A
 22%|██████████████▋                                                     | 5701632/26421880 [06:37<15:05, 22891.77it/s][A
 22%|██████████████▋                                                     | 5709824/26421880 [06:37<13:07, 26305.56it/s][A
 22%|██████████████▋                                                     | 5718016/26421880 [06:38<19:54, 17332.29it/s][A
 22%|███████████

 24%|████████████████                                                    | 6225920/26421880 [07:02<11:25, 29442.20it/s][A
 24%|████████████████                                                    | 6234112/26421880 [07:02<10:29, 32068.62it/s][A
 24%|████████████████                                                    | 6242304/26421880 [07:03<12:18, 27317.22it/s][A
 24%|████████████████                                                    | 6250496/26421880 [07:03<13:32, 24824.78it/s][A
 24%|████████████████                                                    | 6258688/26421880 [07:03<12:27, 26979.96it/s][A
 24%|████████████████▏                                                   | 6266880/26421880 [07:04<16:35, 20250.52it/s][A
 24%|████████████████▏                                                   | 6275072/26421880 [07:04<16:33, 20275.75it/s][A
 24%|████████████████▏                                                   | 6283264/26421880 [07:05<16:33, 20266.77it/s][A
 24%|███████████

 26%|█████████████████▍                                                  | 6799360/26421880 [07:31<10:21, 31581.76it/s][A
 26%|█████████████████▌                                                  | 6807552/26421880 [07:31<09:40, 33766.38it/s][A
 26%|█████████████████▌                                                  | 6815744/26421880 [07:31<09:15, 35313.71it/s][A
 26%|█████████████████▌                                                  | 6823936/26421880 [07:32<08:54, 36695.79it/s][A
 26%|█████████████████▌                                                  | 6832128/26421880 [07:32<11:07, 29349.69it/s][A
 26%|█████████████████▋                                                  | 6848512/26421880 [07:32<10:13, 31897.71it/s][A
 26%|█████████████████▋                                                  | 6864896/26421880 [07:33<08:22, 38888.05it/s][A
 26%|█████████████████▋                                                  | 6873088/26421880 [07:33<08:18, 39228.03it/s][A
 26%|███████████

 28%|███████████████████                                                 | 7413760/26421880 [07:57<28:43, 11026.81it/s][A
 28%|███████████████████                                                 | 7421952/26421880 [07:58<24:47, 12777.06it/s][A
 28%|███████████████████▍                                                 | 7430144/26421880 [07:59<39:04, 8101.29it/s][A
 28%|███████████████████▍                                                 | 7438336/26421880 [08:00<37:21, 8468.78it/s][A
 28%|███████████████████▏                                                | 7446528/26421880 [08:01<31:16, 10114.49it/s][A
 28%|███████████████████▏                                                | 7454720/26421880 [08:01<26:33, 11904.10it/s][A
 28%|███████████████████▏                                                | 7462912/26421880 [08:02<30:21, 10405.64it/s][A
 28%|███████████████████▏                                                | 7471104/26421880 [08:03<25:53, 12198.56it/s][A
 28%|███████████

 30%|████████████████████▍                                               | 7962624/26421880 [08:37<28:14, 10893.35it/s][A
 30%|████████████████████▌                                               | 7970816/26421880 [08:38<28:49, 10670.98it/s][A
 30%|████████████████████▌                                               | 7979008/26421880 [08:38<26:56, 11406.47it/s][A
 30%|████████████████████▌                                               | 7987200/26421880 [08:39<23:42, 12958.08it/s][A
 30%|████████████████████▌                                               | 7995392/26421880 [08:39<23:35, 13013.34it/s][A
 30%|████████████████████▌                                               | 8003584/26421880 [08:40<18:48, 16324.21it/s][A
 30%|████████████████████▌                                               | 8011776/26421880 [08:40<23:02, 13316.59it/s][A
 30%|████████████████████▉                                                | 8019968/26421880 [08:43<44:18, 6922.63it/s][A
 30%|███████████

 32%|█████████████████████▉                                              | 8536064/26421880 [09:08<16:26, 18134.61it/s][A
 32%|█████████████████████▉                                              | 8544256/26421880 [09:09<18:08, 16430.26it/s][A
 32%|██████████████████████                                              | 8552448/26421880 [09:09<17:30, 17004.20it/s][A
 32%|██████████████████████                                              | 8560640/26421880 [09:10<21:23, 13913.40it/s][A
 32%|██████████████████████                                              | 8568832/26421880 [09:11<22:03, 13486.92it/s][A
 32%|██████████████████████                                              | 8577024/26421880 [09:11<17:37, 16878.52it/s][A
 32%|██████████████████████                                              | 8585216/26421880 [09:12<21:19, 13939.38it/s][A
 33%|██████████████████████                                              | 8593408/26421880 [09:12<16:55, 17551.88it/s][A
 33%|███████████

 34%|███████████████████████▋                                             | 9084928/26421880 [10:04<33:05, 8731.12it/s][A
 34%|███████████████████████▋                                             | 9093120/26421880 [10:05<40:13, 7180.08it/s][A
 34%|███████████████████████▊                                             | 9101312/26421880 [10:06<34:31, 8362.97it/s][A
 34%|███████████████████████▊                                             | 9109504/26421880 [10:07<32:39, 8837.18it/s][A
 35%|███████████████████████▊                                             | 9117696/26421880 [10:07<29:20, 9830.10it/s][A
 35%|███████████████████████▍                                            | 9125888/26421880 [10:08<24:50, 11606.05it/s][A
 35%|███████████████████████▌                                            | 9134080/26421880 [10:08<19:31, 14755.53it/s][A
 35%|███████████████████████▌                                            | 9142272/26421880 [10:09<20:08, 14298.32it/s][A
 35%|███████████

 36%|█████████████████████████▏                                           | 9625600/26421880 [10:57<38:33, 7258.98it/s][A
 36%|█████████████████████████▏                                           | 9633792/26421880 [10:58<38:08, 7336.51it/s][A
 36%|█████████████████████████▏                                           | 9641984/26421880 [10:58<33:20, 8389.25it/s][A
 37%|████████████████████████▊                                           | 9650176/26421880 [10:59<27:26, 10189.30it/s][A
 37%|████████████████████████▊                                           | 9658368/26421880 [10:59<21:15, 13138.52it/s][A
 37%|████████████████████████▉                                           | 9666560/26421880 [11:00<23:06, 12084.58it/s][A
 37%|████████████████████████▉                                           | 9674752/26421880 [11:00<18:13, 15309.38it/s][A
 37%|████████████████████████▉                                           | 9682944/26421880 [11:01<23:17, 11973.69it/s][A
 37%|███████████

 38%|█████████████████████████▋                                         | 10149888/26421880 [15:12<8:39:53, 521.65it/s][A
 38%|█████████████████████████▋                                         | 10149888/26421880 [15:27<8:39:53, 521.65it/s][A
 38%|█████████████████████████▊                                         | 10158080/26421880 [15:33<9:30:05, 475.47it/s][A
 38%|█████████████████████████▊                                         | 10158080/26421880 [15:47<9:30:05, 475.47it/s][A
 38%|█████████████████████████▍                                        | 10166272/26421880 [15:53<10:00:24, 451.23it/s][A
 38%|█████████████████████████▍                                        | 10166272/26421880 [16:07<10:00:24, 451.23it/s][A
 39%|█████████████████████████▍                                        | 10174464/26421880 [16:15<10:40:29, 422.78it/s][A
 39%|█████████████████████████▍                                        | 10174464/26421880 [16:27<10:40:29, 422.78it/s][A
 39%|███████████

为了使读者更加形象的看到数据，添加额外的部分来展示数据的图像

In [None]:
#数据展示
import matplotlib.pyplot as plt
def show_fashion_mnist(images, labels):
    d2l.use_svg_display()
    # 这里的_表示我们忽略（不使用）的变量
    _, figs = plt.subplots(1, len(images), figsize=(12, 12))
    for f, img, lbl in zip(figs, images, labels):
        f.imshow(img.view((28, 28)).numpy())
        f.set_title(lbl)
        f.axes.get_xaxis().set_visible(False)
        f.axes.get_yaxis().set_visible(False)
    plt.show()

for Xdata,ylabel in train_iter:
    break
X, y = [], []
for i in range(10):
    print(Xdata[i].shape,ylabel[i].numpy())
    X.append(Xdata[i]) # 将第i个feature加到X中
    y.append(ylabel[i].numpy()) # 将第i个label加到y中
show_fashion_mnist(X, y)

因为卷积神经网络计算比多层感知机要复杂，建议使用GPU来加速计算。我们查看看是否可以用GPU，如果成功则使用`cuda:0`，否则仍然使用`cpu`。

In [None]:
# This function has been saved in the d2l package for future use
#use GPU
def try_gpu():
    """If GPU is available, return torch.device as cuda:0; else return torch.device as cpu."""
    if torch.cuda.is_available():
        device = torch.device('cuda:0')
    else:
        device = torch.device('cpu')
    return device

device = try_gpu()
device

我们实现`evaluate_accuracy`函数，该函数用于计算模型`net`在数据集`data_iter`上的准确率。

In [None]:
#计算准确率
'''
(1). net.train()
  启用 BatchNormalization 和 Dropout，将BatchNormalization和Dropout置为True
(2). net.eval()
不启用 BatchNormalization 和 Dropout，将BatchNormalization和Dropout置为False
'''

def evaluate_accuracy(data_iter, net,device=torch.device('cpu')):
    """Evaluate accuracy of a model on the given data set."""
    acc_sum,n = torch.tensor([0],dtype=torch.float32,device=device),0
    for X,y in data_iter:
        # If device is the GPU, copy the data to the GPU.
        X,y = X.to(device),y.to(device)
        net.eval()
        with torch.no_grad():
            y = y.long()
            acc_sum += torch.sum((torch.argmax(net(X), dim=1) == y))  #[[0.2 ,0.4 ,0.5 ,0.6 ,0.8] ,[ 0.1,0.2 ,0.4 ,0.3 ,0.1]] => [ 4 , 2 ]
            n += y.shape[0]
    return acc_sum.item()/n

我们定义函数`train_ch5`，用于训练模型。

In [None]:
#训练函数
def train_ch5(net, train_iter, test_iter,criterion, num_epochs, batch_size, device,lr=None):
    """Train and evaluate a model with CPU or GPU."""
    print('training on', device)
    net.to(device)
    optimizer = optim.SGD(net.parameters(), lr=lr)
    for epoch in range(num_epochs):
        train_l_sum = torch.tensor([0.0],dtype=torch.float32,device=device)
        train_acc_sum = torch.tensor([0.0],dtype=torch.float32,device=device)
        n, start = 0, time.time()
        for X, y in train_iter:
            net.train()
            
            optimizer.zero_grad()
            X,y = X.to(device),y.to(device) 
            y_hat = net(X)
            loss = criterion(y_hat, y)
            loss.backward()
            optimizer.step()
            
            with torch.no_grad():
                y = y.long()
                train_l_sum += loss.float()
                train_acc_sum += (torch.sum((torch.argmax(y_hat, dim=1) == y))).float()
                n += y.shape[0]
        test_acc = evaluate_accuracy(test_iter, net,device)
        print('epoch %d, loss %.4f, train acc %.3f, test acc %.3f, '
              'time %.1f sec'
              % (epoch + 1, train_l_sum/n, train_acc_sum/n, test_acc,
                 time.time() - start))

我们重新将模型参数初始化到对应的设备`device`(`cpu` or `cuda:0`)之上，并使用Xavier随机初始化。损失函数和训练算法则依然使用交叉熵损失函数和小批量随机梯度下降。

In [None]:
# 训练
lr, num_epochs = 0.9, 10

def init_weights(m):
    if type(m) == nn.Linear or type(m) == nn.Conv2d:
        torch.nn.init.xavier_uniform_(m.weight)

net.apply(init_weights)
net = net.to(device)

criterion = nn.CrossEntropyLoss()   #交叉熵描述了两个概率分布之间的距离，交叉熵越小说明两者之间越接近
train_ch5(net, train_iter, test_iter, criterion,num_epochs, batch_size,device, lr)

In [None]:
# test
for testdata,testlabe in test_iter:
    testdata,testlabe = testdata.to(device),testlabe.to(device)
    break
print(testdata.shape,testlabe.shape)
net.eval()
y_pre = net(testdata)
print(torch.argmax(y_pre,dim=1)[:10])
print(testlabe[:10])

## 总结：

卷积神经网络就是含卷积层的网络。
LeNet交替使用卷积层和最大池化层后接全连接层来进行图像分类。

# 习题

### 1.关于LeNet，以下说法中错误的是：

A.LeNet主要分为两个部分：卷积层块和全连接层块

B.LeNet的绝大多数参数集中在卷积层块部分

C.LeNet在连接卷积层块和全连接层块时，需要做一次展平操作

D.LeNet的卷积层块交替使用卷积层和池化层。

答案解释

选项1：正确，参考LeNet模型的结构

选项2：错误，LeNet模型中，90%以上的参数集中在全连接层块

选项3：正确，参考LeNet模型的结构

选项4：正确，参考LeNet模型的结构

### 2.关于卷积神经网络，以下说法中错误的是：

A.因为全连接层的参数数量比卷积层多，所以全连接层可以更好地提取空间信息

B.使用形状为$2 \times 2$，步幅为2的池化层，会将高和宽都减半

C.卷积神经网络通过使用滑动窗口在输入的不同位置处重复计算，减小参数数量

D.在通过卷积层或池化层后，输出的高和宽可能减小，为了尽可能保留输入的特征，我们可以在减小高宽的同时增加通道数

答案解释

选项1：错误，参考视频1分钟左右对全连接层局限性的介绍

选项2：正确，参考LeNet中的池化层

选项3：正确，参考视频1分30左右对卷积层优势的介绍

选项4：正确，参考视频3分钟左右的介绍
