In [None]:
!pip3 install torch torchvision numpy

In [7]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader,Dataset
import torchvision.transforms as T
import os
import gzip
import numpy as np

In [9]:
class MNISTDataset(Dataset):
    def __init__(self,datadir,transform,is_train = True):
        super().__init__()
        self.datadir = datadir
        self.img,self.label = self.load_data(self.datadir, is_train = is_train)
        self.len_data = len(self.img)
        self.transform = transform
        
    def __getitem__(self,index):
        return self.transform(self.img[index]), self.label[index]
    
    def __len__(self):
        return self.len_data
    
    def load_data(self, datadir, is_train):
        dirname = os.path.join(datadir)
        files = ['train-labels-idx1-ubyte.gz', 'train-images-idx3-ubyte.gz',
            't10k-labels-idx1-ubyte.gz', 't10k-images-idx3-ubyte.gz']

        paths = []
        for fname in files:
            paths.append(os.path.join(dirname,fname))
        if is_train:

            with gzip.open(paths[0], 'rb') as lbpath:
                label = np.frombuffer(lbpath.read(), np.uint8, offset=8)
            with gzip.open(paths[1], 'rb') as imgpath:
                img = np.frombuffer(imgpath.read(), np.uint8,
                                   offset=16).reshape(len(label), 28, 28)
        else:
            with gzip.open(paths[2], 'rb') as lbpath:
                label = np.frombuffer(lbpath.read(), np.uint8, offset=8)

            with gzip.open(paths[3], 'rb') as imgpath:
                img = np.frombuffer(imgpath.read(), np.uint8,
                                      offset=16).reshape(len(label), 28, 28)
        return img, label

In [10]:
"""MNIST数据集"""
train_dataset = MNISTDataset(
    datadi = 'MNIST/raw',
    transform = T.ToTensor(),
    is_train = True
)
test_dataset = MNISTDataset(
    datadir = 'MNIST/raw',
    transform = T.ToTensor(),
    is_train = False
)
# DataLoader
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, drop_last=True)
test_loader = DataLoader(test_dataset, batch_size=10000, shuffle=False, drop_last=False)
data,target = train_dataset[0]
print(data.shape)
print(target)

torch.Size([1, 28, 28])
5


  img = torch.from_numpy(pic.transpose((2, 0, 1))).contiguous()


In [11]:
class SimpleNN(nn.Module):
    def __init__(self):
        super(SimpleNN, self).__init__()
        self.flatten = nn.Flatten()
        self.fc1 = nn.Linear(28*28, 128)
        self.relu1 = nn.ReLU()
        self.fc2 = nn.Linear(128, 64)
        self.relu2 = nn.ReLU()
        self.fc3 = nn.Linear(64, 10)

    def forward(self, x):
        x = self.flatten(x)
        x = self.relu1(self.fc1(x))
        x = self.relu2(self.fc2(x))
        x = self.fc3(x)
        return x

In [12]:
model = SimpleNN() # 实例化模型
loss_fn = nn.CrossEntropyLoss() # 损失函数：交叉熵损失
opt = torch.optim.Adam(model.parameters(),lr=1e-3) # 定义优化器

In [13]:
'''训练代码'''
def train(epoch):
    model.train()
    total_loss = 0.
    for iter,(data,target) in enumerate(train_loader):
        opt.zero_grad()  # backward前梯度清零
        output = model(data)
        loss = loss_fn(output,target) # 计算损失
        loss.backward() # 误差反向传播，计算梯度
        opt.step() # 参数更新
        total_loss += loss.item()
    print(f'Train Epoch: {epoch} Loss: {total_loss/len(train_loader):.3f}')

In [14]:
'''测试代码'''
def test(epoch):
    model.eval()
    correct = 0
    tot = 0
    for data,target in test_loader:
        output = model(data)
        # 使用for循环找到每个输出的最大值对应的索引（即预测的类别）
        pred = []
        for i in range(output.size(0)):  # 遍历每个样本
            max_index = 0
            max_value = output[i][0]
            for j in range(1, output.size(1)):  # 遍历每个类别的得分
                if output[i][j] > max_value:
                    max_value = output[i][j]
                    max_index = j
            pred.append(max_index)
        
        # 使用for循环计算正确预测的数量
        for i in range(len(pred)):
            if pred[i] == target[i]:
                correct += 1
        
        tot+=data.shape[0]

    print(f'Test Epoch:{epoch} Accuracy: {correct/tot*100:.2f}%')

In [15]:

for epoch in range(5):
    '''训练'''
    train(epoch)
    '''测试'''
    test(epoch)

Train Epoch: 0 Loss: 2.269
Test Epoch:0 Accuracy: 33.06%
Train Epoch: 1 Loss: 2.141
Test Epoch:1 Accuracy: 51.49%
Train Epoch: 2 Loss: 1.848
Test Epoch:2 Accuracy: 67.73%
Train Epoch: 3 Loss: 1.341
Test Epoch:3 Accuracy: 77.19%
Train Epoch: 4 Loss: 0.913
Test Epoch:4 Accuracy: 82.37%
