In [1]:
import numpy as np
def read_image(file_path):
    with open(file_path,'rb') as f:
        file = f.read()
        img_num = int.from_bytes(file[4:8],byteorder='big')
        img_h = int.from_bytes(file[8:12],byteorder='big')
        img_w = int.from_bytes(file[12:16],byteorder='big')
        img_data = []
        file = file[16:]
        data_len = img_h*img_w

        for i in range(img_num):
            data = [item/255 for item in file[i*data_len:(i+1)*data_len]]
            img_data.append(np.array(data).reshape(img_h,img_w))

        return img_data

def read_label(file_path):
    with open(file_path,'rb') as f:
        file = f.read()
        label_num = int.from_bytes(file[4:8],byteorder='big')
        file = file[8:]
        label_data = []
        for i in range(label_num):
            label_data.append(file[i])
        return label_data

test_img = read_image("mnist_data/test/t10k-images.idx3-ubyte")
test_label = read_label("mnist_data/test/t10k-labels.idx1-ubyte")

In [2]:
import torch
from torch.utils.data import Dataset,DataLoader
import torchvision
from torchvision import datasets,transforms

my_transforms = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])
class MnistDataset(Dataset):
    def __init__(self,image,label,my_transforms):
        self.len = len(label)
        self.image = image
        self.label = label
        self.my_transforms = my_transforms
    def __getitem__(self,index):
        return my_transforms(self.image[index]),self.label[index]

    def __len__(self):
        return self.len

In [3]:
import torch.nn as nn
import torch.nn.functional as F
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

cuda:0


使用LeNet网络结构,参考PPT中的结构

In [4]:
class MyNet(nn.Module):
    def __init__(self):
        super(MyNet,self).__init__()
        self.conv_1 = nn.Sequential(
            nn.Conv2d(1,32,kernel_size=3,padding=1),
            nn.ReLU(),
            nn.BatchNorm2d(32),
            nn.MaxPool2d(2,2),
            nn.Dropout(0.25)
        )
        self.conv_2 = nn.Sequential(
            nn.Conv2d(32,64,kernel_size=3,padding=1),
            nn.ReLU(),
            nn.BatchNorm2d(64),
            nn.MaxPool2d(2,2),
            nn.Dropout(0.25),
        )

        self.conv_3 = nn.Sequential(
            nn.Conv2d(64,128,kernel_size=3),
            nn.ReLU(),
            nn.BatchNorm2d(128),
            nn.MaxPool2d(2,2),
            nn.Dropout(0.25),
        )

        self.fc = nn.Sequential(
            nn.Linear(512,128),
            nn.Linear(128,10)
        )

    def forward(self,x):
        x = self.conv_1(x)
        x = self.conv_2(x)
        x = self.conv_3(x)
        x = x.view(x.size(0),-1)
        
        x = self.fc(x)
        return F.log_softmax(x,dim=1)
myNet = torch.load("mnist.h5").to(device)

In [5]:

testDataset = MnistDataset(test_img,test_label,my_transforms)
test_loader = DataLoader(testDataset,256)


def test_loss_acc():
    correct = 0
    total = 0
    for data in test_loader:
            
        test_imgs,test_labels = data
        test_imgs = test_imgs.type(torch.FloatTensor)
        outputs = myNet(test_imgs.to(device)).to("cpu")
        _,predict_labels = torch.max(outputs,1)
        total += test_labels.size(0)
        mask = predict_labels == test_labels
        correct += mask.sum().item()
    print("测试集正确率：{}%".format(100.0 * correct / total))
    return total,correct
    


In [6]:
import torch.optim as optim

myNet.eval()
test_loss_acc()


测试集正确率：99.36%


(10000, 9936)