Skip to content

Commit

Permalink
fix(api): improve handling of non-square images around tile size
Browse files Browse the repository at this point in the history
  • Loading branch information
ssube committed Jul 10, 2023
1 parent 95cad90 commit d9dd1e4
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 38 deletions.
29 changes: 7 additions & 22 deletions api/onnx_web/chain/source_txt2img.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,7 @@
from PIL import Image

from ..diffusers.load import load_pipeline
from ..diffusers.utils import (
encode_prompt,
get_latents_from_seed,
get_tile_latents,
parse_prompt,
)
from ..diffusers.utils import encode_prompt, get_latents_from_seed, get_tile_latents, parse_prompt
from ..params import ImageParams, Size, SizeChart, StageParams
from ..server import ServerContext
from ..worker import ProgressCallback, WorkerContext
Expand Down Expand Up @@ -54,23 +49,13 @@ def run(
)

tile_size = params.tiles

if max(size) > tile_size:
latent_size = Size(tile_size, tile_size)
pipe_width = pipe_height = tile_size
else:
latent_size = Size(size.width, size.height)
pipe_width = size.width
pipe_height = size.height
latent_size = size.min(tile_size, tile_size)

# generate new latents or slice existing
if latents is None:
# generate new latents
latents = get_latents_from_seed(params.seed, latent_size, params.batch)
else:
# slice existing latents
latents = get_tile_latents(latents, dims, Size(tile_size, tile_size))
pipe_width = pipe_height = tile_size
latents = get_tile_latents(latents, dims, latent_size)

pipe_type = params.get_valid_pipeline("txt2img")
pipe = load_pipeline(
Expand All @@ -87,8 +72,8 @@ def run(
rng = torch.manual_seed(params.seed)
result = pipe.text2img(
prompt,
height=pipe_height,
width=pipe_width,
height=latent_size.height,
width=latent_size.width,
generator=rng,
guidance_scale=params.cfg,
latents=latents,
Expand All @@ -108,8 +93,8 @@ def run(
rng = np.random.RandomState(params.seed)
result = pipe(
prompt,
height=pipe_height,
width=pipe_width,
height=latent_size.height,
width=latent_size.width,
generator=rng,
guidance_scale=params.cfg,
latents=latents,
Expand Down
23 changes: 7 additions & 16 deletions api/onnx_web/chain/upscale_outpaint.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,24 +71,15 @@ def run(
outputs.append(source)
continue

size = Size(*source.size)
tile_size = params.tiles
if max(size) > tile_size:
latent_size = Size(tile_size, tile_size)
pipe_width = pipe_height = tile_size
else:
latent_size = Size(size.width, size.height)
pipe_width = size.width
pipe_height = size.height
size = Size(*source.size)
latent_size = size.min(tile_size, tile_size)

# generate new latents or slice existing
if latents is None:
# generate new latents
latents = get_latents_from_seed(params.seed, latent_size, params.batch)
else:
# slice existing latents and make sure there is a complete tile
latents = get_tile_latents(latents, dims, Size(tile_size, tile_size))
pipe_width = pipe_height = tile_size
latents = get_tile_latents(latents, dims, latent_size)

if params.lpw():
logger.debug("using LPW pipeline for inpaint")
Expand All @@ -98,8 +89,8 @@ def run(
tile_mask,
prompt,
negative_prompt=negative_prompt,
height=pipe_height,
width=pipe_width,
height=latent_size.height,
width=latent_size.width,
num_inference_steps=params.steps,
guidance_scale=params.cfg,
generator=rng,
Expand All @@ -119,8 +110,8 @@ def run(
source,
tile_mask,
negative_prompt=negative_prompt,
height=pipe_height,
width=pipe_width,
height=latent_size.height,
width=latent_size.width,
num_inference_steps=params.steps,
guidance_scale=params.cfg,
generator=rng,
Expand Down
3 changes: 3 additions & 0 deletions api/onnx_web/params.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,9 @@ def add_border(self, border: Border):
border.top + self.height + border.bottom,
)

def min(self, width: int, height: int):
return Size(min(self.width, width), min(self.height, height))

def round_to_tile(self, tile=512):
return Size(
ceil(self.width / tile) * tile,
Expand Down

0 comments on commit d9dd1e4

Please sign in to comment.