Skip to content

Commit

Permalink
fix(api): turn alternatives back off for SDXL
Browse files Browse the repository at this point in the history
  • Loading branch information
ssube committed Nov 12, 2023
1 parent 6eb014c commit 3ffbc00
Show file tree
Hide file tree
Showing 4 changed files with 23 additions and 19 deletions.
9 changes: 5 additions & 4 deletions api/onnx_web/chain/blend_img2img.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,10 +81,11 @@ def run(
)
else:
# encode and record alternative prompts outside of LPW
prompt_embeds = encode_prompt(
pipe, prompt_pairs, params.batch, params.do_cfg()
)
pipe.unet.set_prompts(prompt_embeds)
if not params.is_xl():
prompt_embeds = encode_prompt(
pipe, prompt_pairs, params.batch, params.do_cfg()
)
pipe.unet.set_prompts(prompt_embeds)

rng = np.random.RandomState(params.seed)
result = pipe(
Expand Down
9 changes: 5 additions & 4 deletions api/onnx_web/chain/source_txt2img.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,10 +130,11 @@ def run(
)
else:
# encode and record alternative prompts outside of LPW
prompt_embeds = encode_prompt(
pipe, prompt_pairs, params.batch, params.do_cfg()
)
pipe.unet.set_prompts(prompt_embeds)
if not params.is_xl():
prompt_embeds = encode_prompt(
pipe, prompt_pairs, params.batch, params.do_cfg()
)
pipe.unet.set_prompts(prompt_embeds)

rng = np.random.RandomState(params.seed)
result = pipe(
Expand Down
9 changes: 5 additions & 4 deletions api/onnx_web/chain/upscale_outpaint.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,10 +99,11 @@ def run(
)
else:
# encode and record alternative prompts outside of LPW
prompt_embeds = encode_prompt(
pipe, prompt_pairs, params.batch, params.do_cfg()
)
pipe.unet.set_prompts(prompt_embeds)
if not params.is_xl():
prompt_embeds = encode_prompt(
pipe, prompt_pairs, params.batch, params.do_cfg()
)
pipe.unet.set_prompts(prompt_embeds)

rng = np.random.RandomState(params.seed)
result = pipe(
Expand Down
15 changes: 8 additions & 7 deletions api/onnx_web/chain/upscale_stable_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,13 +48,14 @@ def run(
)
generator = torch.manual_seed(params.seed)

prompt_embeds = encode_prompt(
pipeline,
prompt_pairs,
num_images_per_prompt=params.batch,
do_classifier_free_guidance=params.do_cfg(),
)
pipeline.unet.set_prompts(prompt_embeds)
if not params.is_xl():
prompt_embeds = encode_prompt(
pipeline,
prompt_pairs,
num_images_per_prompt=params.batch,
do_classifier_free_guidance=params.do_cfg(),
)
pipeline.unet.set_prompts(prompt_embeds)

outputs = []
for source in sources:
Expand Down

0 comments on commit 3ffbc00

Please sign in to comment.