Skip to content

Commit

Permalink
feat: allow users to add their own labels for models (#144)
Browse files Browse the repository at this point in the history
  • Loading branch information
ssube committed Mar 5, 2023
1 parent 628812f commit 5d459ab
Show file tree
Hide file tree
Showing 8 changed files with 123 additions and 10 deletions.
20 changes: 13 additions & 7 deletions api/onnx_web/convert/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -353,20 +353,26 @@ def main() -> int:
logger.info("converting base models")
convert_models(ctx, args, base_models)

for file in args.extras:
extras = []
extras.extend(ctx.extra_models)
extras.extend(args.extras)
extras = list(set(extras))
extras.sort()
logger.debug("loading extra files: %s", extras)

with open("./schemas/extras.yaml", "r") as f:
extra_schema = safe_load(f.read())

for file in extras:
if file is not None and file != "":
logger.info("loading extra models from %s", file)
try:
with open(file, "r") as f:
data = safe_load(f.read())

with open("./schemas/extras.yaml", "r") as f:
schema = safe_load(f.read())

logger.debug("validating chain request: %s against %s", data, schema)

logger.debug("validating extras file %s", data)
try:
validate(data, schema)
validate(data, extra_schema)
logger.info("converting extra models")
convert_models(ctx, args, data)
except ValidationError as err:
Expand Down
5 changes: 5 additions & 0 deletions api/onnx_web/server/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,10 @@ def introspect(context: ServerContext, app: Flask):
}


def get_extra_strings(context: ServerContext):
return jsonify(get_extra_strings())


def list_mask_filters(context: ServerContext):
return jsonify(list(get_mask_filters().keys()))

