Skip to content

Commit

Permalink
fix(api): continue converting other models after an error in one (#166)
Browse files Browse the repository at this point in the history
  • Loading branch information
ssube committed Feb 17, 2023
1 parent b3c8fce commit c74d22a
Showing 1 changed file with 40 additions and 26 deletions.
66 changes: 40 additions & 26 deletions api/onnx_web/convert/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,8 +194,12 @@ def convert_models(ctx: ConversionContext, args, models: Models):
else:
model_format = source_format(model)
source = model["source"]
dest = fetch_model(ctx, name, source, model_format=model_format)
logger.info("finished downloading source: %s -> %s", source, dest)

try:
dest = fetch_model(ctx, name, source, model_format=model_format)
logger.info("finished downloading source: %s -> %s", source, dest)
except Exception as e:
logger.error("error fetching source %s: %s", name, e)

if args.diffusion and "diffusion" in models:
for model in models.get("diffusion"):
Expand All @@ -206,23 +210,27 @@ def convert_models(ctx: ConversionContext, args, models: Models):
logger.info("skipping model: %s", name)
else:
model_format = source_format(model)
source = fetch_model(
ctx, name, model["source"], model_format=model_format
)

if model_format in model_formats_original:
convert_diffusion_original(
ctx,
model,
source,
)
else:
convert_diffusion_stable(
ctx,
model,
source,

try:
source = fetch_model(
ctx, name, model["source"], model_format=model_format
)

if model_format in model_formats_original:
convert_diffusion_original(
ctx,
model,
source,
)
else:
convert_diffusion_stable(
ctx,
model,
source,
)
except Exception as e:
logger.error("error converting diffusion model %s: %s", name, e)

if args.upscaling and "upscaling" in models:
for model in models.get("upscaling"):
model = tuple_to_upscaling(model)
Expand All @@ -232,10 +240,14 @@ def convert_models(ctx: ConversionContext, args, models: Models):
logger.info("skipping model: %s", name)
else:
model_format = source_format(model)
source = fetch_model(
ctx, name, model["source"], model_format=model_format
)
convert_upscale_resrgan(ctx, model, source)

try:
source = fetch_model(
ctx, name, model["source"], model_format=model_format
)
convert_upscale_resrgan(ctx, model, source)
except Exception as e:
logger.error("error converting upscaling model %s: %s", name, e)

if args.correction and "correction" in models:
for model in models.get("correction"):
Expand All @@ -246,11 +258,13 @@ def convert_models(ctx: ConversionContext, args, models: Models):
logger.info("skipping model: %s", name)
else:
model_format = source_format(model)
source = fetch_model(
ctx, name, model["source"], model_format=model_format
)
convert_correction_gfpgan(ctx, model, source)

try:
source = fetch_model(
ctx, name, model["source"], model_format=model_format
)
convert_correction_gfpgan(ctx, model, source)
except Exception as e:
logger.error("error converting correction model %s: %s", name, e)

def main() -> int:
parser = ArgumentParser(
Expand Down

0 comments on commit c74d22a

Please sign in to comment.