Skip to content

Commit

Permalink
feat(api): support wildcards in nested folders
Browse files Browse the repository at this point in the history
  • Loading branch information
ssube committed Jul 12, 2023
1 parent 1a084eb commit 865b25e
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 7 deletions.
2 changes: 1 addition & 1 deletion api/onnx_web/diffusers/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
CLIP_TOKEN = compile(r"\<clip:([-\w]+):(\d+)\>")
INVERSION_TOKEN = compile(r"\<inversion:([^:\>]+):(-?[\.|\d]+)\>")
LORA_TOKEN = compile(r"\<lora:([^:\>]+):(-?[\.|\d]+)\>")
WILDCARD_TOKEN = compile(r"__([-\w]+)__")
WILDCARD_TOKEN = compile(r"__([-/\w]+)__")

INTERVAL_RANGE = compile(r"(\w+)-{(\d+),(\d+)(?:,(\d+))?}")
ALTERNATIVE_RANGE = compile(r"\(([^\)]+)\)")
Expand Down
23 changes: 17 additions & 6 deletions api/onnx_web/server/load.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,17 +255,24 @@ def load_extras(server: ServerContext):


def list_model_globs(
server: ServerContext, globs: List[str], base_path: Optional[str] = None
server: ServerContext,
globs: List[str],
base_path: Optional[str] = None,
recursive=False,
filename_only=True,
) -> List[str]:
if base_path is None:
base_path = server.model_path

models = []
for pattern in globs:
pattern_path = path.join(base_path or server.model_path, pattern)
pattern_path = path.join(base_path, pattern)
logger.debug("loading models from %s", pattern_path)
for name in glob(pattern_path):
for name in glob(pattern_path, recursive=recursive):
base = path.basename(name)
(file, ext) = path.splitext(base)
if ext not in IGNORE_EXTENSIONS:
models.append(file)
models.append(file if filename_only else path.relpath(name, base_path))

unique_models = list(set(models))
unique_models.sort()
Expand Down Expand Up @@ -456,7 +463,11 @@ def load_wildcards(server: ServerContext) -> None:

# simple wildcards
wildcard_files = list_model_globs(
server, ["*.txt"], base_path=path.join(server.model_path, "wildcard")
server,
["**/*.txt"],
base_path=path.join(server.model_path, "wildcard"),
filename_only=False,
recursive=True,
)

for file in wildcard_files:
Expand All @@ -465,6 +476,6 @@ def load_wildcards(server: ServerContext) -> None:
lines = [line.strip() for line in lines if not line.startswith("#")]
lines = [line for line in lines if len(line) > 0]
logger.debug("loading wildcards from %s: %s", file, lines)
wildcard_data[file] = lines
wildcard_data[path.splitext(file)[0]] = lines

# TODO: structured wildcards

0 comments on commit 865b25e

Please sign in to comment.