Skip to content

Commit

Permalink
feat(api): add an option for custom checkpoint config to extras file (f…
Browse files Browse the repository at this point in the history
…ixes #130)
  • Loading branch information
ssube committed Feb 12, 2023
1 parent 82487f5 commit d6201c9
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 3 deletions.
10 changes: 7 additions & 3 deletions api/onnx_web/convert/diffusion_original.py
Original file line number Diff line number Diff line change
Expand Up @@ -1110,7 +1110,9 @@ def get_config_path(
return os.path.abspath(parts)


def get_config_file(train_unfrozen=False, v2=False, prediction_type="epsilon"):
def get_config_file(train_unfrozen=False, v2=False, prediction_type="epsilon", config_file=None):
if config_file is not None:
return config_file

config_base_name = "training"

Expand Down Expand Up @@ -1142,6 +1144,7 @@ def extract_checkpoint(
extract_ema=False,
train_unfrozen=False,
is_512=True,
config_file=None,
):
"""
Expand Down Expand Up @@ -1229,7 +1232,7 @@ def extract_checkpoint(
else:
prediction_type = "epsilon"

original_config_file = get_config_file(train_unfrozen, v2, prediction_type)
original_config_file = get_config_file(train_unfrozen, v2, prediction_type, config_file=config_file)

logger.info(f"Pred and size are {prediction_type} and {image_size}, using config: {original_config_file}")
db_config.resolution = image_size
Expand Down Expand Up @@ -1406,6 +1409,7 @@ def convert_diffusion_original(
model: ModelDict,
source: str,
):
config = model["config"]
name = model["name"]
source = source or model["source"]

Expand All @@ -1424,7 +1428,7 @@ def convert_diffusion_original(
logger.info("Torch pipeline already exists, reusing: %s", torch_path)
else:
logger.info("Converting original Diffusers check to Torch model: %s -> %s", source, torch_path)
extract_checkpoint(ctx, torch_name, source)
extract_checkpoint(ctx, torch_name, source, config_file=config)
logger.info("Converted original Diffusers checkpoint to Torch model.")

convert_diffusion_stable(ctx, model, working_name)
Expand Down
4 changes: 4 additions & 0 deletions api/schemas/extras.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,10 @@ $defs:
diffusion_model:
allOf:
- $ref: "#/$defs/base_model"
- type: object
properties:
config:
type: string

upscaling_model:
allOf:
Expand Down

0 comments on commit d6201c9

Please sign in to comment.