## mnist手写数字辨识

In [4]:
import torch
import torchvision
from matplotlib import pyplot as plt

In [6]:
train_dataset = torchvision.datasets.MNIST(root='./data', train=True, transform=torchvision.transforms.ToTensor(), download=True)
test_dataset = torchvision.datasets.MNIST(root='./data', train=False, transform=torchvision.transforms.ToTensor(), download=True)

print(train_dataset.data.size())

torch.Size([60000, 28, 28])


In [9]:
train_dataloader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=64, shuffle=True)
test_dataloader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=64, shuffle=True)

x, y = next(iter(train_dataloader))
x.shape, y.shape

(torch.Size([8, 1, 28, 28]), tensor([9, 4, 8, 3, 6, 7, 9, 5]))

In [13]:
class Model(torch.nn.Module):
    
    def __init__(self):
        super(Model, self).__init__()
        
        self.conv1 = torch.nn.Sequential(
            torch.nn.Conv2d(in_channels=1,
                            out_channels=16,
                            kernel_size=3, 
                            stride=2, 
                            padding=1)
        )
            
        self.conv2 = torch.nn.Sequential(
            torch.nn.Conv2d(in_channels=16, 
                            out_channels=32, 
                            kernel_size=3, 
                            stride=1, 
                            padding=1)
        )
        
        self.conv3 = torch.nn.Sequential(
            torch.nn.Conv2d(in_channels=32, 
                            out_channels=64, 
                            kernel_size=7, 
                            stride=1,
                            padding=0)
        )
        
        self.pool = torch.nn.MaxPool2d(kernel_size=2, stride=2)
        
        self.relu = torch.nn.ReLU()
        
        self.fc = torch.nn.Linear(64 , 10)
        
    def forward(self, x):
        
        # [64, 1, 28, 28] -> [64, 16, 14, 14]
        x = self.conv1(x)
        x = self.relu(x)
        
        # [64, 16, 14, 14] -> [64, 32, 14, 14]
        x = self.conv2(x)
        x = self.relu(x)
        
        # [64, 32, 14, 14] -> [64, 32, 7, 7]
        x = self.pool(x)
        
        # [64, 32, 7, 7] -> [64, 64, 1, 1]
        x = self.conv3(x)
        x = self.relu(x)
        
        # [64, 64, 1, 1] -> [64, 64]
        x = x.flatten(start_dim=1)
        
        # [64, 64] -> [64, 10]
        return self.fc(x)
        
model = Model()

model(torch.randn(64, 1, 28, 28)).shape

torch.Size([64, 10])

## 计算过程
$$ 
O = \frac{{I - K + 2P}}{{S}} + 1
$$
$$ 
O = \frac{{28 - 3 + 2*1}}{{2}} + 1 = 14
$$
$$ 
O = \frac{{14 - 3 + 2*1}}{{1}} + 1 = 14
$$
$$ 
O = \frac{{7 - 7 + 2*0}}{{1}} + 1 = 1
$$

In [14]:
def train():
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
    loss_fn = torch.nn.CrossEntropyLoss()
    model.train()
    
    for epoch in range(10):
        for i, (x, y) in enumerate(train_dataloader):
            optimizer.zero_grad()
            y_pred = model(x)
            loss = loss_fn(y_pred, y)
            loss.backward()
            optimizer.step()
            
            if i % 100 == 0:
                print('epoch: {}, step: {}, loss: {}'.format(epoch, i, loss.item()))
    
    torch.save(model.state_dict(), './model.pth')
    
train()

epoch: 0, step: 0, loss: 2.307921886444092
epoch: 0, step: 100, loss: 0.24208813905715942
epoch: 0, step: 200, loss: 0.2891041338443756
epoch: 0, step: 300, loss: 0.40162408351898193
epoch: 0, step: 400, loss: 0.023495061323046684
epoch: 0, step: 500, loss: 0.07630366086959839
epoch: 0, step: 600, loss: 0.17589624226093292
epoch: 0, step: 700, loss: 0.805452287197113
epoch: 0, step: 800, loss: 0.023236798122525215
epoch: 0, step: 900, loss: 0.018157295882701874
epoch: 0, step: 1000, loss: 0.05834105238318443
epoch: 0, step: 1100, loss: 0.01430862583220005
epoch: 0, step: 1200, loss: 0.28176063299179077
epoch: 0, step: 1300, loss: 0.05352471023797989
epoch: 0, step: 1400, loss: 0.22636502981185913
epoch: 0, step: 1500, loss: 0.08927323669195175
epoch: 0, step: 1600, loss: 0.0704236850142479
epoch: 0, step: 1700, loss: 0.03793703764677048
epoch: 0, step: 1800, loss: 0.24361760914325714
epoch: 0, step: 1900, loss: 0.22209042310714722
epoch: 0, step: 2000, loss: 0.0038501343224197626
epoch

In [15]:
@torch.no_grad()
def test():
    model.load_state_dict(torch.load('./model.pth'))
    model.eval()
    
    total = 0
    correct = 0
    
    for x, y in test_dataloader:
        y_pred = model(x)
        _, pred = torch.max(y_pred, 1)
        total += x.size(0)
        correct += (pred == y).sum().item()
        
    print('accuracy: {}'.format(correct / total))
    
test()

accuracy: 0.9856
