In [64]:
import torch
import numpy as np
from torch import nn
from torchvision import transforms,datasets
from torch.utils.data import DataLoader

In [65]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [66]:
#通过PyTorch下载数据集
train_dataset = datasets.MNIST(root='',train=True,download=True,transform=transforms.Compose([transforms.Resize(size=(32,32)),transforms.ToTensor()]))
test_dataset = datasets.MNIST(root='',train=False,download=True,transform=transforms.Compose([transforms.Resize(size=(32,32)),transforms.ToTensor()]))

In [67]:
train_dataset_size = len(train_dataset)
test_dataset_size = len(test_dataset)
print(train_dataset_size)
print(test_dataset_size)

batch_size = 64
#shuffle=True可以打乱数据集，batch_size=64将会让这个数据生成器每次给我们64个数据,drop_last=True会把不够64一组的舍去（影响不大）。
train_loader = DataLoader(dataset=train_dataset,batch_size=batch_size,shuffle=True,drop_last=True)
test_loader = DataLoader(dataset=test_dataset,batch_size=batch_size,shuffle=True,drop_last=True)

60000
10000


In [68]:
'''这里先单独获取一次DataLoader的数据，用来观察数据结构'''
#enumerate将可迭代对象组合为索引序列，同时列出数据和数据下标。
for index,data in enumerate(train_loader):
    inputs, labels = data
    print(inputs.shape)
    print(labels.shape)
    print(labels)
    break

torch.Size([64, 1, 32, 32])
torch.Size([64])
tensor([7, 8, 6, 3, 9, 2, 8, 9, 8, 6, 6, 8, 3, 8, 2, 2, 0, 8, 9, 9, 4, 7, 4, 7,
        5, 6, 5, 6, 9, 4, 1, 8, 7, 4, 8, 7, 6, 3, 0, 9, 0, 6, 3, 5, 0, 6, 1, 3,
        1, 3, 6, 4, 6, 3, 5, 4, 6, 5, 4, 0, 0, 5, 5, 0])


In [69]:
class Net(nn.Module):
    def __init__(self):
        super(Net,self).__init__()
        self.model = nn.Sequential(
            nn.Conv2d(in_channels=1,out_channels=6,kernel_size=5,stride=1,padding=0),
            nn.MaxPool2d(kernel_size=(2,2)),
            nn.Conv2d(in_channels=6,out_channels=16,kernel_size=5,stride=1,padding=0),
            nn.MaxPool2d(kernel_size=(2,2)),
            nn.Flatten(),
            nn.Linear(in_features=16*5*5,out_features=1*120),
            nn.Linear(in_features=1*120,out_features=84),
            nn.Linear(in_features=84,out_features=10)
        )
    
    def forward(self,x):
        x = self.model(x)
        return x
        

In [70]:
LR = 1e-2
#神经网络模型对象创建
net = Net()
net = net.to(device)
print(net)
loss_fn = nn.CrossEntropyLoss()
loss_fn = loss_fn.to(device)
optim = torch.optim.SGD(net.parameters(),LR)

Net(
  (model): Sequential(
    (0): Conv2d(1, 6, kernel_size=(5, 5), stride=(1, 1))
    (1): MaxPool2d(kernel_size=(2, 2), stride=(2, 2), padding=0, dilation=1, ceil_mode=False)
    (2): Conv2d(6, 16, kernel_size=(5, 5), stride=(1, 1))
    (3): MaxPool2d(kernel_size=(2, 2), stride=(2, 2), padding=0, dilation=1, ceil_mode=False)
    (4): Flatten(start_dim=1, end_dim=-1)
    (5): Linear(in_features=400, out_features=120, bias=True)
    (6): Linear(in_features=120, out_features=84, bias=True)
    (7): Linear(in_features=84, out_features=10, bias=True)
  )
)


In [71]:
def train():
    for index,data in enumerate(train_loader):
        inputs, labels = data
        inputs = inputs.to(device)
        labels = labels.to(device)
        outputs = net(inputs)
        loss = loss_fn(outputs,labels)
        optim.zero_grad()
        loss.backward()
        optim.step()

def test():
    times = 0
    for index,data in enumerate(test_loader):
        times += 1
        inputs, labels = data
        inputs = inputs.to(device)
        labels = labels.to(device)
        outputs = net(inputs)
        accuracy = (outputs.argmax(axis=1) == labels).sum()
        if times % 15 == 0:
            print("Test accuracy:{0}".format(accuracy/len(test_loader)))

In [72]:
for epoch in range(25):
    train()
    if epoch % 5 == 0:
        print("epoch {0}".format(epoch))
        test()
print("Final accuracy")
test()

epoch 0
Test accuracy:0.3782051205635071
Test accuracy:0.3717948794364929
Test accuracy:0.3717948794364929
Test accuracy:0.36538460850715637
Test accuracy:0.3589743673801422
Test accuracy:0.3461538553237915
Test accuracy:0.3782051205635071
Test accuracy:0.3461538553237915
Test accuracy:0.3782051205635071
Test accuracy:0.39743590354919434
epoch 5
Test accuracy:0.39743590354919434
Test accuracy:0.38461539149284363
Test accuracy:0.3910256326198578
Test accuracy:0.41025641560554504
Test accuracy:0.4038461446762085
Test accuracy:0.39743590354919434
Test accuracy:0.41025641560554504
Test accuracy:0.39743590354919434
Test accuracy:0.4038461446762085
Test accuracy:0.41025641560554504
epoch 10
Test accuracy:0.4038461446762085
Test accuracy:0.41025641560554504
Test accuracy:0.38461539149284363
Test accuracy:0.4038461446762085
Test accuracy:0.41025641560554504
Test accuracy:0.3910256326198578
Test accuracy:0.41025641560554504
Test accuracy:0.41025641560554504
Test accuracy:0.39743590354919434
Tes