Skip to content

Commit

Permalink
fix(api): allow random seed in reseed regions
Browse files Browse the repository at this point in the history
  • Loading branch information
ssube committed Nov 11, 2023
1 parent c0a4fb6 commit 798fa5f
Show file tree
Hide file tree
Showing 5 changed files with 37 additions and 20 deletions.
14 changes: 12 additions & 2 deletions api/onnx_web/chain/source_txt2img.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

from ..diffusers.load import load_pipeline
from ..diffusers.utils import (
LATENT_FACTOR,
encode_prompt,
get_latents_from_seed,
get_tile_latents,
Expand Down Expand Up @@ -78,8 +79,12 @@ def run(
latents = get_tile_latents(latents, int(params.seed), latent_size, dims)

# reseed latents as needed
reseed_rng = np.random.default_rng(params.seed)
prompt, reseed = parse_reseed(prompt)
for top, left, bottom, right, region_seed in reseed:
if region_seed == -1:
region_seed = reseed_rng.integers(2**32)

logger.debug(
"reseed latent region: [:, :, %s:%s, %s:%s] with %s",
top,
Expand All @@ -89,8 +94,13 @@ def run(
region_seed,
)
latents[
:, :, top // 8 : bottom // 8, left // 8 : right // 8
] = get_latents_from_seed(region_seed, Size(right - left, bottom - top), params.batch)
:,
:,
top // LATENT_FACTOR : bottom // LATENT_FACTOR,
left // LATENT_FACTOR : right // LATENT_FACTOR,
] = get_latents_from_seed(
region_seed, Size(right - left, bottom - top), params.batch
)

pipe_type = params.get_valid_pipeline("txt2img")
pipe = load_pipeline(
Expand Down
10 changes: 6 additions & 4 deletions api/onnx_web/diffusers/load.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from ..convert.diffusion.lora import blend_loras, buffer_external_data_tensors
from ..convert.diffusion.textual_inversion import blend_textual_inversions
from ..diffusers.pipelines.upscale import OnnxStableDiffusionUpscalePipeline
from ..diffusers.utils import expand_prompt
from ..diffusers.utils import LATENT_FACTOR, expand_prompt
from ..params import DeviceParams, ImageParams
from ..server import ModelTypes, ServerContext
from ..torch_before_ort import InferenceSession
Expand Down Expand Up @@ -264,19 +264,21 @@ def load_pipeline(
if hasattr(pipe, vae):
vae_model = getattr(pipe, vae)
vae_model.set_tiled(tiled=params.tiled_vae)
vae_model.set_window_size(params.vae_tile // 8, params.vae_overlap)
vae_model.set_window_size(
params.vae_tile // LATENT_FACTOR, params.vae_overlap
)

# update panorama params
if params.is_panorama():
unet_stride = (params.unet_tile * (1 - params.unet_overlap)) // 8
unet_stride = (params.unet_tile * (1 - params.unet_overlap)) // LATENT_FACTOR
logger.debug(
"setting panorama window parameters: %s/%s for UNet, %s/%s for VAE",
params.unet_tile,
unet_stride,
params.vae_tile,
params.vae_overlap,
)
pipe.set_window_size(params.unet_tile // 8, unet_stride)
pipe.set_window_size(params.unet_tile // LATENT_FACTOR, unet_stride)

run_gc([device])

Expand Down
21 changes: 13 additions & 8 deletions api/onnx_web/diffusers/pipelines/panorama.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@

from onnx_web.chain.tile import make_tile_mask

from ..utils import parse_regions
from ..utils import LATENT_CHANNELS, LATENT_FACTOR, parse_regions

logger = logging.get_logger(__name__)

Expand Down Expand Up @@ -512,7 +512,12 @@ def text2img(

# get the initial random noise unless the user supplied it
latents_dtype = prompt_embeds.dtype
latents_shape = (batch_size * num_images_per_prompt, 4, height // 8, width // 8)
latents_shape = (
batch_size * num_images_per_prompt,
LATENT_CHANNELS,
height // LATENT_FACTOR,
width // LATENT_FACTOR,
)
if latents is None:
latents = generator.randn(*latents_shape).astype(latents_dtype)
elif latents.shape != latents_shape:
Expand Down Expand Up @@ -612,10 +617,10 @@ def text2img(
)

# convert coordinates to latent space
h_start = top // 8
h_end = bottom // 8
w_start = left // 8
w_end = right // 8
h_start = top // LATENT_FACTOR
h_end = bottom // LATENT_FACTOR
w_start = left // LATENT_FACTOR
w_end = right // LATENT_FACTOR

# get the latents corresponding to the current view coordinates
latents_for_region = latents[:, :, h_start:h_end, w_start:w_end]
Expand Down Expand Up @@ -1170,8 +1175,8 @@ def inpaint(
latents_shape = (
batch_size * num_images_per_prompt,
num_channels_latents,
height // 8,
width // 8,
height // LATENT_FACTOR,
width // LATENT_FACTOR,
)
latents_dtype = prompt_embeds.dtype
if latents is None:
Expand Down
10 changes: 5 additions & 5 deletions api/onnx_web/diffusers/pipelines/panorama_xl.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

from onnx_web.chain.tile import make_tile_mask

from ..utils import parse_regions
from ..utils import LATENT_FACTOR, parse_regions

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -457,10 +457,10 @@ def text2img(
)

# convert coordinates to latent space
h_start = top // 8
h_end = bottom // 8
w_start = left // 8
w_end = right // 8
h_start = top // LATENT_FACTOR
h_end = bottom // LATENT_FACTOR
w_start = left // LATENT_FACTOR
w_end = right // LATENT_FACTOR

# get the latents corresponding to the current view coordinates
latents_for_region = latents[:, :, h_start:h_end, w_start:w_end]
Expand Down
2 changes: 1 addition & 1 deletion api/onnx_web/diffusers/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
REGION_TOKEN = compile(
r"\<region:(\d+):(\d+):(\d+):(\d+):(-?[\.|\d]+):(-?[\.|\d]+):([^\>]+)\>"
)
RESEED_TOKEN = compile(r"\<reseed:(\d+):(\d+):(\d+):(\d+):(\d+)\>")
RESEED_TOKEN = compile(r"\<reseed:(\d+):(\d+):(\d+):(\d+):(-?\d+)\>")

INTERVAL_RANGE = compile(r"(\w+)-{(\d+),(\d+)(?:,(\d+))?}")
ALTERNATIVE_RANGE = compile(r"\(([^\)]+)\)")
Expand Down

0 comments on commit 798fa5f

Please sign in to comment.