Skip to content

Commit

Permalink
feat(api): make chain stages into classes with max tile size and step…
Browse files Browse the repository at this point in the history
… count estimate
  • Loading branch information
ssube committed Jul 1, 2023
1 parent 5e1b700 commit 2913cd0
Show file tree
Hide file tree
Showing 29 changed files with 1,195 additions and 1,093 deletions.
2 changes: 1 addition & 1 deletion api/onnx_web/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
run_upscale_pipeline,
)
from .diffusers.stub_scheduler import StubScheduler
from .diffusers.upscale import stage_upscale_correction
from .chain.upscale import stage_upscale_correction
from .image.utils import (
expand_image,
)
Expand Down
80 changes: 40 additions & 40 deletions api/onnx_web/chain/__init__.py
Original file line number Diff line number Diff line change
@@ -1,44 +1,44 @@
from .base import ChainPipeline, PipelineStage, StageCallback, StageParams
from .blend_img2img import blend_img2img
from .blend_inpaint import blend_inpaint
from .blend_linear import blend_linear
from .blend_mask import blend_mask
from .correct_codeformer import correct_codeformer
from .correct_gfpgan import correct_gfpgan
from .persist_disk import persist_disk
from .persist_s3 import persist_s3
from .reduce_crop import reduce_crop
from .reduce_thumbnail import reduce_thumbnail
from .source_noise import source_noise
from .source_s3 import source_s3
from .source_txt2img import source_txt2img
from .source_url import source_url
from .upscale_bsrgan import upscale_bsrgan
from .upscale_highres import upscale_highres
from .upscale_outpaint import upscale_outpaint
from .upscale_resrgan import upscale_resrgan
from .upscale_stable_diffusion import upscale_stable_diffusion
from .upscale_swinir import upscale_swinir
from .blend_img2img import BlendImg2ImgStage
from .blend_inpaint import BlendInpaintStage
from .blend_linear import BlendLinearStage
from .blend_mask import BlendMaskStage
from .correct_codeformer import CorrectCodeformerStage
from .correct_gfpgan import CorrectGFPGANStage
from .persist_disk import PersistDiskStage
from .persist_s3 import PersistS3Stage
from .reduce_crop import ReduceCropStage
from .reduce_thumbnail import ReduceThumbnailStage
from .source_noise import SourceNoiseStage
from .source_s3 import SourceS3Stage
from .source_txt2img import SourceTxt2ImgStage
from .source_url import SourceURLStage
from .upscale_bsrgan import UpscaleBSRGANStage
from .upscale_highres import UpscaleHighresStage
from .upscale_outpaint import UpscaleOutpaintStage
from .upscale_resrgan import UpscaleRealESRGANStage
from .upscale_stable_diffusion import UpscaleStableDiffusionStage
from .upscale_swinir import UpscaleSwinIRStage

CHAIN_STAGES = {
"blend-img2img": blend_img2img,
"blend-inpaint": blend_inpaint,
"blend-linear": blend_linear,
"blend-mask": blend_mask,
"correct-codeformer": correct_codeformer,
"correct-gfpgan": correct_gfpgan,
"persist-disk": persist_disk,
"persist-s3": persist_s3,
"reduce-crop": reduce_crop,
"reduce-thumbnail": reduce_thumbnail,
"source-noise": source_noise,
"source-s3": source_s3,
"source-txt2img": source_txt2img,
"source-url": source_url,
"upscale-bsrgan": upscale_bsrgan,
"upscale-highres": upscale_highres,
"upscale-outpaint": upscale_outpaint,
"upscale-resrgan": upscale_resrgan,
"upscale-stable-diffusion": upscale_stable_diffusion,
"upscale-swinir": upscale_swinir,
"blend-img2img": BlendImg2ImgStage,
"blend-inpaint": BlendInpaintStage,
"blend-linear": BlendLinearStage,
"blend-mask": BlendMaskStage,
"correct-codeformer": CorrectCodeformerStage,
"correct-gfpgan": CorrectGFPGANStage,
"persist-disk": PersistDiskStage,
"persist-s3": PersistS3Stage,
"reduce-crop": ReduceCropStage,
"reduce-thumbnail": ReduceThumbnailStage,
"source-noise": SourceNoiseStage,
"source-s3": SourceS3Stage,
"source-txt2img": SourceTxt2ImgStage,
"source-url": SourceURLStage,
"upscale-bsrgan": UpscaleBSRGANStage,
"upscale-highres": UpscaleHighresStage,
"upscale-outpaint": UpscaleOutpaintStage,
"upscale-resrgan": UpscaleRealESRGANStage,
"upscale-stable-diffusion": UpscaleStableDiffusionStage,
"upscale-swinir": UpscaleSwinIRStage,
}
9 changes: 5 additions & 4 deletions api/onnx_web/chain/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from ..server import ServerContext
from ..utils import is_debug
from ..worker import ProgressCallback, WorkerContext
from .stage import BaseStage
from .utils import process_tile_order

