Skip to content

Commit

Permalink
feat(api): add support for wildcards
Browse files Browse the repository at this point in the history
  • Loading branch information
ssube committed Jul 4, 2023
1 parent b816307 commit c8a9dd4
Show file tree
Hide file tree
Showing 6 changed files with 88 additions and 3 deletions.
2 changes: 1 addition & 1 deletion api/onnx_web/chain/source_txt2img.py
Expand Up @@ -88,4 +88,4 @@ def run(
callback=callback,
)

return result.images
return result.images
30 changes: 29 additions & 1 deletion api/onnx_web/diffusers/utils.py
@@ -1,7 +1,8 @@
import random
from logging import getLogger
from math import ceil
from re import Pattern, compile
from typing import List, Optional, Tuple
from typing import Dict, List, Optional, Tuple

import numpy as np
import torch
Expand All @@ -18,6 +19,8 @@
CLIP_TOKEN = compile(r"\<clip:([-\w]+):(\d+)\>")
INVERSION_TOKEN = compile(r"\<inversion:([^:\>]+):(-?[\.|\d]+)\>")
LORA_TOKEN = compile(r"\<lora:([^:\>]+):(-?[\.|\d]+)\>")
WILDCARD_TOKEN = compile(r"__([-\w]+)__")

INTERVAL_RANGE = compile(r"(\w+)-{(\d+),(\d+)(?:,(\d+))?}")
ALTERNATIVE_RANGE = compile(r"\(([^\)]+)\)")

Expand Down Expand Up @@ -361,3 +364,28 @@ def encode_prompt(
)
for prompt, neg_prompt in prompt_pairs
]


def replace_wildcards(prompt: str, seed: int, wildcards: Dict[str, List[str]]) -> str:
next_match = WILDCARD_TOKEN.search(prompt)
remaining_prompt = prompt

random.seed(seed)

while next_match is not None:
logger.debug("found wildcard in prompt: %s", next_match)
name = next_match.groups()

if name not in wildcards:
logger.warning("unknown wildcard: %s", name)
continue

wildcard = random.choice(wildcards.get(name))

remaining_prompt = (
remaining_prompt[: next_match.start()]
+ wildcard
+ remaining_prompt[next_match.end() :]
)

return remaining_prompt
2 changes: 2 additions & 0 deletions api/onnx_web/main.py
Expand Up @@ -21,6 +21,7 @@
load_models,
load_params,
load_platforms,
load_wildcards,
)
from .server.static import register_static_routes
from .server.utils import check_paths
Expand All @@ -46,6 +47,7 @@ def main():
load_models(server)
load_params(server)
load_platforms(server)
load_wildcards(server)

if is_debug():
gc.set_debug(gc.DEBUG_STATS)
Expand Down
3 changes: 2 additions & 1 deletion api/onnx_web/server/admin.py
Expand Up @@ -6,7 +6,7 @@
from ..utils import load_config, load_config_str
from ..worker.pool import DevicePoolExecutor
from .context import ServerContext
from .load import load_extras, load_models
from .load import load_extras, load_models, load_wildcards
from .utils import wrap_route

logger = getLogger(__name__)
Expand Down Expand Up @@ -88,6 +88,7 @@ def update_extra_models(server: ServerContext):

logger.info("finished converting models, reloading server")
load_models(server)
load_wildcards(server)
load_extras(server)

conversion_lock = False
Expand Down
32 changes: 32 additions & 0 deletions api/onnx_web/server/api.py
Expand Up @@ -15,6 +15,7 @@
run_txt2img_pipeline,
run_upscale_pipeline,
)
from ..diffusers.utils import replace_wildcards
from ..output import json_params, make_output_name
from ..params import Border, StageParams, TileOrder, UpscaleParams
from ..transformers.run import run_txt2txt_pipeline
Expand Down Expand Up @@ -44,6 +45,7 @@
get_noise_sources,
get_source_filters,
get_upscaling_models,
get_wildcard_data,
)
from .params import (
border_from_request,
Expand Down Expand Up @@ -177,6 +179,12 @@ def img2img(server: ServerContext, pool: DevicePoolExecutor):
get_config_value("strength", "min"),
)

