Skip to content

Commit

Permalink
feat(api): pass tile order to inpaint and outpaint pipelines
Browse files Browse the repository at this point in the history
  • Loading branch information
ssube committed Feb 12, 2023
1 parent 51651ab commit 3a29082
Show file tree
Hide file tree
Showing 9 changed files with 65 additions and 19 deletions.
10 changes: 7 additions & 3 deletions api/onnx_web/chain/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from ..output import save_image
from ..params import ImageParams, StageParams
from ..utils import ServerContext, is_debug
from .utils import process_tile_grid
from .utils import process_tile_order

logger = getLogger(__name__)

Expand Down Expand Up @@ -100,8 +100,12 @@ def stage_tile(tile: Image.Image, _dims) -> Image.Image:

return tile

image = process_tile_grid(
image, stage_params.tile_size, stage_params.outscale, [stage_tile]
image = process_tile_order(
stage_params.tile_order,
image,
stage_params.tile_size,
stage_params.outscale,
[stage_tile],
)
else:
logger.info("image within tile size, running stage")
Expand Down
6 changes: 4 additions & 2 deletions api/onnx_web/chain/blend_inpaint.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from ..output import save_image
from ..params import Border, ImageParams, Size, SizeChart, StageParams
from ..utils import ServerContext, is_debug
from .utils import process_tile_grid
from .utils import process_tile_order

logger = getLogger(__name__)

Expand Down Expand Up @@ -101,7 +101,9 @@ def outpaint(image: Image.Image, dims: Tuple[int, int, int]):

return result.images[0]

output = process_tile_grid(source_image, SizeChart.auto, 1, [outpaint])
output = process_tile_order(
stage.tile_order, source_image, SizeChart.auto, 1, [outpaint]
)

logger.info("final output image size", output.size)
return output
13 changes: 9 additions & 4 deletions api/onnx_web/chain/upscale_outpaint.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,9 @@
from ..diffusion.load import get_latents_from_seed, get_tile_latents, load_pipeline
from ..image import expand_image, mask_filter_none, noise_source_histogram
from ..output import save_image
from ..params import Border, ImageParams, Size, SizeChart, StageParams
from ..params import Border, ImageParams, Size, SizeChart, StageParams, TileOrder
from ..utils import ServerContext, is_debug
from .utils import process_tile_grid, process_tile_spiral
from .utils import process_tile_grid, process_tile_order

logger = getLogger(__name__)

Expand Down Expand Up @@ -120,8 +120,13 @@ def outpaint(image: Image.Image, dims: Tuple[int, int, int]):
"outpainting with an even border, using spiral tiling with %s overlap",
overlap,
)
output = process_tile_spiral(
source_image, SizeChart.auto, 1, [outpaint], overlap=overlap
output = process_tile_order(
stage.tile_order,
source_image,
SizeChart.auto,
1,
[outpaint],
overlap=overlap,
)
else:
logger.debug("outpainting with an uneven border, using grid tiling")
Expand Down
23 changes: 23 additions & 0 deletions api/onnx_web/chain/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@

from PIL import Image

from ..params import TileOrder

logger = getLogger(__name__)


Expand All @@ -16,6 +18,7 @@ def process_tile_grid(
tile: int,
scale: int,
filters: List[TileCallback],
**kwargs,
) -> Image.Image:
width, height = source.size
image = Image.new("RGB", (width * scale, height * scale))
Expand Down Expand Up @@ -46,6 +49,7 @@ def process_tile_spiral(
scale: int,
filters: List[TileCallback],
overlap: float = 0.5,
**kwargs,
) -> Image.Image:
if scale != 1:
raise Exception("unsupported scale")
Expand Down Expand Up @@ -87,3 +91,22 @@ def process_tile_spiral(
image.paste(tile_image, (left * scale, top * scale))

return image


def process_tile_order(
order: TileOrder,
source: Image.Image,
tile: int,
scale: int,
filters: List[TileCallback],
**kwargs,
) -> Image.Image:
if order == TileOrder.grid:
logger.debug("using grid tile order with tile size: %s", tile)
return process_tile_grid(source, tile, scale, filters, **kwargs)
elif order == TileOrder.kernel:
logger.debug("using kernel tile order with tile size: %s", tile)
raise NotImplementedError()
elif order == TileOrder.spiral:
logger.debug("using spiral tile order with tile size: %s", tile)
return process_tile_spiral(source, tile, scale, filters, **kwargs)
3 changes: 2 additions & 1 deletion api/onnx_web/diffusion/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,10 +151,11 @@ def run_inpaint_pipeline(
mask_filter: Any,
strength: float,
fill_color: str,
tile_order: str,
) -> None:
# device = job.get_device()
# progress = job.get_progress_callback()
stage = StageParams()
stage = StageParams(tile_order=tile_order)

image = upscale_outpaint(
job,
Expand Down
12 changes: 10 additions & 2 deletions api/onnx_web/params.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,12 @@ class SizeChart(IntEnum):
hd64k = 2**16


class TileOrder:
grid = "grid"
kernel = "kernel"
spiral = "spiral"


Param = Union[str, int, float]
Point = Tuple[int, int]

Expand Down Expand Up @@ -122,13 +128,15 @@ class StageParams:
def __init__(
self,
name: Optional[str] = None,
tile_size: int = SizeChart.auto,
outscale: int = 1,
tile_order: str = TileOrder.grid,
tile_size: int = SizeChart.auto,
# batch_size: int = 1,
) -> None:
self.name = name
self.tile_size = tile_size
self.outscale = outscale
self.tile_order = tile_order
self.tile_size = tile_size


class UpscaleParams:
Expand Down
5 changes: 4 additions & 1 deletion api/onnx_web/serve.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@
noise_source_uniform,
)
from .output import json_params, make_output_name
from .params import Border, DeviceParams, ImageParams, Size, StageParams, UpscaleParams
from .params import Border, DeviceParams, ImageParams, Size, StageParams, UpscaleParams, TileOrder
from .utils import (
ServerContext,
base_join,
Expand Down Expand Up @@ -589,6 +589,7 @@ def inpaint():
get_config_value("strength", "max"),
get_config_value("strength", "min"),
)
tile_order = get_from_list(request.args, "tileOrder", [TileOrder.grid, TileOrder.kernel, TileOrder.spiral])

output = make_output_name(
context,
Expand All @@ -604,6 +605,7 @@ def inpaint():
noise_source.__name__,
strength,
fill_color,
tile_order,
),
)
logger.info("inpaint job queued for: %s", output)
Expand All @@ -625,6 +627,7 @@ def inpaint():
mask_filter,
strength,
fill_color,
tile_order,
needs_device=device,
)

