In [4]:
"""
1. Download model and save the model to git_root/model/celebahq/200_Network.pth
2. Modify inpainting_celebahq.json
    ["path"]["resume_state"]: "model/celebahq/200"
    ["datasets"]["test"]["args"]["data_root"]: "<Folder Constains Inference Images>"

    (optinally) change ["model"]["which_networks"]["args"]["beta_schedule"]["test"]["n_timestep"] value to reduce # steps inference should take
                more steps yields better results
3. Modify in your particular case in this code:
    model_pth = "<PATH-TO-MODEL>/200_Network.pth"
    input_image_pth = "<PATH-TO-DATASET_PARENT_DIT>/02323.jpg"
5. Run inpainting code (assume save this code to git_root/inference/inpainting.py)
    cd inference
    python inpainting.py -c ../config/inpainting_celebahq.json -p test
"""

import argparse

import core.praser as Praser
import torch
from core.util import set_device, tensor2img
from data.util.mask import get_irregular_mask
from models.network import Network
from PIL import Image
from torchvision import transforms

model_pth = "/home/y/project/dm/Palette-Image-to-Image-Diffusion-Models/save_models/49_Network.pth"
input_image_pth = "/home/y/project/dm/Palette-Image-to-Image-Diffusion-Models/s_image/Out_682_34.png"


In [5]:


def parse_config():
    parser = argparse.ArgumentParser()
    parser.add_argument('-c', '--config', type=str,
                        default='/home/y/project/dm/Palette-Image-to-Image-Diffusion-Models/config/inpainting_places2.json', help='JSON file for configuration')
    parser.add_argument('-p', '--phase', type=str,
                        choices=['train', 'test'], help='Run train or test', default='test')
    parser.add_argument('-b', '--batch', type=int,
                        default=16, help='Batch size in every gpu')
    parser.add_argument('-gpu', '--gpu_ids', type=str, default=None)
    parser.add_argument('-d', '--debug', action='store_true')
    parser.add_argument('-P', '--port', default='21012', type=str)

    args = parser.parse_args()
    opt = Praser.parse(args)
    return opt


# config arg
opt = parse_config()
model_args = opt["model"]["which_networks"][0]["args"]


usage: ipykernel_launcher.py [-h] [-c CONFIG] [-p {train,test}] [-b BATCH]
                             [-gpu GPU_IDS] [-d] [-P PORT]
ipykernel_launcher.py: error: unrecognized arguments: --f=/home/y/.local/share/jupyter/runtime/kernel-v3f5f9f70bbaf08c041b21c700b0f8fde477a6f97e.json


SystemExit: 2

In [None]:

# initializa model
model = Network(**model_args)
state_dict = torch.load(model_pth)
model.load_state_dict(state_dict, strict=False)
device = torch.device('cuda:0')
model.to(device)
model.set_new_noise_schedule(phase='test')
model.eval()

tfs = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])

# read input and create random mask
img_pillow = Image.open(input_image_pth).convert('RGB')
img = tfs(img_pillow)
mask = get_irregular_mask([256, 256])
mask = torch.from_numpy(mask).permute(2, 0, 1)
cond_image = img*(1. - mask) + mask*torch.randn_like(img)
mask_img = img*(1. - mask) + mask

# save conditional image used a inference input
cond_image_np = tensor2img(cond_image)
Image.fromarray(cond_image_np).save("./result/cond_image.jpg")

# set device
cond_image = set_device(cond_image)
gt_image = set_device(img)
mask = set_device(mask)

# unsqueeze
cond_image = cond_image.unsqueeze(0).to(device)
gt_image = gt_image.unsqueeze(0).to(device)
mask = mask.unsqueeze(0).to(device)

# inference
with torch.no_grad():
    output, visuals = model.restoration(cond_image, y_t=cond_image,
                                        y_0=gt_image, mask=mask, sample_num=8)

# save intermediate processes
output_img = output.detach().float().cpu()
for i in range(visuals.shape[0]):
    img = tensor2img(visuals[i].detach().float().cpu())
    Image.fromarray(img).save(f"./result/process_{i}.jpg")

# save output (output should be the same as last process_{i}.jpg)
img = tensor2img(output_img)
Image.fromarray(img).save("./result/output.jpg")


In [None]:
import torch
import os

# 假设你的 Palette 模型定义在 model.py 中
from models.model import Palette
from models.guided_diffusion_modules.unet import UNet
# 设置设备
device = "cuda" if torch.cuda.is_available() else "cpu"

# 实例化模型
# 根据你的代码调整参数，例如 networks, losses, sample_num, task, optimizers
networks = UNet()  # 假设 UNet 是你的网络定义
losses = None  # 替换为实际的损失函数
sample_num = 10  # 替换为实际的样本数量
task = "image-to-image"  # 替换为实际任务
optimizers = None  # 替换为实际的优化器

model = Palette(networks=networks, losses=losses, sample_num=sample_num, task=task, optimizers=optimizers)
model = model.to(device)


TypeError: __init__() missing 5 required positional arguments: 'networks', 'losses', 'sample_num', 'task', and 'optimizers'

In [None]:

# 文件路径
ema_path = "/home/y/project/dm/Palette-Image-to-Image-Diffusion-Models/save_models/49_Network_ema.pth"
network_path = "/home/y/project/dm/Palette-Image-to-Image-Diffusion-Models/save_models/49_Network.pth"
state_path = "/home/y/project/dm/Palette-Image-to-Image-Diffusion-Models/save_models/49.state"


# 加载权重函数
def load_palette_model(model, ema_path=None, network_path=None, state_path=None):
    if os.path.exists(ema_path):
        state_dict = torch.load(ema_path, map_location=device)
        model.load_state_dict(state_dict)
        print(f"Loaded EMA weights from {ema_path}")
    elif os.path.exists(network_path):
        state_dict = torch.load(network_path, map_location=device)
        model.load_state_dict(state_dict)
        print(f"Loaded network weights from {network_path}")
    elif os.path.exists(state_path):
        checkpoint = torch.load(state_path, map_location=device)
        if "model_state_dict" in checkpoint:
            model.load_state_dict(checkpoint["model_state_dict"])
            print(f"Loaded model state from {state_path}")
        else:
            model.load_state_dict(checkpoint)  # 如果 .state 直接是 state_dict
            print(f"Loaded checkpoint from {state_path}")
    
    model.eval()  # 设置为推理模式
    return model

# 加载模型
model = load_palette_model(model, ema_path, network_path, state_path)