# 利用CNN实现MNIST数据集的手写数字识别任务

In [23]:
import numpy as np
import torch
from torch import nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torchvision import datasets,transforms
import matplotlib.pyplot as plt
%matplotlib inline
from tqdm import tqdm

## 读取数据
-分别构建训练集和测试集
-DataLoader来迭代取数据

In [24]:
# 定义超参数
input_size = 28
num_classes = 10
num_epochs = 3
batch_size = 64

# 训练集
train_dataset = datasets.MNIST(root='./data',
                              train=True,
                              transform=transforms.ToTensor(),
                              download=True)
# 测试集
test_dataset = datasets.MNIST(root='./data',
                              train=False,
                              transform=transforms.ToTensor(),
                              download=True)

# 构建batch数据
train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
                                           batch_size=batch_size,
                                          shuffle=True)
test_loader = torch.utils.data.DataLoader(dataset=test_dataset,
                                         batch_size=batch_size,
                                         shuffle=True)

In [25]:
train_dataset[2][0].shape # 第二个样本的数据

torch.Size([1, 28, 28])

In [26]:
test_dataset

Dataset MNIST
    Number of datapoints: 10000
    Root location: ./data
    Split: Test
    StandardTransform
Transform: ToTensor()

## 卷积网络模块的构建
-一般卷积层，relu层，池化层可以写成一个套餐
-注意卷积最后结果还是一个特征图，需要把图转换为向量才能做分类或者回归任务

In [27]:
class CNN(nn.Module):
    def __init__(self):
        super(CNN,self).__init__()
        self.conv1 = nn.Sequential( # 输入的大小为(1,28,28)
            nn.Conv2d(in_channels=1,out_channels=16,kernel_size=5,stride=1,padding=2), # 输出的维度为(16,28,28)
            nn.ReLU(), # 输出的维度(16,28,28)
            nn.MaxPool2d(kernel_size=2) # 输出的维度(16,14,14)
        )
        self.conv2 = nn.Sequential( # 输入的维度(16,14,14)
            nn.Conv2d(in_channels=16,out_channels=32,kernel_size=5,stride=1,padding=2), # 输出的维度(32,14,14)
            nn.ReLU(), # 输出的维度(32,14,14)
            nn.MaxPool2d(kernel_size=2) # 输出的维度(32,7,7)
        )
        self.out = nn.Linear(32*7*7,10) # 全连接
        
    def forward(self,x):
        x = self.conv1(x)
        x = self.conv2(x)
        x = x.view(x.size(0),-1) # flatten操作，结果为：(batch_size,32*7*7)
        output = self.out(x)
        return output

## 准确率函数

In [28]:
def accuracy(predictions,labels):
    pred = torch.max(predictions.data,1)[1] # 得到最大值的下标
    rights = pred.eq(labels.data.view_as(pred)).sum() # 得到预测准确的数目
    return rights,len(labels)

## 训练和验证函数

In [37]:
def train_val(train_loader,test_loader,criterion,optimizer):
    # 当前的epoch的结果保存下来
    train_rights = []
    
    for batch_idx,(data,target) in tqdm(enumerate(train_loader),total = len(train_loader), position = 0, leave = True):
        net.train()
        output = net(data)
        loss = criterion(output,target)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        right= accuracy(output,target)
        train_rights.append(right)
    
        if batch_idx % 100 == 0:
            # 验证/测试
            val_rights = []
            for (data,target) in test_loader:
                output = net(data)
                right = accuracy(output,target)
                val_rights.append(right)
                
            # 准确率计算
            # 获取训练集的正确个数和总样本数，得到一个元组
            train_r = (sum([tup[0]for tup in train_rights]),sum([tup[1] for tup in train_rights]))
            # 获得验证集的正确个数和总样本数，得到一个元组
            val_r = (sum([tup[0]for tup in val_rights]),sum([tup[1] for tup in val_rights]))   
            
            print(f'当前epoch:{epoch},损失：{loss.data.item()},训练集准确率：{np.round(100.*train_r[0].numpy()/train_r[1],5)}%,测试集准确率{np.round(100.*val_r[0].numpy()/val_r[1],5)}%')

## 训练网络模型

In [38]:
# 实例化
net = CNN()
# 损失函数
criterion = nn.CrossEntropyLoss()
# 优化器
optimizer = optim.Adam(net.parameters(),lr=0.005)

# 开始训练
for epoch in range(1,num_epochs+1):
    train_val(train_loader,test_loader,criterion,optimizer)

  1%|▌                                                                                 | 7/938 [00:01<02:55,  5.32it/s]

当前epoch:1,损失：2.2879769802093506,训练集准确率：14.0625%,测试集准确率11.96%


 12%|█████████▎                                                                      | 109/938 [00:05<01:02, 13.17it/s]

当前epoch:1,损失：0.16246330738067627,训练集准确率：84.90099%,测试集准确率95.6%


 22%|█████████████████▉                                                              | 210/938 [00:08<01:00, 12.09it/s]

当前epoch:1,损失：0.0948672816157341,训练集准确率：90.31405%,测试集准确率96.91%


 33%|██████████████████████████▏                                                     | 307/938 [00:12<00:53, 11.76it/s]

当前epoch:1,损失：0.23753592371940613,训练集准确率：92.45224%,测试集准确率97.77%


 43%|██████████████████████████████████▊                                             | 408/938 [00:16<00:39, 13.28it/s]

当前epoch:1,损失：0.13571897149085999,训练集准确率：93.65648%,测试集准确率97.93%


 54%|███████████████████████████████████████████▍                                    | 510/938 [00:19<00:34, 12.54it/s]

