In [6]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms

In [7]:
torch.manual_seed(4242)

<torch._C.Generator at 0x2d0a74b7cd0>

In [8]:
train_loader = torch.utils.data.DataLoader(
    datasets.MNIST('../data/p1ch2/mnist', train=True, download=True,
                   transform=transforms.Compose([
                       transforms.ToTensor(),
                       transforms.Normalize((0.1307,), (0.3081,))
                   ])),
    batch_size=64, shuffle=True)

In [9]:
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
        self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
        self.conv2_drop = nn.Dropout2d()
        self.fc1 = nn.Linear(320, 50)
        self.fc2 = nn.Linear(50, 10)

    def forward(self, x):
        x = F.relu(F.max_pool2d(self.conv1(x), 2))
        x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))
        x = x.view(-1, 320)
        x = F.relu(self.fc1(x))
        x = F.dropout(x, training=self.training)
        x = self.fc2(x)
        return F.log_softmax(x, dim=1)

In [10]:
model = Net()

In [11]:
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.5)#随机梯度算法

In [12]:
for epoch in range(10):
    for batch_idx, (data, target) in enumerate(train_loader):
        optimizer.zero_grad()
        output = model(data)
        loss = F.nll_loss(output, target)
        loss.backward()
        optimizer.step()
    print('Current loss', float(loss))

Current loss 0.4915657341480255
Current loss 0.13548722863197327
Current loss 0.4440402388572693
Current loss 0.5570538640022278
Current loss 0.22641773521900177
Current loss 0.28907525539398193
Current loss 0.24486860632896423
Current loss 0.265625
Current loss 0.055082887411117554
Current loss 0.06868812441825867


In [13]:
torch.save(model.state_dict(), '../data/p1ch2/mnist/mnist.pth')

In [14]:
pretrained_model = Net()
pretrained_model.load_state_dict(torch.load('../data/p1ch2/mnist/mnist.pth'))

<All keys matched successfully>