In [1]:
import os
import torch
from torchvision import transforms
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
from model import UNet
from tqdm.notebook import tqdm

def is_image_file(filename):
    return any(filename.endswith(extension) for extension in ['.png', '.jpg', '.jpeg', '.bmp', '.tiff'])

def main():
    weights_path = "./run_results/best_model.pth"
    test_path = './inference'
    save_path = './processed_images'  # 定义保存处理后图像的路径

    # 如果保存路径不存在，创建它
    os.makedirs(save_path, exist_ok=True)

    with open('./data/grayList.txt', 'r') as f:
        gray = f.read().splitlines()

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    model = UNet(num_classes=len(gray))
    model.load_state_dict(torch.load(weights_path, map_location=device))
    model.to(device)
    model.eval()

    test_imgs = [os.path.join(test_path, i) for i in os.listdir(test_path) if is_image_file(i)]
    progress_bar = tqdm(test_imgs, desc="Processing images")

    num_imgs = len(test_imgs)
    plt.figure(figsize=(10, 5 * num_imgs))

    for idx, test_img in enumerate(progress_bar):
        original_img = Image.open(test_img).convert('RGB')
        data_transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
        ])
        img = data_transform(original_img)
        img = torch.unsqueeze(img, dim=0)

        with torch.no_grad():
            output = model(img.to(device))
            prediction = output.argmax(1).squeeze(0)
            prediction = prediction.to("cpu").numpy().astype(np.uint8)

            for i in np.unique(prediction):
                prediction[prediction == i] = gray[i]

        # 保存处理后的图像
        processed_img_name = os.path.splitext(os.path.basename(test_img))[0] + '_result.png'
        processed_img_path = os.path.join(save_path, processed_img_name)
        Image.fromarray(prediction).save(processed_img_path)  # 使用 PIL 保存图像

        plt.subplot(num_imgs, 2, 2*idx + 1)
        plt.imshow(original_img)
        plt.title("Original Image")
        plt.axis('off')

        plt.subplot(num_imgs, 2, 2*idx + 2)
        plt.imshow(prediction, cmap='gray')
        plt.title("Prediction")
        plt.axis('off')

    plt.tight_layout()
    plt.show()
    print("所有图像处理完成。")

if __name__ == '__main__':
    main()


In [13]:
import os
import torch
from torchvision import transforms
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
from model import UNet
from tqdm.notebook import tqdm

def is_image_file(filename):
    return any(filename.endswith(extension) for extension in ['.png', '.jpg', '.jpeg', '.bmp', '.tiff'])

def main():
    weights_path = "./run_results/best_model.pth"
    test_path = './inference'
    save_path = './processed_images'  # 定义保存处理后图像的路径

    # 如果保存路径不存在，创建它
    os.makedirs(save_path, exist_ok=True)

    with open('./data/grayList.txt', 'r') as f:
        gray = f.read().splitlines()

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    model = UNet(num_classes=len(gray))
    model.load_state_dict(torch.load(weights_path, map_location=device))
    model.to(device)
    model.eval()

    test_imgs = [os.path.join(test_path, i) for i in os.listdir(test_path) if is_image_file(i)]
    progress_bar = tqdm(test_imgs, desc="Processing images")

    num_imgs = len(test_imgs)
    plt.figure(figsize=(10, 5 * num_imgs))

    for idx, test_img in enumerate(progress_bar):
        original_img = Image.open(test_img).convert('RGB')
        data_transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
        ])
        img = data_transform(original_img)
        img = torch.unsqueeze(img, dim=0)

        with torch.no_grad():
            output = model(img.to(device))
            prediction = output.argmax(1).squeeze(0)
            prediction = prediction.to("cpu").numpy().astype(np.uint8)

            for i in np.unique(prediction):
                prediction[prediction == i] = gray[i]

        # 保存处理后的图像
        processed_img_name = os.path.splitext(os.path.basename(test_img))[0] + '_result.png'
        processed_img_path = os.path.join(save_path, processed_img_name)

        # 使用 matplotlib 的 imsave 函数保存图像
        plt.imsave(processed_img_path, prediction, cmap='gray')

        plt.subplot(num_imgs, 2, 2*idx + 1)
        plt.imshow(original_img)
        plt.title("Original Image")
        plt.axis('off')

        plt.subplot(num_imgs, 2, 2*idx + 2)
        plt.imshow(prediction, cmap='gray')
        plt.title("Prediction")
        plt.axis('off')

    plt.tight_layout()
    plt.show()
    print("所有图像处理完成。")

if __name__ == '__main__':
    main()