params.prompt = replace_wildcards(params.prompt, params.seed, get_wildcard_data())
if params.negative_prompt is not None:
params.negative_prompt = replace_wildcards(
params.negative_prompt, params.seed, get_wildcard_data()
)

output_count = params.batch
if source_filter is not None and source_filter != "none":
logger.debug(
Expand Down Expand Up @@ -213,6 +221,12 @@ def txt2img(server: ServerContext, pool: DevicePoolExecutor):
upscale = upscale_from_request()
highres = highres_from_request()

params.prompt = replace_wildcards(params.prompt, params.seed, get_wildcard_data())
if params.negative_prompt is not None:
params.negative_prompt = replace_wildcards(
params.negative_prompt, params.seed, get_wildcard_data()
)

output = make_output_name(server, "txt2img", params, size)

job_name = output[0]
Expand Down Expand Up @@ -257,6 +271,12 @@ def inpaint(server: ServerContext, pool: DevicePoolExecutor):
request.args, "tileOrder", [TileOrder.grid, TileOrder.kernel, TileOrder.spiral]
)

params.prompt = replace_wildcards(params.prompt, params.seed, get_wildcard_data())
if params.negative_prompt is not None:
params.negative_prompt = replace_wildcards(
params.negative_prompt, params.seed, get_wildcard_data()
)

output = make_output_name(
server,
"inpaint",
Expand Down Expand Up @@ -314,6 +334,12 @@ def upscale(server: ServerContext, pool: DevicePoolExecutor):
upscale = upscale_from_request()
highres = highres_from_request()

params.prompt = replace_wildcards(params.prompt, params.seed, get_wildcard_data())
if params.negative_prompt is not None:
params.negative_prompt = replace_wildcards(
params.negative_prompt, params.seed, get_wildcard_data()
)

output = make_output_name(server, "upscale", params, size)

job_name = output[0]
Expand Down Expand Up @@ -354,6 +380,12 @@ def chain(server: ServerContext, pool: DevicePoolExecutor):
output = make_output_name(server, "chain", params, size)
job_name = output[0]

params.prompt = replace_wildcards(params.prompt, params.seed, get_wildcard_data())
if params.negative_prompt is not None:
params.negative_prompt = replace_wildcards(
params.negative_prompt, params.seed, get_wildcard_data()
)

pipeline = ChainPipeline()
for stage_data in data.get("stages", []):
stage_class = CHAIN_STAGES[stage_data.get("type")]
Expand Down
22 changes: 22 additions & 0 deletions api/onnx_web/server/load.py
Expand Up @@ -90,6 +90,7 @@
diffusion_models: List[str] = []
network_models: List[NetworkModel] = []
upscaling_models: List[str] = []
wildcard_data: Dict[str, List[str]] = []

# Loaded from extra_models
extra_hashes: Dict[str, str] = {}
Expand Down Expand Up @@ -120,6 +121,10 @@ def get_upscaling_models():
return upscaling_models


def get_wildcard_data():
return wildcard_data


def get_extra_strings():
return extra_strings

Expand Down Expand Up @@ -444,3 +449,20 @@ def any_first_cpu_last(a: DeviceParams, b: DeviceParams):
"available acceleration platforms: %s",
", ".join([str(p) for p in available_platforms]),
)


def load_wildcards(server: ServerContext) -> None:
global wildcard_data

# simple wildcards
wildcard_files = list_model_globs(
server, ["*.txt"], base_path=path.join(server.model_path, "wildcard")
)

for file in wildcard_files:
with open(file, "r") as f:
lines = f.read().splitlines()
logger.debug("loading wildcards from %s: %s", file, lines)
wildcard_data[file] = lines

# TODO: structured wildcards

0 comments on commit c8a9dd4

Please sign in to comment.