Skip to content

Commit

Permalink
feat(api): parse named tile sizes
Browse files Browse the repository at this point in the history
  • Loading branch information
ssube committed Jan 29, 2023
1 parent db9189f commit 8f1cbc8
Show file tree
Hide file tree
Showing 5 changed files with 32 additions and 9 deletions.
4 changes: 2 additions & 2 deletions api/onnx_web/chain/blend_img2img.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,12 +30,12 @@ def blend_img2img(
prompt: str = None,
**kwargs,
) -> Image.Image:
logger.info('generating image using img2img, %s steps: %s', params.steps, params.prompt)
prompt = prompt or params.prompt
logger.info('generating image using img2img, %s steps: %s', params.steps, prompt)

pipe = load_pipeline(OnnxStableDiffusionImg2ImgPipeline,
params.model, params.provider, params.scheduler)

prompt = prompt or params.prompt
rng = np.random.RandomState(params.seed)

result = pipe(
Expand Down
4 changes: 3 additions & 1 deletion api/onnx_web/chain/upscale_stable_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,9 +67,11 @@ def upscale_stable_diffusion(
source: Image.Image,
*,
upscale: UpscaleParams,
prompt: str = None,
**kwargs,
) -> Image.Image:
logger.info('upscaling with Stable Diffusion')
prompt = prompt or params.prompt
logger.info('upscaling with Stable Diffusion, %s steps: %s', params.steps, prompt)

pipeline = load_stable_diffusion(ctx, upscale)
generator = torch.manual_seed(params.seed)
Expand Down
5 changes: 3 additions & 2 deletions api/onnx_web/serve.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@
get_from_list,
get_from_map,
get_not_empty,
get_size,
make_output_name,
base_join,
ServerContext,
Expand Down Expand Up @@ -577,8 +578,8 @@ def chain():

stage = StageParams(
stage_data.get('name', callback.__name__),
tile_size=int(kwargs.get('tile_size', SizeChart.auto)),
outscale=int(kwargs.get('outscale', 1)),
tile_size=get_size(kwargs.get('tile_size')),
outscale=get_and_clamp_int(kwargs,'outscale', 1, 4),
)

# TODO: create Border from border
Expand Down
19 changes: 18 additions & 1 deletion api/onnx_web/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,13 @@
from os import environ, path
from struct import pack
from time import time
from typing import Any, Dict, List, Optional, Tuple
from typing import Any, Dict, List, Optional, Union, Tuple

from .params import (
ImageParams,
Param,
Size,
SizeChart,
)

logger = getLogger(__name__)
Expand Down Expand Up @@ -92,6 +93,22 @@ def get_not_empty(args: Any, key: str, default: Any) -> Any:
return val


def get_size(val: Union[int, str, None]) -> SizeChart:
if val is None:
return SizeChart.auto

if type(val) is str:
if val in SizeChart:
return SizeChart[val]
else:
return int(val)

if type(val) is int:
return val

raise Exception('invalid size')


def hash_value(sha, param: Param):
if param is None:
return
Expand Down
9 changes: 6 additions & 3 deletions common/pipelines/example.json
Original file line number Diff line number Diff line change
Expand Up @@ -23,21 +23,24 @@
"prompt": "a magical wizard in a robe fighting a dragon",
"scale": 4,
"outscale": 4,
"tile_size": 128
"tile_size": "mini"
}
},
{
"name": "save-local",
"type": "persist-disk",
"params": {}
"params": {
"tile_size": "8k"
}
},
{
"name": "save-ceph",
"type": "persist-s3",
"params": {
"bucket": "storage-stable-diffusion",
"endpoint_url": "http://scylla.home.holdmyran.ch:8000",
"profile_name": "ceph"
"profile_name": "ceph",
"tile_size": "8k"
}
}
]
Expand Down

0 comments on commit 8f1cbc8

Please sign in to comment.