Skip to content

Commit

Permalink
fix(api): parse config files as JSON or YAML depending on extension (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
ssube committed May 4, 2023
1 parent ac4aef6 commit b6692f0
Show file tree
Hide file tree
Showing 5 changed files with 48 additions and 57 deletions.
9 changes: 3 additions & 6 deletions api/onnx_web/convert/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,9 @@
from jsonschema import ValidationError, validate
from onnx import load_model, save_model
from transformers import CLIPTokenizer
from yaml import safe_load

from ..constants import ONNX_MODEL, ONNX_WEIGHTS
from ..utils import load_config
from .correction.gfpgan import convert_correction_gfpgan
from .diffusion.control import convert_diffusion_control
from .diffusion.diffusers import convert_diffusion_diffusers
Expand Down Expand Up @@ -604,16 +604,13 @@ def main() -> int:
extras.sort()
logger.debug("loading extra files: %s", extras)

with open("./schemas/extras.yaml", "r") as f:
extra_schema = safe_load(f.read())
extra_schema = load_config("./schemas/extras.yaml")

for file in extras:
if file is not None and file != "":
logger.info("loading extra models from %s", file)
try:
with open(file, "r") as f:
data = safe_load(f.read())

data = load_config(file)
logger.debug("validating extras file %s", data)
try:
validate(data, extra_schema)
Expand Down
30 changes: 0 additions & 30 deletions api/onnx_web/convert/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
from onnxruntime.transformers.float16 import convert_float_to_float16
from packaging import version
from torch.onnx import export
from yaml import safe_load

from ..constants import ONNX_WEIGHTS
from ..server import ServerContext
Expand Down Expand Up @@ -171,35 +170,6 @@ def source_format(model: Dict) -> Optional[str]:
return None


class Config(object):
"""
Shim for pydantic-style config.
"""

def __init__(self, kwargs):
self.__dict__.update(kwargs)
for k, v in self.__dict__.items():
Config.config_from_key(self, k, v)

def __iter__(self):
for k in self.__dict__.keys():
yield k

@classmethod
def config_from_key(cls, target, k, v):
if isinstance(v, dict):
tmp = Config(v)
setattr(target, k, tmp)
else:
setattr(target, k, v)


def load_yaml(file: str) -> Config:
with open(file, "r") as f:
data = safe_load(f.read())
return Config(data)


def remove_prefix(name: str, prefix: str) -> str:
if name.startswith(prefix):
return name[len(prefix) :]
Expand Down
8 changes: 4 additions & 4 deletions api/onnx_web/server/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
from logging import getLogger
from os import path

import yaml
from flask import Flask, jsonify, make_response, request, url_for
from jsonschema import validate
from PIL import Image
Expand All @@ -28,6 +27,8 @@
get_from_map,
get_not_empty,
get_size,
load_config,
load_config_str,
sanitize_name,
)
from ..worker.pool import DevicePoolExecutor
Expand Down Expand Up @@ -352,9 +353,8 @@ def chain(server: ServerContext, pool: DevicePoolExecutor):
if body is None:
return error_reply("chain pipeline must have a body")

data = yaml.safe_load(body)
with open("./schemas/chain.yaml", "r") as f:
schema = yaml.safe_load(f.read())
data = load_config_str(body)
schema = load_config("./schemas/chain.yaml")

logger.debug("validating chain request: %s against %s", data, schema)
validate(data, schema)
Expand Down
28 changes: 11 additions & 17 deletions api/onnx_web/server/load.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,7 @@
from typing import Any, Dict, List, Optional, Union

import torch
import yaml
from jsonschema import ValidationError, validate
from yaml import safe_load

from ..image import ( # mask filters; noise sources
mask_filter_gaussian_multiply,
Expand Down Expand Up @@ -35,7 +33,7 @@
from ..models.meta import NetworkModel
from ..params import DeviceParams
from ..torch_before_ort import get_available_providers
from ..utils import merge
from ..utils import load_config, merge
from .context import ServerContext

logger = getLogger(__name__)
Expand Down Expand Up @@ -154,16 +152,13 @@ def load_extras(server: ServerContext):
labels = {}
strings = {}

with open("./schemas/extras.yaml", "r") as f:
extra_schema = safe_load(f.read())
extra_schema = load_config("./schemas/extras.yaml")

for file in server.extra_models:
if file is not None and file != "":
logger.info("loading extra models from %s", file)
try:
with open(file, "r") as f:
data = safe_load(f.read())

data = load_config(file)
logger.debug("validating extras file %s", data)
try:
validate(data, extra_schema)
Expand Down Expand Up @@ -349,16 +344,15 @@ def load_params(server: ServerContext) -> None:
params_file = path.join(server.params_path, "params.json")
logger.debug("loading server parameters from file: %s", params_file)

with open(params_file, "r") as f:
config_params = yaml.safe_load(f)
config_params = load_config(params_file)

if "platform" in config_params and server.default_platform is not None:
logger.info(
"overriding default platform from environment: %s",
server.default_platform,
)
config_platform = config_params.get("platform", {})
config_platform["default"] = server.default_platform
if "platform" in config_params and server.default_platform is not None:
logger.info(
"overriding default platform from environment: %s",
server.default_platform,
)
config_platform = config_params.get("platform", {})
config_platform["default"] = server.default_platform


def load_platforms(server: ServerContext) -> None:
Expand Down
30 changes: 30 additions & 0 deletions api/onnx_web/utils.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
import gc
import importlib
import json
import threading
from json import JSONDecodeError
from logging import getLogger
from os import environ, path
from platform import system
from typing import Any, Dict, List, Optional, Sequence, Union

import torch
from yaml import safe_load

from .params import DeviceParams, SizeChart

Expand Down Expand Up @@ -167,3 +170,30 @@ def show_system_toast(msg: str) -> None:
)
else:
logger.info("system notifications not yet available for %s", sys_name)


def load_json(file: str) -> Dict:
with open(file, "r") as f:
data = json.loads(f.read())
return data


def load_yaml(file: str) -> Dict:
with open(file, "r") as f:
data = safe_load(f.read())
return data


def load_config(file: str) -> Dict:
name, ext = path.splitext(file)
if ext in [".yml", ".yaml"]:
return load_yaml(file)
elif ext in [".json"]:
return load_json(file)


def load_config_str(raw: str) -> Dict:
try:
return json.loads(raw)
except JSONDecodeError:
return safe_load(raw)

0 comments on commit b6692f0

Please sign in to comment.