Skip to content

Commit

Permalink
feat(api): add conversion script for models
Browse files Browse the repository at this point in the history
  • Loading branch information
ssube committed Jan 16, 2023
1 parent 1283bc3 commit e59449f
Show file tree
Hide file tree
Showing 3 changed files with 99 additions and 8 deletions.
89 changes: 89 additions & 0 deletions api/onnx_web/convert.py
@@ -0,0 +1,89 @@
from argparse import ArgumentParser
from basicsr.archs.rrdbnet_arch import RRDBNet
from basicsr.utils.download_util import load_file_from_url
from os import path, environ
from sys import exit

import torch
import torch.onnx

from .upscale import (
gfpgan_url,
resrgan_url,
resrgan_name,
)

model_path = environ.get('ONNX_WEB_MODEL_PATH',
path.join('..', 'models'))


def convert_real_esrgan():
dest_path = path.join(model_path, resrgan_name + '.pth')
print('converting Real ESRGAN into %s' % dest_path)

if not path.isfile(dest_path):
print('existing model not found, downloading...')
for url in resrgan_url:
dest_path = load_file_from_url(
url=url, model_dir=path.join(dest_path, resrgan_name), progress=True, file_name=None)

model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64,
num_block=23, num_grow_ch=32, scale=4)

print('loading and training Real ESRGAN model')
model.load_state_dict(torch.load(dest_path)['params_ema'])
model.train(False)
model.eval()

rng = torch.rand(1, 3, 64, 64)
input_names = ['data']
output_names = ['output']
dynamic_axes = {'data': {2: 'width', 3: 'height'},
'output': {2: 'width', 3: 'height'}}

with torch.no_grad():
dest_onnx = path.join(model_path, resrgan_name + '.onnx')
print('exporting Real ESRGAN model to %s' % dest_onnx)
torch.onnx.export(
model,
rng,
dest_onnx,
input_names=input_names,
output_names=output_names,
dynamic_axes=dynamic_axes,
opset_version=11,
export_params=True
)
print('Real ESRGAN exported to ONNX')


def convert_gfpgan():
pass


def convert_diffuser():
pass


def main() -> int:
parser = ArgumentParser(
prog='onnx-web model converter',
description='convert checkpoint models to ONNX')
parser.add_argument('--diffusers', type=str, nargs='+',
help='models using the diffusers pipeline')
parser.add_argument('--gfpgan', action='store_true')
parser.add_argument('--resrgan', action='store_true')
args = parser.parse_args()
print(args)

for model in args.diffusers:
print('convert ' + model)

if args.resrgan:
convert_real_esrgan()

return 0


if __name__ == '__main__':
exit(main())
3 changes: 2 additions & 1 deletion api/onnx_web/image.py
Expand Up @@ -159,7 +159,7 @@ def noise_source_histogram(source_image: Image, dims: Point, origin: Point) -> I
return noise


# based on https://github.com/AUTOMATIC1111/stable-diffusion-webui/blob/master/scripts/outpainting_mk_2.py#L175-L232
# very loosely based on https://github.com/AUTOMATIC1111/stable-diffusion-webui/blob/master/scripts/outpainting_mk_2.py#L175-L232
def expand_image(
source_image: Image,
mask_image: Image,
Expand All @@ -179,6 +179,7 @@ def expand_image(

full_mask = mask_filter(mask_image, dims, origin)
full_noise = noise_source(source_image, dims, origin)
# TODO: multiply noise by mask before compositing
full_source = Image.composite(full_noise, full_source, full_mask.convert('L'))

return (full_source, full_mask, full_noise, (full_width, full_height))
15 changes: 8 additions & 7 deletions api/onnx_web/upscale.py
Expand Up @@ -8,30 +8,31 @@
import numpy as np

denoise_strength = 0.5
gfpgan_url = 'https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.3.pth'
resrgan_url = [
'https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth']
fp16 = False
model_name = 'RealESRGAN_x4plus'
netscale = 4
outscale = 4
pre_pad = 0
tile = 0
tile_pad = 10

gfpgan_url = 'https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.3.pth'
resrgan_name = 'RealESRGAN_x4plus'
resrgan_url = [
'https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth']


def make_resrgan(model_path):
model_path = path.join(model_path, model_name + '.pth')
model_path = path.join(model_path, resrgan_name + '.onnx')
if not path.isfile(model_path):
for url in resrgan_url:
model_path = load_file_from_url(
url=url, model_dir=path.join(model_path, model_name), progress=True, file_name=None)
url=url, model_dir=path.join(model_path, resrgan_name), progress=True, file_name=None)

model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64,
num_block=23, num_grow_ch=32, scale=4)

dni_weight = None
if model_name == 'realesr-general-x4v3' and denoise_strength != 1:
if resrgan_name == 'realesr-general-x4v3' and denoise_strength != 1:
wdn_model_path = model_path.replace(
'realesr-general-x4v3', 'realesr-general-wdn-x4v3')
model_path = [model_path, wdn_model_path]
Expand Down

0 comments on commit e59449f

Please sign in to comment.