Skip to content

Commit

Permalink
feat(api): provide a way for users to add models to the convert list (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
ssube committed Feb 2, 2023
1 parent 0050cea commit c837830
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 15 deletions.
8 changes: 8 additions & 0 deletions api/extras.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
{
"diffusion": [
["diffusion-knollingcase", "Aybeeceedee/knollingcase"],
["diffusion-openjourney", "prompthero/openjourney"]
],
"correction": [],
"upscaling": []
}
29 changes: 14 additions & 15 deletions api/onnx_web/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
StableDiffusionPipeline,
StableDiffusionUpscalePipeline,
)
from json import loads
from logging import getLogger
from onnx import load, save_model
from os import environ, makedirs, mkdir, path
Expand Down Expand Up @@ -60,16 +61,6 @@
],
}

# other neat models
extra_models: Models = {
'diffusion': [
('diffusion-knollingcase', 'Aybeeceedee/knollingcase'),
('diffusion-openjourney', 'prompthero/openjourney'),
],
'correction': [],
'upscaling': [],
}

model_path = environ.get('ONNX_WEB_MODEL_PATH',
path.join('..', 'models'))
training_device = 'cuda' if torch.cuda.is_available() else 'cpu'
Expand Down Expand Up @@ -491,16 +482,18 @@ def main() -> int:
# model groups
parser.add_argument('--correction', action='store_true', default=False)
parser.add_argument('--diffusion', action='store_true', default=False)
parser.add_argument('--extras', action='store_true', default=False)
parser.add_argument('--upscaling', action='store_true', default=False)

# extra models
parser.add_argument('--extras', nargs='*', type=str, default=[])
parser.add_argument('--skip', nargs='*', type=str, default=[])

# export options
parser.add_argument(
'--half',
action='store_true',
default=False,
help='Export models for half precision, faster on some Nvidia cards'
help='Export models for half precision, faster on some Nvidia cards.'
)
parser.add_argument(
'--opset',
Expand All @@ -524,9 +517,15 @@ def main() -> int:
logger.info('Converting base models.')
load_models(args, base_models)

if args.extras:
logger.info('Converting extra models.')
load_models(args, extra_models)
for file in args.extras:
logger.info('Loading extra models from %s', file)
try:
with open(file, 'r') as f:
data = loads(f.read())
logger.info('Converting extra models.')
load_models(args, data)
except Exception as err:
logger.error('Error converting extra models: %s', err)

return 0

Expand Down

0 comments on commit c837830

Please sign in to comment.