Skip to content

Commit

Permalink
feat(api): add support for highres images
Browse files Browse the repository at this point in the history
  • Loading branch information
ssube committed Apr 1, 2023
1 parent f4c7f02 commit cdaf1b8
Showing 1 changed file with 61 additions and 1 deletion.
62 changes: 61 additions & 1 deletion api/onnx_web/diffusers/run.py
Expand Up @@ -6,10 +6,12 @@
from diffusers import OnnxStableDiffusionImg2ImgPipeline, OnnxStableDiffusionPipeline
from PIL import Image

from onnx_web.chain.utils import process_tile_order

from ..chain import blend_mask, upscale_outpaint
from ..chain.base import ChainProgress
from ..output import save_image, save_params
from ..params import Border, ImageParams, Size, StageParams, UpscaleParams
from ..params import Border, ImageParams, Size, StageParams, TileOrder, UpscaleParams
from ..server import ServerContext
from ..upscale import run_upscale_correction
from ..utils import run_gc
Expand Down Expand Up @@ -79,6 +81,64 @@ def run_txt2img_pipeline(
)

for image, output in zip(result.images, outputs):
highres_scale = 4
highres_strength = 0.5

if params.highres > 1:

def highres(tile: Image.Image, dims):
highpipe = load_pipeline(
server,
OnnxStableDiffusionImg2ImgPipeline,
params.model,
params.scheduler,
job.get_device(),
params.lpw,
inversions,
loras,
)
progress = job.get_progress_callback()
if params.lpw:
logger.debug("using LPW pipeline for img2img")
rng = torch.manual_seed(params.seed)
result = highpipe.img2img(
tile,
params.prompt,
generator=rng,
guidance_scale=params.cfg,
negative_prompt=params.negative_prompt,
num_images_per_prompt=1,
num_inference_steps=params.steps,
strength=highres_strength,
eta=params.eta,
callback=progress,
)
return result.images[0]
else:
rng = np.random.RandomState(params.seed)
result = highpipe(
params.prompt,
tile,
generator=rng,
guidance_scale=params.cfg,
negative_prompt=params.negative_prompt,
num_images_per_prompt=1,
num_inference_steps=params.steps,
strength=highres_strength,
eta=params.eta,
callback=progress,
)
return result.images[0]

logger.info("running highres fix for %s tiles", highres_scale)
image = process_tile_order(
TileOrder.grid,
image,
size.height // highres_scale,
highres_scale,
[highres],
)

image = run_upscale_correction(
job,
server,
Expand Down

0 comments on commit cdaf1b8

Please sign in to comment.