Skip to content

Commit

Permalink
fix(api): keep network tokens while replacing wildcards in the saved …
Browse files Browse the repository at this point in the history
…prompt
  • Loading branch information
ssube committed Jul 7, 2023
1 parent b8c0bb0 commit f7fc442
Show file tree
Hide file tree
Showing 8 changed files with 79 additions and 60 deletions.
12 changes: 7 additions & 5 deletions api/onnx_web/chain/blend_img2img.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,9 @@ def run(
"blending image using img2img, %s steps: %s", params.steps, params.prompt
)

prompt_pairs, loras, inversions = parse_prompt(params)
prompt_pairs, loras, inversions, (prompt, negative_prompt) = parse_prompt(
params
)

pipe_type = params.get_valid_pipeline("img2img")
pipe = load_pipeline(
Expand Down Expand Up @@ -67,10 +69,10 @@ def run(
rng = torch.manual_seed(params.seed)
result = pipe.img2img(
source,
params.prompt,
prompt,
generator=rng,
guidance_scale=params.cfg,
negative_prompt=params.negative_prompt,
negative_prompt=negative_prompt,
num_inference_steps=params.steps,
callback=callback,
**pipe_params,
Expand All @@ -84,11 +86,11 @@ def run(

rng = np.random.RandomState(params.seed)
result = pipe(
params.prompt,
prompt,
generator=rng,
guidance_scale=params.cfg,
image=source,
negative_prompt=params.negative_prompt,
negative_prompt=negative_prompt,
num_inference_steps=params.steps,
callback=callback,
**pipe_params,
Expand Down
20 changes: 14 additions & 6 deletions api/onnx_web/chain/blend_inpaint.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from PIL import Image

from ..diffusers.load import load_pipeline
from ..diffusers.utils import get_latents_from_seed, parse_prompt
from ..diffusers.utils import encode_prompt, get_latents_from_seed, parse_prompt
from ..image import expand_image, mask_filter_none, noise_source_histogram
from ..output import save_image
from ..params import Border, ImageParams, Size, SizeChart, StageParams
Expand Down Expand Up @@ -43,7 +43,9 @@ def run(
"blending image using inpaint, %s steps: %s", params.steps, params.prompt
)

_prompt_pairs, loras, inversions = parse_prompt(params)
prompt_pairs, loras, inversions, (prompt, negative_prompt) = parse_prompt(
params
)
pipe_type = params.get_valid_pipeline("inpaint")
pipe = load_pipeline(
server,
Expand Down Expand Up @@ -88,30 +90,36 @@ def outpaint(tile_source: Image.Image, dims: Tuple[int, int, int]):
logger.debug("using LPW pipeline for inpaint")
rng = torch.manual_seed(params.seed)
result = pipe.inpaint(
params.prompt,
prompt,
generator=rng,
guidance_scale=params.cfg,
height=size.height,
image=tile_source,
latents=latents,
mask_image=tile_mask,
negative_prompt=params.negative_prompt,
negative_prompt=negative_prompt,
num_inference_steps=params.steps,
width=size.width,
eta=params.eta,
callback=callback,
)
else:
# encode and record alternative prompts outside of LPW
prompt_embeds = encode_prompt(
pipe, prompt_pairs, params.batch, params.do_cfg()
)
pipe.unet.set_prompts(prompt_embeds)

rng = np.random.RandomState(params.seed)
result = pipe(
params.prompt,
prompt,
generator=rng,
guidance_scale=params.cfg,
height=size.height,
image=tile_source,
latents=latents,
mask_image=stage_mask,
negative_prompt=params.negative_prompt,
negative_prompt=negative_prompt,
num_inference_steps=params.steps,
width=size.width,
eta=params.eta,
Expand Down
12 changes: 7 additions & 5 deletions api/onnx_web/chain/source_txt2img.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,9 @@ def run(
"a source image was passed to a txt2img stage, and will be discarded"
)

prompt_pairs, loras, inversions = parse_prompt(params)
prompt_pairs, loras, inversions, (prompt, negative_prompt) = parse_prompt(
params
)

latents = get_latents_from_seed(params.seed, size, params.batch)
pipe_type = params.get_valid_pipeline("txt2img")
Expand All @@ -58,13 +60,13 @@ def run(
logger.debug("using LPW pipeline for txt2img")
rng = torch.manual_seed(params.seed)
result = pipe.text2img(
params.prompt,
prompt,
height=size.height,
width=size.width,
generator=rng,
guidance_scale=params.cfg,
latents=latents,
negative_prompt=params.negative_prompt,
negative_prompt=negative_prompt,
num_images_per_prompt=params.batch,
num_inference_steps=params.steps,
eta=params.eta,
Expand All @@ -79,13 +81,13 @@ def run(

rng = np.random.RandomState(params.seed)
result = pipe(
params.prompt,
prompt,
height=size.height,
width=size.width,
generator=rng,
guidance_scale=params.cfg,
latents=latents,
negative_prompt=params.negative_prompt,
negative_prompt=negative_prompt,
num_images_per_prompt=params.batch,
num_inference_steps=params.steps,
eta=params.eta,
Expand Down
25 changes: 19 additions & 6 deletions api/onnx_web/chain/upscale_outpaint.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,12 @@
from PIL import Image, ImageDraw, ImageOps

from ..diffusers.load import load_pipeline
from ..diffusers.utils import get_latents_from_seed, get_tile_latents, parse_prompt
from ..diffusers.utils import (
encode_prompt,
get_latents_from_seed,
get_tile_latents,
parse_prompt,
)
from ..image import expand_image, mask_filter_none, noise_source_histogram
from ..output import save_image
from ..params import Border, ImageParams, Size, SizeChart, StageParams
Expand Down Expand Up @@ -39,7 +44,9 @@ def run(
callback: Optional[ProgressCallback] = None,
**kwargs,
) -> List[Image.Image]:
_prompt_pairs, loras, inversions = parse_prompt(params)
prompt_pairs, loras, inversions, (prompt, negative_prompt) = parse_prompt(
params
)

pipe_type = params.get_valid_pipeline("inpaint", params.pipeline)
pipe = load_pipeline(
Expand Down Expand Up @@ -108,27 +115,33 @@ def outpaint(tile_source: Image.Image, dims: Tuple[int, int, int]):
result = pipe.inpaint(
tile_source,
tile_mask,
params.prompt,
prompt,
generator=rng,
guidance_scale=params.cfg,
height=size.height,
latents=latents,
negative_prompt=params.negative_prompt,
negative_prompt=negative_prompt,
num_inference_steps=params.steps,
width=size.width,
callback=callback,
)
else:
# encode and record alternative prompts outside of LPW
prompt_embeds = encode_prompt(
pipe, prompt_pairs, params.batch, params.do_cfg()
)
pipe.unet.set_prompts(prompt_embeds)

rng = np.random.RandomState(params.seed)
result = pipe(
params.prompt,
prompt,
tile_source,
tile_mask,
height=size.height,
width=size.width,
num_inference_steps=params.steps,
guidance_scale=params.cfg,
negative_prompt=params.negative_prompt,
negative_prompt=negative_prompt,
generator=rng,
latents=latents,
callback=callback,
Expand Down
8 changes: 5 additions & 3 deletions api/onnx_web/chain/upscale_stable_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,9 @@ def run(
"upscaling with Stable Diffusion, %s steps: %s", params.steps, params.prompt
)

prompt_pairs, _loras, _inversions = parse_prompt(params)
prompt_pairs, _loras, _inversions, (prompt, negative_prompt) = parse_prompt(
params
)

pipeline = load_pipeline(
server,
Expand All @@ -57,11 +59,11 @@ def run(
outputs = []
for source in sources:
result = pipeline(
params.prompt,
prompt,
source,
generator=generator,
guidance_scale=params.cfg,
negative_prompt=params.negative_prompt,
negative_prompt=negative_prompt,
num_inference_steps=params.steps,
eta=params.eta,
noise_level=upscale.denoise,
Expand Down
8 changes: 4 additions & 4 deletions api/onnx_web/diffusers/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ def run_txt2img_pipeline(
progress = job.get_progress_callback()
images = chain(job, server, params, [], callback=progress)

_prompt_pairs, loras, inversions = parse_prompt(params, use_input=True)
_pairs, loras, inversions, _rest = parse_prompt(params)

for image, output in zip(images, outputs):
dest = save_image(
Expand Down Expand Up @@ -177,7 +177,7 @@ def run_img2img_pipeline(
images.append(source)

# save with metadata
_prompt_pairs, loras, inversions = parse_prompt(params, use_input=True)
_pairs, loras, inversions, _rest = parse_prompt(params)
size = Size(*source.size)

for image, output in zip(images, outputs):
Expand Down Expand Up @@ -263,7 +263,7 @@ def run_inpaint_pipeline(
progress = job.get_progress_callback()
images = chain(job, server, params, [source], callback=progress)

_prompt_pairs, loras, inversions = parse_prompt(params, use_input=True)
_pairs, loras, inversions, _rest = parse_prompt(params)
for image, output in zip(images, outputs):
dest = save_image(
server,
Expand Down Expand Up @@ -331,7 +331,7 @@ def run_upscale_pipeline(
progress = job.get_progress_callback()
images = chain(job, server, params, [source], callback=progress)

_prompt_pairs, loras, inversions = parse_prompt(params, use_input=True)
_pairs, loras, inversions, _rest = parse_prompt(params)
for image, output in zip(images, outputs):
dest = save_image(
server,
Expand Down
24 changes: 18 additions & 6 deletions api/onnx_web/diffusers/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,22 +312,26 @@ def get_scaled_latents(
def parse_prompt(
params: ImageParams,
use_input: bool = False,
) -> Tuple[List[Tuple[str, str]], List[Tuple[str, float]], List[Tuple[str, float]]]:
) -> Tuple[
List[Tuple[str, str]],
List[Tuple[str, float]],
List[Tuple[str, float]],
Tuple[str, str],
]:
prompt, loras = get_loras_from_prompt(
params.input_prompt if use_input else params.prompt
)
prompt, inversions = get_inversions_from_prompt(prompt)
params.prompt = prompt
# params.prompt = prompt

neg_prompt = None
if params.input_negative_prompt is not None:
neg_prompt, neg_loras = get_loras_from_prompt(
params.input_negative_prompt if use_input else params.negative_prompt
)
neg_prompt, neg_inversions = get_inversions_from_prompt(neg_prompt)
params.negative_prompt = neg_prompt
# params.negative_prompt = neg_prompt

# TODO: check whether these need to be * -1
loras.extend(neg_loras)
inversions.extend(neg_inversions)

Expand All @@ -352,7 +356,7 @@ def parse_prompt(
for i in range(neg_prompt_count, prompt_count):
neg_prompts.append(neg_prompts[i % neg_prompt_count])

return list(zip(prompts, neg_prompts)), loras, inversions
return list(zip(prompts, neg_prompts)), loras, inversions, (prompt, neg_prompt)


def encode_prompt(
Expand All @@ -372,7 +376,7 @@ def encode_prompt(
]


def replace_wildcards(prompt: str, seed: int, wildcards: Dict[str, List[str]]) -> str:
def parse_wildcards(prompt: str, seed: int, wildcards: Dict[str, List[str]]) -> str:
next_match = WILDCARD_TOKEN.search(prompt)
remaining_prompt = prompt

Expand Down Expand Up @@ -400,6 +404,14 @@ def replace_wildcards(prompt: str, seed: int, wildcards: Dict[str, List[str]]) -
return remaining_prompt


def replace_wildcards(params: ImageParams, wildcards: Dict[str, List[str]]):
params.prompt = parse_wildcards(params.prompt, params.seed, wildcards)
if params.negative_prompt is not None:
params.negative_prompt = parse_wildcards(
params.negative_prompt, params.seed, wildcards
)


def pop_random(list: List[str]) -> str:
"""
From https://stackoverflow.com/a/14088129
Expand Down
30 changes: 5 additions & 25 deletions api/onnx_web/server/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,11 +179,7 @@ def img2img(server: ServerContext, pool: DevicePoolExecutor):
get_config_value("strength", "min"),
)

params.prompt = replace_wildcards(params.prompt, params.seed, get_wildcard_data())
if params.negative_prompt is not None:
params.negative_prompt = replace_wildcards(
params.negative_prompt, params.seed, get_wildcard_data()
)
replace_wildcards(params, get_wildcard_data())

output_count = params.batch
if source_filter is not None and source_filter != "none":
Expand Down Expand Up @@ -221,11 +217,7 @@ def txt2img(server: ServerContext, pool: DevicePoolExecutor):
upscale = upscale_from_request()
highres = highres_from_request()

params.prompt = replace_wildcards(params.prompt, params.seed, get_wildcard_data())
if params.negative_prompt is not None:
params.negative_prompt = replace_wildcards(
params.negative_prompt, params.seed, get_wildcard_data()
)
replace_wildcards(params, get_wildcard_data())

output = make_output_name(server, "txt2img", params, size)

Expand Down Expand Up @@ -271,11 +263,7 @@ def inpaint(server: ServerContext, pool: DevicePoolExecutor):
request.args, "tileOrder", [TileOrder.grid, TileOrder.kernel, TileOrder.spiral]
)

params.prompt = replace_wildcards(params.prompt, params.seed, get_wildcard_data())
if params.negative_prompt is not None:
params.negative_prompt = replace_wildcards(
params.negative_prompt, params.seed, get_wildcard_data()
)
replace_wildcards(params, get_wildcard_data())

output = make_output_name(
server,
Expand Down Expand Up @@ -334,11 +322,7 @@ def upscale(server: ServerContext, pool: DevicePoolExecutor):
upscale = upscale_from_request()
highres = highres_from_request()

params.prompt = replace_wildcards(params.prompt, params.seed, get_wildcard_data())
if params.negative_prompt is not None:
params.negative_prompt = replace_wildcards(
params.negative_prompt, params.seed, get_wildcard_data()
)
replace_wildcards(params, get_wildcard_data())

output = make_output_name(server, "upscale", params, size)

Expand Down Expand Up @@ -380,11 +364,7 @@ def chain(server: ServerContext, pool: DevicePoolExecutor):
output = make_output_name(server, "chain", params, size)
job_name = output[0]

params.prompt = replace_wildcards(params.prompt, params.seed, get_wildcard_data())
if params.negative_prompt is not None:
params.negative_prompt = replace_wildcards(
params.negative_prompt, params.seed, get_wildcard_data()
)
replace_wildcards(params, get_wildcard_data())

pipeline = ChainPipeline()
for stage_data in data.get("stages", []):
Expand Down

0 comments on commit f7fc442

Please sign in to comment.