In [1]:
import sys
import os.path as path
import gzip
from typing import Dict, List, Tuple, AnyStr, KeysView, Any
import torch
from torch.utils.data import DataLoader, Dataset
from torch import nn
from torchvision import transforms
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm

In [2]:
dataset_folder = 'data/MNIST/raw/'
files_name = {
    'train_img': 'train-images-idx3-ubyte.gz',
    'train_label': 'train-labels-idx1-ubyte.gz',
    'vali_img': 't10k-images-idx3-ubyte.gz',
    'vali_label': 't10k-labels-idx1-ubyte.gz'
}

In [3]:
def dataloader(files_name) -> Tuple:
    with gzip.open(path.join(dataset_folder, files_name['train_img']), mode='rb') as data:
        train_img = torch.frombuffer(data.read(), dtype=torch.uint8, offset=16).reshape(-1, 1, 28, 28)
    # 加载训练集 标签
    with gzip.open(path.join(dataset_folder, files_name['train_label']), mode='rb') as label:
        train_label = torch.frombuffer(label.read(), dtype=torch.uint8, offset=8)
    # 加载验证集 图片
    with gzip.open(path.join(dataset_folder, files_name['vali_img']), mode='rb') as data:
        test_img = torch.frombuffer(data.read(), dtype=torch.uint8, offset=16).reshape(-1, 1, 28, 28)
    # 加载验证集 label
    with gzip.open(path.join(dataset_folder, files_name['vali_label']), mode='rb') as label:
        test_label = torch.frombuffer(label.read(), dtype=torch.uint8, offset=8)
    return (train_img, train_label),(test_img, test_label)

In [4]:
class MNIST_dataset(Dataset):
    def __init__(self, data: List, label: List):
        self.__data = data
        self.__label = label

    def __getitem__(self, item):
        if not item < self.__len__():
            return f'Error, index {item} is out of range'
        return self.__data[item], self.__label[item]

    def __len__(self):
        return len(self.__data)
# 读取数据
train_data,test_data = dataloader(files_name)
# 将数据封装为 MNIST 类
train_dataset = MNIST_dataset(*train_data)
test_dataset = MNIST_dataset(*test_data)
len(train_dataset), len(test_dataset)


  train_img = torch.frombuffer(data.read(), dtype=torch.uint8, offset=16).reshape(-1, 1, 28, 28)


(60000, 10000)

In [5]:
class NetWork(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv0 = nn.Conv2d(in_channels=1, out_channels=6, kernel_size=(5, 5))
        self.conv1 = nn.Conv2d(6, 16, kernel_size=(5, 5))
        self.pool0 = nn.AvgPool2d(kernel_size=(2, 2))
        self.pool1 = nn.AvgPool2d(kernel_size=(2, 2))
        self.linear0 = nn.Linear(16*4*4, 120)
        self.linear1 = nn.Linear(120, 84)
        self.linear2 = nn.Linear(84, 10)
        self.dropout = nn.Dropout(0.5)
        self.relu = nn.ReLU()
        self.flatten = nn.Flatten()
    def forward(self, x):
        output = self.conv0(x)
        output = self.pool0(output)
        output = self.conv1(output)
        output = self.pool1(output)
        output = self.flatten(output)
        output = self.linear0(output)
        output = self.relu(output)
        self.dropout = nn.Dropout(0.5)
        output = self.linear1(output)
        output = self.relu(output)
        output = self.linear2(output)
        output = self.relu(output)
        return output

net = NetWork()

In [6]:
def train(net, loss, train_iter, test_iter, optimizer, epochs, device):
    net = net.to(device)
    epoch_losses = []
    train_correct = 0
    train_len = 0
    test_correct = 0
    test_len = 0

    for epoch in range(epochs):
        net.train()
        epoch_losses.clear()
        for img, label in train_iter:
            img, label = img.to(device, dtype=torch.float), label.to(device)
            optimizer.zero_grad()
            output = net(img)
            l = loss(output, label)
            l.backward()
            optimizer.step()
            epoch_losses.append(l.item())
            pred = output.argmax(dim=1, keepdim=True)
            train_correct += pred.eq(label.view_as(pred)).sum().item()
            train_len += len(label)

        train_loss = sum(epoch_losses) / len(epoch_losses)
        train_acc = train_correct / train_len * 100.0
        print(f'-----------epoch: {epoch + 1} start --------------')
        print(f'epoch: {epoch} train loss: {train_loss}')
        print(f'epoch: {epoch} train acc: {train_acc}')

        # validation
        epoch_losses.clear()
        with torch.no_grad():
            net.eval()
            for img, label in test_iter:
                img, label = img.to(device, dtype=torch.float), label.to(device)
                test_output = net(img)
                l = loss(test_output, label)
                epoch_losses.append(l.item())
                test_pred = test_output.argmax(dim=1, keepdim=True)
                test_correct += (test_pred.squeeze() == label).sum().item()
                test_len += len(label)

            test_loss = sum(epoch_losses) / len(epoch_losses)
            test_acc = test_correct / test_len * 100.0
            print(f'epoch: {epoch} test loss: {test_loss}')
            print(f'epoch: {epoch} test acc: {test_acc}')
            print(f'-----------epoch: {epoch + 1} end --------------')

In [7]:
net = NetWork()
batch_size = 16
train_iter = DataLoader(train_dataset, batch_size=batch_size)
test_iter = DataLoader(test_dataset, batch_size=batch_size)
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
num_epoch = 10
loss = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(net.parameters())
train(net, loss, train_iter, test_iter, optimizer, num_epoch, device)# from torch.utils.data import DataLoader

-----------epoch: 1 start --------------
epoch: 0 train loss: 0.5704973290729336
epoch: 0 train acc: 79.73833333333333
epoch: 0 test loss: 0.47110058530736715
epoch: 0 test acc: 83.52000000000001
-----------epoch: 1 end --------------
-----------epoch: 2 start --------------
epoch: 1 train loss: 0.37574759227796456
epoch: 1 train acc: 83.11416666666666
epoch: 1 test loss: 0.4022045310129411
epoch: 1 test acc: 84.65
-----------epoch: 2 end --------------
-----------epoch: 3 start --------------
epoch: 2 train loss: 0.35645316969580015
epoch: 2 train acc: 84.43611111111112
epoch: 2 test loss: 0.37681149199898356
epoch: 2 test acc: 85.26
-----------epoch: 3 end --------------
-----------epoch: 4 start --------------
epoch: 3 train loss: 0.3452461419521289
epoch: 3 train acc: 85.17083333333333
epoch: 3 test loss: 0.3853519351546769
epoch: 3 test acc: 85.545
-----------epoch: 4 end --------------
-----------epoch: 5 start --------------
epoch: 4 train loss: 0.29592213014059937
epoch: 4 trai