In [24]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms

In [25]:
class MnistNet(torch.nn.Module):
    def __init__(self):
        super(MnistNet, self).__init__()
        self.fc1 = torch.nn.Linear(28*28, 512)
        self.fc2 = torch.nn.Linear(512, 512)
        self.fc3 = torch.nn.Linear(512, 10)

    def forward(self, x):
        x = x.view(-1, 28*28)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = F.softmax(self.fc3(x), dim=1) # flatten
        return x

In [26]:
class Model:
    def __init__(self, net, cost, optimist):
        self.net = net
        self.cost = self.create_cost(cost)
        self.optimizer = self.create_optimizer(optimist)
        pass

    def create_cost(self, cost):
        support_cost = {
            'CROSS_ENTROPY': nn.CrossEntropyLoss(),
            'MSE': nn.MSELoss()
        }

        return support_cost[cost]

    def create_optimizer(self, optimist, **rests):
        support_optim = {
            'SGD': optim.SGD(self.net.parameters(), lr=0.1, **rests),
            'ADAM': optim.Adam(self.net.parameters(), lr=0.01, **rests),
            'RMSP':optim.RMSprop(self.net.parameters(), lr=0.001, **rests)
        }

        return support_optim[optimist]

    def train(self, train_loader, epoches=3):
        for epoch in range(epoches):
            running_loss = 0.0
            for i, data in enumerate(train_loader, 0):
                inputs, labels = data

                self.optimizer.zero_grad()

                # forward + backward + optimize
                outputs = self.net(inputs)
                loss = self.cost(outputs, labels)
                loss.backward()
                self.optimizer.step()

                running_loss += loss.item()
                if i % 100 == 0:
                    print('[epoch %d, %.2f%%] loss: %.3f' %
                          (epoch + 1, (i + 1)*1./len(train_loader), running_loss / 100))
                    running_loss = 0.0

        print('Finished Training')

    def evaluate(self, test_loader):
        print('Evaluating ...')
        correct = 0
        total = 0
        with torch.no_grad():  # no grad when test and predict
            for data in test_loader:
                images, labels = data

                outputs = self.net(images)
                predicted = torch.argmax(outputs, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()

        print('Accuracy of the network on the test images: %d %%' % (100 * correct / total))     

In [27]:
def mnist_load_data():
    transform = transforms.Compose(
        [transforms.ToTensor(),
         transforms.Normalize([0,], [1,])])

    trainset = torchvision.datasets.MNIST(root='E:\学习资料_summary\八斗2023AI清华班\【10】框架&CNN\代码\pytorch\data', train=True,
                                            download=True, transform=transform)
    trainloader = torch.utils.data.DataLoader(trainset, batch_size=32,
                                              shuffle=True, num_workers=2)

    testset = torchvision.datasets.MNIST(root='E:\学习资料_summary\八斗2023AI清华班\【10】框架&CNN\代码\pytorch\data', train=False,
                                           download=True, transform=transform)
    testloader = torch.utils.data.DataLoader(testset, batch_size=32,shuffle=True, num_workers=2)
    return trainloader, testloader

In [28]:
if __name__=='__main__':
    net = MnistNet()
    model = Model(net, 'CROSS_ENTROPY', 'RMSP')
    train_loader, test_loader = mnist_load_data()
    model.train(train_loader)
    model.evaluate(test_loader)

[epoch 1, 0.00%] loss: 0.023
[epoch 1, 0.05%] loss: 1.836
[epoch 1, 0.11%] loss: 1.658
[epoch 1, 0.16%] loss: 1.604
[epoch 1, 0.21%] loss: 1.550
[epoch 1, 0.27%] loss: 1.551
[epoch 1, 0.32%] loss: 1.535
[epoch 1, 0.37%] loss: 1.537
[epoch 1, 0.43%] loss: 1.533
[epoch 1, 0.48%] loss: 1.531
[epoch 1, 0.53%] loss: 1.535
[epoch 1, 0.59%] loss: 1.527
[epoch 1, 0.64%] loss: 1.525
[epoch 1, 0.69%] loss: 1.526
[epoch 1, 0.75%] loss: 1.522
[epoch 1, 0.80%] loss: 1.523
[epoch 1, 0.85%] loss: 1.519
[epoch 1, 0.91%] loss: 1.519
[epoch 1, 0.96%] loss: 1.517
[epoch 2, 0.00%] loss: 0.015
[epoch 2, 0.05%] loss: 1.515
[epoch 2, 0.11%] loss: 1.516
[epoch 2, 0.16%] loss: 1.504
[epoch 2, 0.21%] loss: 1.505
[epoch 2, 0.27%] loss: 1.507
[epoch 2, 0.32%] loss: 1.514
[epoch 2, 0.37%] loss: 1.508
[epoch 2, 0.43%] loss: 1.506
[epoch 2, 0.48%] loss: 1.512
[epoch 2, 0.53%] loss: 1.502
[epoch 2, 0.59%] loss: 1.512
[epoch 2, 0.64%] loss: 1.507
[epoch 2, 0.69%] loss: 1.504
[epoch 2, 0.75%] loss: 1.509
[epoch 2, 0.80