Expand Down
8 changes: 4 additions & 4 deletions api/params.json
Original file line number Diff line number Diff line change
Expand Up @@ -66,10 +66,6 @@
"default": "histogram",
"keys": []
},
"order": {
"default": "spiral",
"keys": []
},
"outscale": {
"default": 1,
"min": 1,
Expand Down Expand Up @@ -118,6 +114,10 @@
"max": 1,
"step": 0.01
},
"tileOrder": {
"default": "spiral",
"keys": []
},
"top": {
"default": 0,
"min": 0,
Expand Down
4 changes: 2 additions & 2 deletions gui/src/components/tab/Inpaint.tsx
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import { doesExist, mustExist } from '@apextoaster/js-utils';
import { Box, Button, FormControl, FormControlLabel, InputLabel, Select, Stack } from '@mui/material';
import { Box, Button, FormControl, FormControlLabel, InputLabel, MenuItem, Select, Stack } from '@mui/material';
import * as React from 'react';
import { useContext } from 'react';
import { useMutation, useQuery, useQueryClient } from 'react-query';
Expand Down Expand Up @@ -161,7 +161,7 @@ export function Inpaint() {
tileOrder: e.target.value,
});
}}
></Select>
>{['grid', 'kernel', 'spiral'].map((name) => <MenuItem key={name} value={name}>{name}</MenuItem>)}</Select>
</FormControl>
<Stack direction='row' spacing={2}>
<FormControlLabel
Expand Down

0 comments on commit 3a29082

Please sign in to comment.