In [6]:
import torch
import torchvision
from torchvision import transforms
from PIL import Image
from torch import nn
from collections import OrderedDict

In [7]:
image_path = '../images/six.png'
img = Image.open(image_path)
img = img.convert('L')  # 转换为灰度图像

transform = transforms.Compose([transforms.Resize((28, 28)),
                                transforms.ToTensor()])

image = transform(img)
image = torch.reshape(image, (1, 1, 28, 28))
print(image.shape)

torch.Size([1, 1, 28, 28])


In [8]:
# 定义卷积神经网络模型
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.model = nn.Sequential(OrderedDict([
            ('conv1', nn.Conv2d(1, 10, kernel_size=5)),
            ('relu1', nn.ReLU()),
            ('pool1', nn.MaxPool2d(kernel_size=2, stride=2)),
            ('conv2', nn.Conv2d(10, 20, kernel_size=5)),
            ('relu2', nn.ReLU()),
            ('pool2', nn.MaxPool2d(kernel_size=2, stride=2)),
            ('flatten', nn.Flatten()),
            ('fc1', nn.Linear(320, 50)),  # 320 = 20 * 4 * 4
            ('fc2', nn.Linear(50, 10))
        ]))


    def forward(self, x):
        x = self.model(x)
        return x

In [9]:
model = Net()
# 加载模型参数
# 如果模型参数是在 GPU 上训练来的，需要在使用时指定使用设备
model_weight = torch.load("./models/MNIST_cnn_epoch_10.pth", map_location=torch.device('cpu'))
model.load_state_dict(model_weight)
print(model)

Net(
  (model): Sequential(
    (conv1): Conv2d(1, 10, kernel_size=(5, 5), stride=(1, 1))
    (relu1): ReLU()
    (pool1): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (conv2): Conv2d(10, 20, kernel_size=(5, 5), stride=(1, 1))
    (relu2): ReLU()
    (pool2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (flatten): Flatten(start_dim=1, end_dim=-1)
    (fc1): Linear(in_features=320, out_features=50, bias=True)
    (fc2): Linear(in_features=50, out_features=10, bias=True)
  )
)


In [11]:
model.eval()
with torch.no_grad():
    output = model(image)
print(torch.argmax(output).item())

6
