Skip to content

Commit

Permalink
fix(api): pass hardware platform to upscaling pipeline (#77)
Browse files Browse the repository at this point in the history
  • Loading branch information
ssube committed Jan 22, 2023
1 parent fe9206c commit f319e6a
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 22 deletions.
19 changes: 10 additions & 9 deletions api/onnx_web/serve.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,7 @@ def border_from_request() -> Border:
return Border(left, right, top, bottom)


def upscale_from_request() -> UpscaleParams:
def upscale_from_request(provider: str) -> UpscaleParams:
denoise = get_and_clamp_float(request.args, 'denoise', 0.5, 1.0, 0.0)
scale = get_and_clamp_int(request.args, 'scale', 1, 4, 1)
outscale = get_and_clamp_int(request.args, 'outscale', 1, 4, 1)
Expand All @@ -199,13 +199,14 @@ def upscale_from_request() -> UpscaleParams:

return UpscaleParams(
upscaling,
provider,
correction_model=correction,
scale=scale,
outscale=outscale,
faces=faces,
platform='onnx',
denoise=denoise,
faces=faces,
face_strength=face_strength,
format='onnx',
outscale=outscale,
scale=scale,
)


Expand Down Expand Up @@ -355,7 +356,7 @@ def img2img():
source_image = Image.open(BytesIO(source_file.read())).convert('RGB')

params, size = pipeline_from_request()
upscale = upscale_from_request()
upscale = upscale_from_request(params.provider)

strength = get_and_clamp_float(
request.args,
Expand Down Expand Up @@ -385,7 +386,7 @@ def img2img():
@app.route('/api/txt2img', methods=['POST'])
def txt2img():
params, size = pipeline_from_request()
upscale = upscale_from_request()
upscale = upscale_from_request(params.provider)

output = make_output_name(
'txt2img',
Expand Down Expand Up @@ -413,7 +414,7 @@ def inpaint():

params, size = pipeline_from_request()
expand = border_from_request()
upscale = upscale_from_request()
upscale = upscale_from_request(params.provider)

fill_color = get_not_empty(request.args, 'fillColor', 'white')
mask_filter = get_from_map(request.args, 'filter', mask_filters, 'none')
Expand Down Expand Up @@ -474,7 +475,7 @@ def upscale():
source_image = Image.open(BytesIO(source_file.read())).convert('RGB')

params, size = pipeline_from_request()
upscale = upscale_from_request()
upscale = upscale_from_request(params.provider)

output = make_output_name(
'upscale',
Expand Down
28 changes: 15 additions & 13 deletions api/onnx_web/upscale.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from os import path
from PIL import Image
from realesrgan import RealESRGANer
from typing import Any, Union
from typing import Any, Literal, Union

import numpy as np
import torch
Expand Down Expand Up @@ -86,43 +86,45 @@ class UpscaleParams():
def __init__(
self,
upscale_model: str,
provider: str,
correction_model: Union[str, None] = None,
scale: int = 4,
outscale: int = 1,
denoise: float = 0.5,
faces=True,
face_strength: float = 0.5,
platform: str = 'onnx',
half=False
format: Literal['onnx', 'pth'] = 'onnx',
half=False,
outscale: int = 1,
scale: int = 4,
) -> None:
self.upscale_model = upscale_model
self.provider = provider
self.correction_model = correction_model
self.scale = scale
self.outscale = outscale
self.denoise = denoise
self.faces = faces
self.face_strength = face_strength
self.platform = platform
self.format = format
self.half = half
self.outscale = outscale
self.scale = scale

def resize(self, size: Size) -> Size:
return Size(size.width * self.outscale, size.height * self.outscale)


def make_resrgan(ctx: ServerContext, params: UpscaleParams, tile=0):
model_file = '%s.%s' % (params.upscale_model, params.platform)
model_file = '%s.%s' % (params.upscale_model, params.format)
model_path = path.join(ctx.model_path, model_file)
if not path.isfile(model_path):
raise Exception('Real ESRGAN model not found at %s' % model_path)

# use ONNX acceleration, if available
if params.platform == 'onnx':
model = ONNXNet(ctx, model_file)
elif params.platform == 'pth':
if params.format == 'onnx':
model = ONNXNet(ctx, model_file, provider=params.provider)
elif params.format == 'pth':
model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64,
num_block=23, num_grow_ch=32, scale=params.scale)
else:
raise Exception('unknown platform %s' % params.platform)
raise Exception('unknown platform %s' % params.format)

dni_weight = None
if params.upscale_model == 'realesr-general-x4v3' and params.denoise != 1:
Expand Down

0 comments on commit f319e6a

Please sign in to comment.