Skip to content

Commit

Permalink
fix(api): store both pre-parse and parsed prompts (#320)
Browse files Browse the repository at this point in the history
  • Loading branch information
ssube committed Apr 22, 2023
1 parent 2a7a068 commit 6e7f202
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 13 deletions.
25 changes: 12 additions & 13 deletions api/onnx_web/diffusers/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,14 @@
logger = getLogger(__name__)


def parse_prompt(params: ImageParams) -> Tuple[List[Tuple[str, float]], List[Tuple[str, float]]]:
prompt, loras = get_loras_from_prompt(params.input_prompt)
prompt, inversions = get_inversions_from_prompt(prompt)
params.prompt = prompt

return loras, inversions


def run_highres(
job: WorkerContext,
server: ServerContext,
Expand Down Expand Up @@ -164,10 +172,7 @@ def run_txt2img_pipeline(
highres: HighresParams,
) -> None:
latents = get_latents_from_seed(params.seed, size, batch=params.batch)

(prompt, loras) = get_loras_from_prompt(params.prompt)
(prompt, inversions) = get_inversions_from_prompt(prompt)
params.prompt = prompt
loras, inversions = parse_prompt(params)

pipe_type = "lpw" if params.lpw() else "txt2img"
pipe = load_pipeline(
Expand Down Expand Up @@ -260,9 +265,7 @@ def run_img2img_pipeline(
strength: float,
source_filter: Optional[str] = None,
) -> None:
(prompt, loras) = get_loras_from_prompt(params.prompt)
(prompt, inversions) = get_inversions_from_prompt(prompt)
params.prompt = prompt
loras, inversions = parse_prompt(params)

# filter the source image
if source_filter is not None:
Expand Down Expand Up @@ -376,9 +379,7 @@ def run_inpaint_pipeline(
progress = job.get_progress_callback()
stage = StageParams(tile_order=tile_order)

(prompt, loras) = get_loras_from_prompt(params.prompt)
(prompt, inversions) = get_inversions_from_prompt(prompt)
params.prompt = prompt
loras, inversions = parse_prompt(params)

# calling the upscale_outpaint stage directly needs accumulating progress
progress = ChainProgress.from_progress(progress)
Expand Down Expand Up @@ -444,9 +445,7 @@ def run_upscale_pipeline(
progress = job.get_progress_callback()
stage = StageParams()

(prompt, loras) = get_loras_from_prompt(params.prompt)
(prompt, inversions) = get_inversions_from_prompt(prompt)
params.prompt = prompt
loras, inversions = parse_prompt(params)

image = run_upscale_correction(
job, server, stage, params, source, upscale=upscale, callback=progress
Expand Down
5 changes: 5 additions & 0 deletions api/onnx_web/params.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,7 @@ class ImageParams:
eta: float
batch: int
control: Optional[NetworkModel]
input_prompt: str

def __init__(
self,
Expand All @@ -187,6 +188,7 @@ def __init__(
eta: float = 0.0,
batch: int = 1,
control: Optional[NetworkModel] = None,
input_prompt: Optional[str] = None,
) -> None:
self.model = model
self.pipeline = pipeline
Expand All @@ -199,6 +201,7 @@ def __init__(
self.eta = eta
self.batch = batch
self.control = control
self.input_prompt = input_prompt or prompt

def lpw(self):
return self.pipeline == "lpw"
Expand All @@ -216,6 +219,7 @@ def tojson(self) -> Dict[str, Optional[Param]]:
"eta": self.eta,
"batch": self.batch,
"control": self.control.name if self.control is not None else "",
"input_prompt": self.input_prompt,
}

def with_args(self, **kwargs):
Expand All @@ -231,6 +235,7 @@ def with_args(self, **kwargs):
kwargs.get("eta", self.eta),
kwargs.get("batch", self.batch),
kwargs.get("control", self.control),
kwargs.get("input_prompt", self.input_prompt),
)


Expand Down

0 comments on commit 6e7f202

Please sign in to comment.