当前epoch:1,损失：0.20600497722625732,训练集准确率：94.40806%,测试集准确率98.39%


 65%|████████████████████████████████████████████████████                            | 611/938 [00:23<00:25, 12.74it/s]

当前epoch:1,损失：0.058954231441020966,训练集准确率：95.01872%,测试集准确率98.45%


 76%|████████████████████████████████████████████████████████████▌                   | 710/938 [00:27<00:18, 12.07it/s]

当前epoch:1,损失：0.0011798848863691092,训练集准确率：95.44178%,测试集准确率98.41%


 86%|████████████████████████████████████████████████████████████████████▉           | 808/938 [00:30<00:10, 11.99it/s]

当前epoch:1,损失：0.06992898881435394,训练集准确率：95.75726%,测试集准确率98.32%


 97%|█████████████████████████████████████████████████████████████████████████████▌  | 909/938 [00:34<00:02, 12.16it/s]

当前epoch:1,损失：0.05768150836229324,训练集准确率：96.00444%,测试集准确率98.47%


100%|████████████████████████████████████████████████████████████████████████████████| 938/938 [00:35<00:00, 26.63it/s]
  1%|▌                                                                                 | 7/938 [00:01<03:03,  5.08it/s]

当前epoch:2,损失：0.003348784288391471,训练集准确率：100.0%,测试集准确率98.4%


 12%|█████████▏                                                                      | 108/938 [00:05<01:20, 10.37it/s]

当前epoch:2,损失：0.058884285390377045,训练集准确率：98.63861%,测试集准确率98.34%


 22%|█████████████████▋                                                              | 208/938 [00:09<01:10, 10.42it/s]

当前epoch:2,损失：0.020244725048542023,训练集准确率：98.49969%,测试集准确率98.4%


 33%|██████████████████████████▎                                                     | 308/938 [00:13<01:00, 10.48it/s]

当前epoch:2,损失：0.06680842489004135,训练集准确率：98.54132%,测试集准确率98.46%


 43%|██████████████████████████████████▊                                             | 408/938 [00:17<00:54,  9.71it/s]

当前epoch:2,损失：0.06239129230380058,训练集准确率：98.56998%,测试集准确率98.82%


 54%|███████████████████████████████████████████▎                                    | 508/938 [00:21<00:42, 10.21it/s]

当前epoch:2,损失：0.12983562052249908,训练集准确率：98.56537%,测试集准确率98.69%


 65%|███████████████████████████████████████████████████▊                            | 608/938 [00:25<00:31, 10.41it/s]

当前epoch:2,损失：0.026609448716044426,训练集准确率：98.57009%,测试集准确率98.73%


 75%|████████████████████████████████████████████████████████████▎                   | 707/938 [00:29<00:22, 10.26it/s]

当前epoch:2,损失：0.12852519750595093,训练集准确率：98.5757%,测试集准确率98.62%


 86%|████████████████████████████████████████████████████████████████████▊           | 807/938 [00:33<00:13,  9.98it/s]

当前epoch:2,损失：0.05168810486793518,训练集准确率：98.55064%,测试集准确率98.03%


 97%|█████████████████████████████████████████████████████████████████████████████▎  | 907/938 [00:37<00:03, 10.25it/s]

当前epoch:2,损失：0.005207865033298731,训练集准确率：98.55369%,测试集准确率98.79%


100%|████████████████████████████████████████████████████████████████████████████████| 938/938 [00:38<00:00, 24.47it/s]
  1%|▌                                                                                 | 6/938 [00:01<03:50,  4.05it/s]

当前epoch:3,损失：0.010867721401154995,训练集准确率：100.0%,测试集准确率98.55%


 11%|█████████                                                                       | 106/938 [00:06<01:24,  9.86it/s]

当前epoch:3,损失：0.02380608394742012,训练集准确率：98.87067%,测试集准确率98.76%


 22%|█████████████████▌                                                              | 206/938 [00:10<01:11, 10.22it/s]

当前epoch:3,损失：0.022123737260699272,训练集准确率：98.93501%,测试集准确率98.73%


 33%|██████████████████████████                                                      | 306/938 [00:14<01:01, 10.33it/s]

当前epoch:3,损失：0.04047602787613869,训练集准确率：99.02409%,测试集准确率98.69%


 43%|██████████████████████████████████▋                                             | 406/938 [00:18<00:52, 10.07it/s]

当前epoch:3,损失：0.1586921215057373,训练集准确率：98.95963%,测试集准确率98.77%


 54%|███████████████████████████████████████████▏                                    | 506/938 [00:22<00:42, 10.28it/s]

当前epoch:3,损失：0.005970741622149944,训练集准确率：98.94274%,测试集准确率98.78%


 65%|███████████████████████████████████████████████████▋                            | 606/938 [00:26<00:32, 10.21it/s]

当前epoch:3,损失：0.10967163741588593,训练集准确率：98.94967%,测试集准确率98.81%


 75%|████████████████████████████████████████████████████████████▏                   | 706/938 [00:30<00:23, 10.05it/s]

当前epoch:3,损失：0.03585703298449516,训练集准确率：98.9301%,测试集准确率98.62%


 86%|████████████████████████████████████████████████████████████████████▋           | 806/938 [00:34<00:13,  9.51it/s]

当前epoch:3,损失：0.02716967463493347,训练集准确率：98.90566%,测试集准确率98.83%


 97%|█████████████████████████████████████████████████████████████████████████████▎  | 906/938 [00:38<00:03, 10.35it/s]

当前epoch:3,损失：0.050107236951589584,训练集准确率：98.89012%,测试集准确率98.88%


100%|████████████████████████████████████████████████████████████████████████████████| 938/938 [00:39<00:00, 23.97it/s]
