In [None]:
import os
import torch
import numpy as np
from torch import nn, optim
import torch.nn.functional as F
import torch.utils.data as Data
from torchvision import datasets, transforms

In [None]:
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,),)
])

In [None]:
train_set = datasets.MNIST('data', train=True, download=True, transform=transform)
test_set = datasets.MNIST('data', train=False, download=True, transform=transform)

In [None]:
class LeNet(nn.Module):
    '''
    2d: 2维图片处理
    '''
    def __init__(self):
        super(LeNet, self).__init__()
        self.c1 = nn.Conv2d(1, 6, 5)
        self.c3 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16*4*4, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)
    
    def forward(self, x):
        x = F.max_pool2d(F.relu(self.c1(x)), 2)
        x = F.max_pool2d(F.relu(self.c3(x)), 2)
        # 转化一维向量
        x = x.view(-1, self.num_flat_features(x))
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x
        
        
    def num_flat_features(self, x):
        # 计算 x 特征点的总数
        size = x.size()[1:]
        num_features = 1
        for s in size:
            num_features *= s
        return num_features

In [None]:
model = LeNet()
trainloader = torch.utils.data.DataLoader(train_set, batch_size=4, shuffle=True, num_workers=2)

In [None]:
def train(trainloader, model, epochs=1, lr=0.001):
    criterion = nn.CrossEntropyLoss()
    # 带动量的随机梯度下降法
    optimizer = optim.SGD(model.parameters(), lr=lr, momentum=0.9)
    for epoch in range(epochs):
        running_loss = 0.0
        for i, data in enumerate(trainloader):
            inputs, labels = data
            optimizer.zero_grad()
            output = model(inputs)
            loss = criterion(output, labels)
            loss.backward()
            optimizer.step()
            
            running_loss += loss.item()
            if i % 1000 == 999:
                print('[Epoch:%d, Batch:%5d] Loss: %.3f' % (epoch+1, i+1, running_loss/1000))
                running_loss = 0.0
    
    print('Finished Training')

In [None]:
train(trainloader, model, epochs=2)

In [None]:
def load_param(model, path):
    if os.path.exists(path):
        model.load_state_dict(torch.load(path))

def save_param(model, path):
    torch.save(model.state_dict(), path)

In [None]:
model.parameters()

In [None]:
testloader = torch.utils.data.DataLoader(test_set, batch_size=4, shuffle=True, num_workers=2)

In [None]:
def test(testloader, model):
    correct = 0
    total = 0
    for data in testloader:
        image, labels = data
        outputs = model(image)
        _, perdicted = torch.max(outputs, 1)
        total += labels.size(0)
        correct += (perdicted == labels).sum()
    print('Accuracy on the test set: %d %%' % (100 * correct / total))

In [None]:
test(testloader, model)

In [None]:
save_param(model, 'pkl_model/model_num_image.pkl')

In [None]:
load_param(model, 'model_num_image.pkl')