Expand Down Expand Up @@ -464,6 +468,7 @@ def register_api_routes(app: Flask, context: ServerContext, pool: DevicePoolExec
app.route("/api/settings/params")(wrap_route(list_params, context)),
app.route("/api/settings/platforms")(wrap_route(list_platforms, context)),
app.route("/api/settings/schedulers")(wrap_route(list_schedulers, context)),
app.route("/api/settings/strings")(wrap_route(get_extra_strings, context)),
app.route("/api/img2img", methods=["POST"])(
wrap_route(img2img, context, pool=pool)
),
Expand Down
3 changes: 3 additions & 0 deletions api/onnx_web/server/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ def __init__(
cache_path: Optional[str] = None,
show_progress: bool = True,
optimizations: Optional[List[str]] = None,
extra_models: Optional[List[str]] = None,
) -> None:
self.bundle_path = bundle_path
self.model_path = model_path
Expand All @@ -40,6 +41,7 @@ def __init__(
self.cache_path = cache_path or path.join(model_path, ".cache")
self.show_progress = show_progress
self.optimizations = optimizations or []
self.extra_models = extra_models or []

@classmethod
def from_environ(cls):
Expand All @@ -63,4 +65,5 @@ def from_environ(cls):
cache=ModelCache(limit=cache_limit),
show_progress=get_boolean(environ, "ONNX_WEB_SHOW_PROGRESS", True),
optimizations=environ.get("ONNX_WEB_OPTIMIZATIONS", "").split(","),
extra_models=environ.get("ONNX_WEB_EXTRA_MODELS", "").split(","),
)
65 changes: 64 additions & 1 deletion api/onnx_web/server/load.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,14 @@
from glob import glob
from logging import getLogger
from os import path
from typing import Dict, List, Union
from typing import Any, Dict, List, Union
from jsonschema import ValidationError, validate

import torch
import yaml
from yaml import safe_load

from ..utils import merge
from ..image import ( # mask filters; noise sources
mask_filter_gaussian_multiply,
mask_filter_gaussian_screen,
Expand Down Expand Up @@ -58,6 +61,9 @@
inversion_models: List[str] = []
upscaling_models: List[str] = []

# Loaded from extra_models
extra_strings: Dict[str, Any] = {}


def get_config_params():
return config_params
Expand All @@ -83,6 +89,10 @@ def get_upscaling_models():
return upscaling_models


def get_extra_strings():
return extra_strings


def get_mask_filters():
return mask_filters

Expand All @@ -101,6 +111,59 @@ def get_model_name(model: str) -> str:
return file


def load_extras(context: ServerContext):
"""
Load the extras file(s) and collect the relevant parts for the server: labels and strings
"""
global extra_strings

labels = {}
strings = {}

with open("./schemas/extras.yaml", "r") as f:
extra_schema = safe_load(f.read())

for file in context.extra_models:
if file is not None and file != "":
logger.info("loading extra models from %s", file)
try:
with open(file, "r") as f:
data = safe_load(f.read())

logger.debug("validating extras file %s", data)
try:
validate(data, extra_schema)
except ValidationError as err:
logger.error("invalid data in extras file: %s", err)
continue

if "strings" in data:
logger.debug("collecting strings from %s", file)
merge(strings, data["strings"])

for model_type in ["diffusion", "correction", "upscaling"]:
if model_type in data:
for model in data[model_type]:
if "label" in model:
model_name = model["name"]
logger.debug("collecting label for model %s from %s", model_name, file)
labels[model_name] = model["label"]

except Exception as err:
logger.error("error loading extras file: %s", err)

logger.debug("adding labels to strings: %s", labels)
merge(strings, {
"en": {
"translation": {
"model": labels,
}
}
})

extra_strings = strings


def list_model_globs(context: ServerContext, globs: List[str]) -> List[str]:
models = []
for pattern in globs:
Expand Down
16 changes: 16 additions & 0 deletions api/onnx_web/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,3 +106,19 @@ def run_gc(devices: Optional[List[DeviceParams]] = None):

def sanitize_name(name):
return "".join(x for x in name if (x.isalnum() or x in SAFE_CHARS))


def merge(a, b, path=None):
"merges b into a"
if path is None: path = []
for key in b:
if key in a:
if isinstance(a[key], dict) and isinstance(b[key], dict):
merge(a[key], b[key], path + [str(key)])
elif a[key] == b[key]:
pass # same leaf value
else:
raise Exception("Conflict at %s" % '.'.join(path + [str(key)]))
else:
a[key] = b[key]
return a
9 changes: 8 additions & 1 deletion api/schemas/extras.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ $defs:
format:
type: string
enum: [concept, embeddings]
label:
type: string
token:
type: string

Expand All @@ -33,6 +35,8 @@ $defs:
enum: [onnx, pth, ckpt, safetensors]
half:
type: boolean
label:
type: string
name:
type: string
opset:
Expand Down Expand Up @@ -104,4 +108,7 @@ properties:
items:
oneOf:
- $ref: "#/$defs/legacy_tuple"
- $ref: "#/$defs/source_model"
- $ref: "#/$defs/source_model"
strings:
type: object
# /\w{2}/: translation: {}
10 changes: 10 additions & 0 deletions gui/src/client.ts
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,11 @@ export interface ApiClient {
*/
schedulers(): Promise<Array<string>>;

/**
* Load extra strings from the server.
*/
strings(): Promise<Record<string, unknown>>;

/**
* Start a txt2img pipeline.
*/
Expand Down Expand Up @@ -389,6 +394,11 @@ export function makeClient(root: string, f = fetch): ApiClient {
const res = await f(path);
return await res.json() as Array<string>;
},
async strings(): Promise<Record<string, unknown>> {
const path = makeApiUrl(root, 'settings', 'strings');
const res = await f(path);
return await res.json() as Record<string, unknown>;
},
async img2img(model: ModelParams, params: Img2ImgParams, upscale?: UpscaleParams): Promise<ImageResponse> {
const url = makeImageURL(root, 'img2img', params);
appendModelToURL(url, model);
Expand Down
5 changes: 4 additions & 1 deletion gui/src/main.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,10 @@ export async function main() {
returnEmptyString: false,
});

i18n.addResourceBundle(i18n.resolvedLanguage, 'model', params.model.keys);
const strings = await client.strings();
for (const [lang, data] of Object.entries(strings)) {
i18n.addResourceBundle(lang, 'translation', data, true);
}

// prep zustand with a slice for each tab, using local storage
const {
Expand Down

0 comments on commit 5d459ab

Please sign in to comment.