logger = getLogger(__name__)
Expand All @@ -35,7 +36,7 @@ def __call__(
pass


PipelineStage = Tuple[StageCallback, StageParams, Optional[dict]]
PipelineStage = Tuple[BaseStage, StageParams, Optional[dict]]


class ChainProgress:
Expand Down Expand Up @@ -131,7 +132,7 @@ def __call__(
logger.info("running pipeline without source image")

for stage_pipe, stage_params, stage_kwargs in self.stages:
name = stage_params.name or stage_pipe.__name__
name = stage_params.name or stage_pipe.__class__.__name__
kwargs = stage_kwargs or {}
kwargs = {**pipeline_kwargs, **kwargs}

Expand All @@ -158,7 +159,7 @@ def __call__(
)

def stage_tile(tile: Image.Image, _dims) -> Image.Image:
tile = stage_pipe(
tile = stage_pipe.run(
job,
server,
stage_params,
Expand All @@ -182,7 +183,7 @@ def stage_tile(tile: Image.Image, _dims) -> Image.Image:
)
else:
logger.debug("image within tile size, running stage")
image = stage_pipe(
image = stage_pipe.run(
job,
server,
stage_params,
Expand Down
Empty file.
138 changes: 71 additions & 67 deletions api/onnx_web/chain/blend_img2img.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,77 +14,81 @@
logger = getLogger(__name__)


def blend_img2img(
job: WorkerContext,
server: ServerContext,
_stage: StageParams,
params: ImageParams,
source: Image.Image,
*,
strength: float,
callback: Optional[ProgressCallback] = None,
stage_source: Optional[Image.Image] = None,
**kwargs,
) -> Image.Image:
params = params.with_args(**kwargs)
source = stage_source or source
logger.info(
"blending image using img2img, %s steps: %s", params.steps, params.prompt
)
class BlendImg2ImgStage:
def run(
self,
job: WorkerContext,
server: ServerContext,
_stage: StageParams,
params: ImageParams,
source: Image.Image,
*,
strength: float,
callback: Optional[ProgressCallback] = None,
stage_source: Optional[Image.Image] = None,
**kwargs,
) -> Image.Image:
params = params.with_args(**kwargs)
source = stage_source or source
logger.info(
"blending image using img2img, %s steps: %s", params.steps, params.prompt
)

prompt_pairs, loras, inversions = parse_prompt(params)
prompt_pairs, loras, inversions = parse_prompt(params)

pipe_type = params.get_valid_pipeline("img2img")
pipe = load_pipeline(
server,
params,
pipe_type,
job.get_device(),
inversions=inversions,
loras=loras,
)
pipe_type = params.get_valid_pipeline("img2img")
pipe = load_pipeline(
server,
params,
pipe_type,
job.get_device(),
inversions=inversions,
loras=loras,
)

pipe_params = {}
if pipe_type == "controlnet":
pipe_params["controlnet_conditioning_scale"] = strength
elif pipe_type == "img2img":
pipe_params["strength"] = strength
elif pipe_type == "panorama":
pipe_params["strength"] = strength
elif pipe_type == "pix2pix":
pipe_params["image_guidance_scale"] = strength
pipe_params = {}
if pipe_type == "controlnet":
pipe_params["controlnet_conditioning_scale"] = strength
elif pipe_type == "img2img":
pipe_params["strength"] = strength
elif pipe_type == "panorama":
pipe_params["strength"] = strength
elif pipe_type == "pix2pix":
pipe_params["image_guidance_scale"] = strength

if params.lpw():
logger.debug("using LPW pipeline for img2img")
rng = torch.manual_seed(params.seed)
result = pipe.img2img(
params.prompt,
generator=rng,
guidance_scale=params.cfg,
image=source,
negative_prompt=params.negative_prompt,
num_inference_steps=params.steps,
callback=callback,
**pipe_params,
)
else:
# encode and record alternative prompts outside of LPW
prompt_embeds = encode_prompt(pipe, prompt_pairs, params.batch, params.do_cfg())
pipe.unet.set_prompts(prompt_embeds)
if params.lpw():
logger.debug("using LPW pipeline for img2img")
rng = torch.manual_seed(params.seed)
result = pipe.img2img(
params.prompt,
generator=rng,
guidance_scale=params.cfg,
image=source,
negative_prompt=params.negative_prompt,
num_inference_steps=params.steps,
callback=callback,
**pipe_params,
)
else:
# encode and record alternative prompts outside of LPW
prompt_embeds = encode_prompt(
pipe, prompt_pairs, params.batch, params.do_cfg()
)
pipe.unet.set_prompts(prompt_embeds)

rng = np.random.RandomState(params.seed)
result = pipe(
params.prompt,
generator=rng,
guidance_scale=params.cfg,
image=source,
negative_prompt=params.negative_prompt,
num_inference_steps=params.steps,
callback=callback,
**pipe_params,
)
rng = np.random.RandomState(params.seed)
result = pipe(
params.prompt,
generator=rng,
guidance_scale=params.cfg,
image=source,
negative_prompt=params.negative_prompt,
num_inference_steps=params.steps,
callback=callback,
**pipe_params,
)

output = result.images[0]
output = result.images[0]

logger.info("final output image size: %sx%s", output.width, output.height)
return output
logger.info("final output image size: %sx%s", output.width, output.height)
return output

0 comments on commit 2913cd0

Please sign in to comment.