In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms

  warn(f"Failed to load image Python extension: {e}")


In [15]:
class Model:
    def __init__(self, net, cost, optimist):
        self.net = net
        self.cost = self.create_cost(cost)
        self.optimizer = self.create_optimizer(optimist)
        pass
    
    def create_cost(self, cost):
        #字典形式，可选择交叉熵损失函数或者MSE平均损失函数
        support_cost = {'CROSS_ENTROPY': nn.CrossEntropyLoss(), 'MSE':nn.MSELoss()}
        return support_cost[cost]
        
    
    def create_optimizer(self, optimist,**rests):
        support_optimizer = {
                            'SGD': optim.SGD(self.net.parameters(), lr = 0.1,**rests), 
                             'ADAM':optim.Adam(self.net.parameters(), lr=0.001, **rests),
                             'RMSP':optim.RMSprop(self.net.parameters(), lr=0.001, **rests)
                             }
        return support_optimizer[optimist]
    
    def train (self, trainloader, epochs=3):
        for epoch in range(epochs):
            running_loss = 0 #每次重置loss
            for i, data in enumerate(trainloader,0):
                #返回trainloaer的索引和data
                inputs, labels = data
                self.optimizer.zero_grad()
                #优化器每次重置梯度
                #前向传播 反向传播 优化器
                outputs = self.net(inputs)
                loss = self.cost(outputs, labels)
                loss.backward()
                self.optimizer.step()
                
                running_loss += loss.item()
                if i % 100 == 0:
                    print('[epoch %d, %.2f%%] loss: %.3f' %
                          (epoch + 1, (i + 1)*1./len(trainloader), running_loss / 100))
                    running_loss = 0.0
        print('Finished Training')
            
            
    def evaluate(self, testloader):
        print('Evaluating....')
        correct = 0
        total = 0
        with torch.no_grad():
            for i, data in enumerate(testloader,0):
                images, labels = data
                outputs = self.net(images)
                predicted = torch.argmax(outputs, 1)
                total +=labels.size(0)
                correct += (predicted == labels).sum().item()
                if i % 100 == 0:
                    print ('predicted', predicted)
                    print ('labels', labels)
        print ('Accuracy of the network on the test image: %d %%'% (100* correct/total))
            
            

In [5]:
def mnist_load_data():
    
    transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize([0,],[1,])])
    #数据预处理 tensor转化，归一化                               
    trainset = torchvision.datasets.MNIST(root='./data', train = True, download = True, transform = transform)
    testset = torchvision.datasets.MNIST(root='./data', train = False, download = True, transform = transform)
    #下载训练和测试数据集
    
    trainloader = torch.utils.data.DataLoader(trainset, batch_size=32,shuffle=True, num_workers=2)
    testloader = torch.utils.data.DataLoader(testset, batch_size=32,shuffle=False, num_workers=2)
    #实例化dataloader, 设置batchsize 是否shuffle num_worker
    
    return trainloader, testloader

In [16]:
class Linear(torch.nn.Module):
    def __init__(self):
        super(Linear, self).__init__()
        #继承父类nn.Module的属性
        self.linear1 = nn.Linear(28*28, 512)
        self.linear2 = nn.Linear(512, 512)
        self.linear3 = nn.Linear(512, 10)
        
    def forward(self, x):
        #前向传播过程
        x = x.view(-1, 28*28)
        x = F.relu(self.linear1(x))
        x = F.relu(self.linear2(x))
        x = F.softmax(self.linear3(x), dim =1)
        
        return x
        

In [17]:
if __name__ == '__main__':
    net = Linear() #实例化Linear
    model = Model(net, 'CROSS_ENTROPY', 'ADAM')
    trainloader,testloader =  mnist_load_data()
    
    model.train(trainloader)
    model.evaluate(testloader)
    

[epoch 1, 0.00%] loss: 0.023
[epoch 1, 0.05%] loss: 1.847
[epoch 1, 0.11%] loss: 1.666
[epoch 1, 0.16%] loss: 1.656
[epoch 1, 0.21%] loss: 1.643
[epoch 1, 0.27%] loss: 1.609
[epoch 1, 0.32%] loss: 1.559
[epoch 1, 0.37%] loss: 1.546
[epoch 1, 0.43%] loss: 1.537
[epoch 1, 0.48%] loss: 1.540
[epoch 1, 0.53%] loss: 1.537
[epoch 1, 0.59%] loss: 1.536
[epoch 1, 0.64%] loss: 1.538
[epoch 1, 0.69%] loss: 1.533
[epoch 1, 0.75%] loss: 1.522
[epoch 1, 0.80%] loss: 1.528
[epoch 1, 0.85%] loss: 1.532
[epoch 1, 0.91%] loss: 1.527
[epoch 1, 0.96%] loss: 1.524
[epoch 2, 0.00%] loss: 0.015
[epoch 2, 0.05%] loss: 1.526
[epoch 2, 0.11%] loss: 1.515
[epoch 2, 0.16%] loss: 1.514
[epoch 2, 0.21%] loss: 1.518
[epoch 2, 0.27%] loss: 1.517
[epoch 2, 0.32%] loss: 1.526
[epoch 2, 0.37%] loss: 1.516
[epoch 2, 0.43%] loss: 1.505
[epoch 2, 0.48%] loss: 1.513
[epoch 2, 0.53%] loss: 1.508
[epoch 2, 0.59%] loss: 1.511
[epoch 2, 0.64%] loss: 1.518
[epoch 2, 0.69%] loss: 1.505
[epoch 2, 0.75%] loss: 1.516
[epoch 2, 0.80