Skip to content

Commit

Permalink
feat(api): parse LoRA names from prompt
Browse files Browse the repository at this point in the history
  • Loading branch information
ssube committed Mar 15, 2023
1 parent 03f4e1b commit 143904f
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 7 deletions.
21 changes: 16 additions & 5 deletions api/onnx_web/diffusers/load.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from logging import getLogger
from os import path
from re import compile
from typing import Any, List, Optional, Tuple

import numpy as np
Expand Down Expand Up @@ -106,11 +107,21 @@ def get_tile_latents(
return full_latents[:, :, y:yt, x:xt]


def get_loras_from_prompt(prompt: str) -> List[str]:
return [
"arch",
"glass",
]
def get_loras_from_prompt(prompt: str) -> Tuple[str, List[str]]:
remaining_prompt = prompt
lora_expr = compile(r"\<lora:(\w+):([\.|\d]+)\>")

loras = []
next_match = lora_expr.search(remaining_prompt)
while next_match is not None:
logger.debug("found LoRA token in prompt: %s", next_match)
name, weight = next_match.groups()
loras.append(name)
# remove this match and look for another
remaining_prompt = remaining_prompt[:next_match.start()] + remaining_prompt[next_match.end():]
next_match = lora_expr.search(remaining_prompt)

return (remaining_prompt, loras)


def optimize_pipeline(
Expand Down
9 changes: 7 additions & 2 deletions api/onnx_web/diffusers/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,10 @@ def run_txt2img_pipeline(
upscale: UpscaleParams,
) -> None:
latents = get_latents_from_seed(params.seed, size, batch=params.batch)
loras = get_loras_from_prompt(params.prompt)

(prompt, loras) = get_loras_from_prompt(params.prompt)
params.prompt = prompt

pipe = load_pipeline(
server,
OnnxStableDiffusionPipeline,
Expand Down Expand Up @@ -101,7 +104,9 @@ def run_img2img_pipeline(
source: Image.Image,
strength: float,
) -> None:
loras = get_loras_from_prompt(params.prompt)
(prompt, loras) = get_loras_from_prompt(params.prompt)
params.prompt = prompt

pipe = load_pipeline(
server,
OnnxStableDiffusionImg2ImgPipeline,
Expand Down

0 comments on commit 143904f

Please sign in to comment.