#将pytorch模型转换为onnx模型

In [None]:
import dnnlib
import numpy as np
import torch
import legacy
import functools
import click
import os

In [None]:
#需要调整的有1.权重文件的路径；2.onnx模型的输出路径及名称
# 加载模型
device = torch.device('cpu')
source = r'models\pkl\network-snapshot-005000.pkl'
with dnnlib.util.open_url(source) as f:
        GG = legacy.load_network_pkl(f)['G'].to(device)
GG.forward = functools.partial(GG.forward, force_fp32=True)
# 输入
dummy_input = torch.from_numpy(np.random.RandomState(0).randn(1, GG.z_dim)).to(device)
label = torch.zeros([1, GG.c_dim], device=device)
#转换onnx模型
in_names = ["z"] + ["c"]
out_names = ["Y"]
torch.onnx.export(model=GG,
                      args=(dummy_input, label),
                      f="stylegan2.onnx",
                      input_names=in_names,
                      output_names=out_names,
                      verbose=False,
                      opset_version=10,
                      export_params=False,
                      do_constant_folding=False,
                      operator_export_type=torch.onnx.OperatorExportTypes.ONNX)

#不同模型生成图片对比

In [None]:
#pytorch模型生成图片
import os
import re
from typing import List, Optiona
import click
import dnnlib
import numpy as np
import PIL.Image
import torch
import legacy


In [None]:
def num_range(s: str) -> List[int]:
    '''Accept either a comma separated list of numbers 'a,b,c' or a range 'a-c' and return as a list of ints.'''

    range_re = re.compile(r'^(\d+)-(\d+)$')
    m = range_re.match(s)
    if m:
        return list(range(int(m.group(1)), int(m.group(2)) + 1))
    vals = s.split(',')
    return [int(x) for x in vals]
#需要调整的有1.network_pkl，权重文件；2.seeds可以修改；3.outdir，生成图片的路径根据需要修改
# ----------------------------------------------------------------------------
def generate_images():
    network_pkl = r'models\pkl\network-snapshot-005000.pkl'
    seeds = [85,265,297,849]
    outdir = r'results/'
    projected_w = None
    class_idx = None
    print('Loading networks from "%s"...' % network_pkl)
    device = torch.device('cuda')
    with dnnlib.util.open_url(network_pkl) as f:
        G = legacy.load_network_pkl(f)['G_ema'].to(device)

    os.makedirs(outdir, exist_ok=True)

    if projected_w is not None:
        if seeds is not None:
            print('warn: --seeds is ignored when using --projected-w')
        print(f'Generating images from projected W "{projected_w}"')
        ws = np.load(projected_w)['w']
        ws = torch.tensor(ws, device=device)
        assert ws.shape[1:] == (G.num_ws, G.w_dim)
        for idx, w in enumerate(ws):
            img = G.synthesis(w.unsqueeze(0), noise_mode='random')
            img = (img.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255).to(torch.uint8)
            img = PIL.Image.fromarray(img[0].cpu().numpy(), 'RGB').save(f'{outdir}/proj{idx:02d}.png')
        return

    if seeds is None:
        ctx.fail('--seeds option is required when not using --projected-w')


    label = torch.zeros([1, G.c_dim], device=device)
    if G.c_dim != 0:
        if class_idx is None:
            ctx.fail('Must specify class label with --class when using a conditional network')
        label[:, class_idx] = 1
    else:
        if class_idx is not None:
            print('warn: --class=lbl ignored when running on an unconditional network')

    # Generate images.
    for seed_idx, seed in enumerate(seeds):
        print('Generating image for seed %d (%d/%d) ...' % (seed, seed_idx, len(seeds)))
        z = torch.from_numpy(np.random.RandomState(seed).randn(1, G.z_dim)).to(device)
        img = G(z, label, truncation_psi=0.7, noise_mode='random')
        img = (img.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255).to(torch.uint8)
        PIL.Image.fromarray(img[0].cpu().numpy(), 'RGB').save(f'{outdir}/seed{seed:04d}.png')

In [None]:
generate_images()

In [None]:
#onnx模型生成图片
import dnnlib
import numpy as np
import torch
import legacy
import functools
import click
import os
from PIL import Image
import cv2
import onnxruntime

In [None]:
#需要调整的有1.onnx模型的路径；2.图片保存的路径
# 创建一个InferenceSession的实例，并将模型的地址传递给该实例
onnx_model = onnxruntime.InferenceSession('results/stylegan2.onnx')
# 输入
dummy_input = np.random.randn(1, 512).astype(np.double)
# 生成图片
output_name = onnx_model.get_outputs()[0].name
outputs = onnx_model.run([output_name], {onnx_model.get_inputs()[0].name: dummy_input})[0]
output = (outputs.squeeze().transpose((1, 2, 0)) * 127.5 + 128)
image = np.clip(output, 0, 255).astype(np.uint8)
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
cv2.imwrite("./results/result.png",image)