In [1]:
import json
with open('cat_to_name.json', 'r') as f:
    cat_to_name = json.load(f)

In [2]:
cat_to_name

{'21': 'fire lily',
 '3': 'canterbury bells',
 '45': 'bolero deep blue',
 '1': 'pink primrose',
 '34': 'mexican aster',
 '27': 'prince of wales feathers',
 '7': 'moon orchid',
 '16': 'globe-flower',
 '25': 'grape hyacinth',
 '26': 'corn poppy',
 '79': 'toad lily',
 '39': 'siam tulip',
 '24': 'red ginger',
 '67': 'spring crocus',
 '35': 'alpine sea holly',
 '32': 'garden phlox',
 '10': 'globe thistle',
 '6': 'tiger lily',
 '93': 'ball moss',
 '33': 'love in the mist',
 '9': 'monkshood',
 '102': 'blackberry lily',
 '14': 'spear thistle',
 '19': 'balloon flower',
 '100': 'blanket flower',
 '13': 'king protea',
 '49': 'oxeye daisy',
 '15': 'yellow iris',
 '61': 'cautleya spicata',
 '31': 'carnation',
 '64': 'silverbush',
 '68': 'bearded iris',
 '63': 'black-eyed susan',
 '69': 'windflower',
 '62': 'japanese anemone',
 '20': 'giant white arum lily',
 '38': 'great masterwort',
 '4': 'sweet pea',
 '86': 'tree mallow',
 '101': 'trumpet creeper',
 '42': 'daffodil',
 '22': 'pincushion flower',
 

In [4]:
import sys
import torch
import torchvision.transforms as transforms
from PIL import Image
from PyQt5 import QtWidgets, QtGui, QtCore
from torchvision import models
import torch.nn as nn

class GUI(QtWidgets.QWidget):
    def __init__(self):
        super().__init__()
        # 定义数据预处理
        self.data_transform = transforms.Compose([
            transforms.Resize((200, 200)),
            transforms.RandomResizedCrop(180),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])
#         self.classes = cat_to_namea

        # 加载模型权重
        self.model = models.resnet18(pretrained=True)
        num_ftrs = self.model.fc.in_features
        self.model.fc = nn.Linear(num_ftrs, 102)
        model_dict = torch.load('models/res18.1.pt',map_location=torch.device('cpu'))
        self.model.load_state_dict(model_dict)
        self.model.eval()

        # 创建GUI界面
        self.initUI()

    def initUI(self):
        # 创建QLabel用于显示图片
        self.image_label = QtWidgets.QLabel(self) #创建一个QLabel对象，用于显示图片，并将其设置为GUI的子对象。
        self.image_label.setGeometry(QtCore.QRect(50, 50, 400, 360))
        self.image_label.setAlignment(QtCore.Qt.AlignCenter)
        self.image_label.setScaledContents(True)#设置QLabel对象中显示的图片可以自动缩放以适应QLabel对象的大小。
        self.image_label.setText('图片展示区域')

        # 创建QLabel用于显示分类结果
        font = QtGui.QFont()
        font.setPointSize(14)  # 设置字体大小为14
        self.result_label = QtWidgets.QLabel(self)
        self.result_label.setGeometry(QtCore.QRect(50, 250, 400, 360))
        self.result_label.setAlignment(QtCore.Qt.AlignCenter)
        self.result_label.setFont(font)

        # 创建QPushButton用于选择图片
        self.select_button = QtWidgets.QPushButton(self)
        self.select_button.setGeometry(QtCore.QRect(150, 445, 200, 20))
        self.select_button.setText('选择文件')
        self.select_button.clicked.connect(self.selectImage)

        # 设置窗口属性
        self.setGeometry(300, 400, 500, 500)
        self.setWindowTitle('花朵分类')
        self.show()

    def selectImage(self):
        # 打开文件选择对话框，选择要分类的图片
        file_name, _ = QtWidgets.QFileDialog.getOpenFileName(self, 'Open File', '.', 'Image Files(*.jpg *.png *.jpeg)')
        if file_name:
            # 显示选择的图片
            pixmap = QtGui.QPixmap(file_name)
            self.image_label.setPixmap(pixmap)

            # 对图片进行预处理
            image = Image.open(file_name)
            image = self.data_transform(image)
            image = image.unsqueeze(0)

            # 进行分类并显示结果
            with torch.no_grad():
                outputs = self.model(image)
                _, predicted = torch.max(outputs.data, 1)
#                 result = self.classes[predicted.item()]
                predicted_label = str(predicted.item())
                result = cat_to_name[predicted_label]
                self.result_label.setText(result)
                
                



if __name__ == '__main__':
    app = QtWidgets.QApplication(sys.argv)
    gui = GUI()
    sys.exit(app.exec_())




SystemExit: 0

  warn("To exit: use 'exit', 'quit', or Ctrl-D.", stacklevel=1)
