In [7]:
import os, numpy as np, h5py
from pathlib import Path
from PIL import Image

dir_name = Path('./images')

In [2]:
# 检测图像文件是否损坏或格式正确
def check_image_file(path):
    error_files = []
    for file in path.rglob('*'):
        if file.is_file():
            try:
                with Image.open(file) as img:
                    img.load()
            except Exception as e:
                # 找出有问题的图像文件
                print(f'Error: {file} - {e}')
                error_files.append(file)
    return error_files

broken_files = check_image_file(dir_name)

[os.remove(val) for val in broken_files]

[]

In [3]:
# 处理RGBA图像
def process_rgba_image(path):
    rgba_files = []
    for file in path.rglob('*'):
        if file.is_file():
            with Image.open(file) as img:
                if img.mode in ('P', 'PA', 'L', 'LA'):
                    print(f'发现调色板图像：{file} - {img.mode}')
                    img = img.convert('RGBA')
                    background = Image.new('RGB', img.size, (255, 255, 255))
                    background.paste(img, mask=img.split()[-1])
                    new_file = str(file.as_posix()).split('.')[0] + '_t.jpg'
                    background.save(new_file)
                    print(f'已转换为新RGB图像：{new_file}')
                    rgba_files.append(file)
    return rgba_files

rgba_files = process_rgba_image(dir_name)

[os.remove(val) for val in rgba_files]

[]

In [4]:
# 生成（文件名，标签）格式npy文件
data, record = [], []

for file in dir_name.rglob('*'):
    if file.is_file():
        record.append(str(file.as_posix()))
        record.append(file.parent.name)
        data.append(record)
        record = []

a = np.array(data)
np.savetxt('./filelist.csv', a, fmt='%s', delimiter=',')

In [5]:
# 生成 images, labels 压缩数据集npz，并转成统一尺寸
def images_labels_generator(data):
    for i in range(len(data)):
        with Image.open(data[i][0]) as img:
            image = img.convert('RGB').resize((64, 64))
        yield np.array(image), data[i][1]

In [6]:
# 通过流式处理生成器生成images, labels数据集data.npz
images, labels = [], []

for idx, (img, label) in enumerate(images_labels_generator(a)):
    images.append(img)
    labels.append(label)

images, labels = np.array(images), np.array(labels)
np.savez_compressed('./shengxiao.npz', images=images, labels=labels)