Skip to content

Commit

Permalink
fix(api): use server image format when building output name
Browse files Browse the repository at this point in the history
  • Loading branch information
ssube committed Feb 2, 2023
1 parent 59b8055 commit e533dad
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 6 deletions.
6 changes: 4 additions & 2 deletions api/onnx_web/output.py
Expand Up @@ -20,6 +20,7 @@

logger = getLogger(__name__)


def hash_value(sha, param: Param):
if param is None:
return
Expand Down Expand Up @@ -62,6 +63,7 @@ def json_params(


def make_output_name(
ctx: ServerContext,
mode: str,
params: ImageParams,
size: Size,
Expand All @@ -86,11 +88,11 @@ def make_output_name(
for param in extras:
hash_value(sha, param)

return '%s_%s_%s_%s' % (mode, params.seed, sha.hexdigest(), now)
return '%s_%s_%s_%s.%s' % (mode, params.seed, sha.hexdigest(), now, ctx.image_format)


def save_image(ctx: ServerContext, output: str, image: Image.Image) -> str:
path = base_join(ctx.output_path, '%s.%s' % (output, ctx.image_format))
path = base_join(ctx.output_path, output)
image.save(path, format=ctx.image_format)
logger.debug('saved output image to: %s', path)
return path
Expand Down
17 changes: 13 additions & 4 deletions api/onnx_web/serve.py
Expand Up @@ -297,7 +297,8 @@ def load_params(context: ServerContext):
config_params = yaml.safe_load(f)

if 'platform' in config_params and context.default_platform is not None:
logger.info('overriding default platform to %s', context.default_platform)
logger.info('overriding default platform to %s',
context.default_platform)
config_platform = config_params.get('platform')
config_platform['default'] = context.default_platform

Expand Down Expand Up @@ -430,6 +431,7 @@ def img2img():
get_config_value('strength', 'min'))

output = make_output_name(
context,
'img2img',
params,
size,
Expand All @@ -449,6 +451,7 @@ def txt2img():
upscale = upscale_from_request(params.provider)

output = make_output_name(
context,
'txt2img',
params,
size)
Expand Down Expand Up @@ -490,6 +493,7 @@ def inpaint():
get_config_value('strength', 'min'))

output = make_output_name(
context,
'inpaint',
params,
size,
Expand Down Expand Up @@ -539,10 +543,11 @@ def upscale():
upscale = upscale_from_request(params.provider)

output = make_output_name(
context,
'upscale',
params,
size)
logger.info("upscale output: %s", output)
logger.info("upscale job queued for: %s", output)

source_image.thumbnail((size.width, size.height))
executor.submit_stored(output, run_upscale_pipeline,
Expand All @@ -563,7 +568,11 @@ def chain():

# get defaults from the regular parameters
params, size = pipeline_from_request()
output = make_output_name('chain', params, size)
output = make_output_name(
context,
'chain',
params,
size)

pipeline = ChainPipeline()
for stage_data in data.get('stages', []):
Expand All @@ -574,7 +583,7 @@ def chain():
stage = StageParams(
stage_data.get('name', callback.__name__),
tile_size=get_size(kwargs.get('tile_size')),
outscale=get_and_clamp_int(kwargs,'outscale', 1, 4),
outscale=get_and_clamp_int(kwargs, 'outscale', 1, 4),
)

if 'border' in kwargs:
Expand Down

0 comments on commit e533dad

Please sign in to comment.