In [1]:
from models.cycle_gan_model import CycleGANModel
from models.networks import define_G, define_D
import os
from options.test_options import TestOptions
from data import create_dataset
from models import create_model
from util.visualizer import save_images
from util import html
import torch
from torch.nn import DataParallel
try:
    import wandb
except ImportError:
    print('Warning: wandb package cannot be found. The option "--use_wandb" will result in error.')
    
from torchvision import transforms
from torch.utils.data import DataLoader, Dataset
from PIL import Image



In [2]:
class ImageFolderDataset(Dataset):
    def __init__(self, image_folder, transform):
        self.image_folder = image_folder
        self.image_paths = [os.path.join(image_folder, f) for f in os.listdir(image_folder) if f.endswith('.jpg')]
        self.transform = transform

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        img_path = self.image_paths[idx]
        img = Image.open(img_path).convert('RGB')  # 打开图片并转换为 RGB
        if self.transform:
            img = self.transform(img)
        return img, os.path.basename(img_path)  # 返回图片张量和文件名

In [3]:
def batch_inference(image_folder, output_folder, model, batch_size=1, one_only=False, suffix="", device='cuda'):
    # 定义预处理
    transform = transforms.Compose([
        # transforms.Resize((256, 256)),  # 调整大小
        transforms.ToTensor(),          # 转为张量
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))  # 归一化到 [-1, 1]
    ])
    
    # 加载数据集
    dataset = ImageFolderDataset(image_folder, transform)
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False)

    # 确保输出文件夹存在
    os.makedirs(output_folder, exist_ok=True)

    # 推理
    model.eval()
    model.to(device)
    img_index = 1  # 图片序号
    with torch.no_grad():
        for batch in dataloader:
            inputs, filenames = batch
            inputs = inputs.to(device)  # 将输入移动到设备
            outputs = model(inputs)  # 推理
            outputs = (outputs + 1) / 2.0  # 反归一化到 [0, 1]

            # 保存结果
            for i in range(outputs.size(0)):
                im = transforms.ToPILImage()(outputs[i].cpu())  # 转为 PIL 图像
                # im = Image.fromarray(output_image)  # 转为 PIL 图像
                im.save(os.path.join(output_folder, f"fake_epoch_{suffix}.jpg"))  # 按序号保存图片
                print(f"Saved: fake_epoch_{suffix}.jpg")
                img_index += 1
            
            if one_only:
                break

In [None]:
# 文件夹路径
image_folder = "./datasets/photo2monet/testA"  # 输入图片文件夹
output_folder = "./datasets/photo2monet/fakeB_epochs"  # 输出图片文件夹