### 使用 pytorch 完成手写数字的识别

In [59]:
from torchvision.datasets import MNIST
from torchvision.transforms import Compose,ToTensor,Normalize
from torch.utils.data import DataLoader

BATCH_SIZE=128

In [103]:
# 准备数据集
def get_dataloader(train=True,batch_size=BATCH_SIZE):
    transform_fn=Compose([
        ToTensor(),
        Normalize(mean=(0.1307,),std=(0.3081,))
    ])
    dataset=MNIST(root="/home/suzhang/git/Pytorch/HandWritingRecoData",train=train,transform=transform_fn)
    data_loader=DataLoader(dataset,batch_size=batch_size,shuffle=True)
    return data_loader

In [61]:
# 构建模型

# 激活函数
import torch.nn.functional as F
import torch
from torch import nn
b=torch.tensor([-2,-1,0,1,2])
F.relu(b)
device=torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

cuda


In [62]:
# 数据的形状 now [BATCH_SIZE,1,28,28]
#            input1 [BATCH_SIZE,28*28]
#            output1 [BATCH_SIZE,28]
#            input1 [BATCH_SIZE,28]
#            output2 [BATCH_SIZE,10]
# 形状的修改

class MnistNet(nn.Module):
    def __init__(self):
        super(MnistNet,self).__init__()
        self.fc1=nn.Linear(28*28*1,28)
        self.fc2=nn.Linear(28,10)
        
    def forward(self,x):
        # x:[batch_size,1,28,28]
        # 最后一部分可能不满
        # x.size(0)也可
        # 修改形状
        x=x.view([-1,28*28*1])

        # 全连接操作
        x=self.fc1(x)
        # 激活函数处理（形状无变化）
        x=F.relu(x)
        # 全输出层
        x=self.fc2(x)

        return x
        

In [63]:
# 损失函数(交叉熵)
# 训练


In [64]:
from torch import optim
mnist_net=MnistNet().to(device)
optimizer=optim.Adam(mnist_net.parameters(),lr=1e-3)

In [89]:
import time
def train(epoch):
    data_loader=get_dataloader()
    for idx,(Input,target) in enumerate(data_loader):
        Input=Input.to(device)
        target=target.to(device)
        output=mnist_net(Input)
        logput=F.log_softmax(output,dim=-1)
        loss=F.nll_loss(logput,target)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        if idx%10==0:
            print(epoch,idx,loss.item())

In [90]:
for i in range(10):
    train(i)

0 0 0.02532069757580757
0 10 0.06606400012969971
0 20 0.07008245587348938
0 30 0.0978085920214653
0 40 0.045961894094944
0 50 0.05757028982043266
0 60 0.08797620236873627
0 70 0.024240365251898766
0 80 0.04234069213271141
0 90 0.05490914359688759
0 100 0.10631413757801056
0 110 0.037515103816986084
0 120 0.04025069996714592
0 130 0.05632774159312248
0 140 0.07919642329216003
0 150 0.10952959954738617
0 160 0.09345036000013351
0 170 0.09402614831924438
0 180 0.08855807781219482
0 190 0.06196647137403488
0 200 0.07229217141866684
0 210 0.09847632050514221
0 220 0.06074308604001999
0 230 0.14775444567203522
0 240 0.03247573599219322
0 250 0.08614680916070938
0 260 0.06386097520589828
0 270 0.1026415154337883
0 280 0.05724901705980301
0 290 0.026624079793691635
0 300 0.09692341089248657
0 310 0.05262065306305885
0 320 0.047548964619636536
0 330 0.027170751243829727
0 340 0.04733964800834656
0 350 0.019532909616827965
0 360 0.08195815980434418
0 370 0.13630469143390656
0 380 0.0367947183549

6 360 0.10319285839796066
6 370 0.06268235296010971
6 380 0.04660429432988167
6 390 0.024313632398843765
6 400 0.055693864822387695
6 410 0.03994910046458244
6 420 0.0854753702878952
6 430 0.03377274423837662
6 440 0.029053762555122375
6 450 0.02943437360227108
6 460 0.04887093976140022
7 0 0.029601991176605225
7 10 0.04493428021669388
7 20 0.02121499739587307
7 30 0.02658945880830288
7 40 0.026546228677034378
7 50 0.04750281199812889
7 60 0.025814982131123543
7 70 0.015782466158270836
7 80 0.02846686728298664
7 90 0.14063477516174316
7 100 0.04230072721838951
7 110 0.017017124220728874
7 120 0.06688115000724792
7 130 0.08717996627092361
7 140 0.018534624949097633
7 150 0.02575472928583622
7 160 0.02617744915187359
7 170 0.059476517140865326
7 180 0.0324057899415493
7 190 0.10888189822435379
7 200 0.11493223905563354
7 210 0.025856809690594673
7 220 0.060347650200128555
7 230 0.09279391169548035
7 240 0.11004915088415146
7 250 0.016559522598981857
7 260 0.0591721348464489
7 270 0.03575

In [87]:
# 保存模型
torch.save(mnist_net.state_dict(),"/home/suzhang/git/Pytorch/HandWritingRecoData/MNIST/model/mnist_net.pt")
torch.save(optimizer.state_dict(),"/home/suzhang/git/Pytorch/HandWritingRecoData/MNIST/results/mnist_optimizer.pt")

In [88]:
# 模型的加载
mnist_net.load_state_dict(torch.load("/home/suzhang/git/Pytorch/HandWritingRecoData/MNIST/model/mnist_net.pt"))
optimizer.load_state_dict(torch.load("/home/suzhang/git/Pytorch/HandWritingRecoData/MNIST/results/mnist_optimizer.pt"))

In [106]:
TEST_BATCH_SIZE=1000
import numpy as np
# 模型的评估
def test():
    loss_list=[]
    acc_list=[]
    mnist_net.eval()
    test_dataloader=get_dataloader(train=False,batch_size=TEST_BATCH_SIZE)
    for idx,(Input,target) in enumerate(test_dataloader):
        Input=Input.to(device)
        target=target.to(device)
        with torch.no_grad():
            output=mnist_net(Input)
            output=F.log_softmax(output)
            #output [batch_size,10] target:batch_size
            cur_loss=F.nll_loss(output,target)
            cur_loss=cur_loss.cpu()
            loss_list.append(cur_loss.data.numpy())
            #计算准确率
            pred=output.max(dim=-1)[-1]
            cur_acc=pred.eq(target).float().mean()
            cur_acc=cur_acc.cpu()
            acc_list.append(cur_acc.data.numpy())
    print("平均准确率，平均损失：",np.mean(acc_list),np.mean(loss_list))    

In [107]:
# 未使用卷积神经网络的结果
test()

  


平均准确率，平均损失： 0.9640001 0.13496827
