Skip to content

Commit

Permalink
Avoid overriding model_type in TasksManager (huggingface#1647)
Browse files Browse the repository at this point in the history
* avoid modifying model_type

* cleanup

* fix test

* fix test

* fix library detection local model

* fix merge

* make library_name non-optional

* fix warning

* trigger ci

* fix library detection
  • Loading branch information
fxmarty authored and young-developer committed May 10, 2024
1 parent bd205dc commit d2c82e0
Show file tree
Hide file tree
Showing 9 changed files with 210 additions and 92 deletions.
2 changes: 1 addition & 1 deletion optimum/commands/export/onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ def parse_args_onnx(parser):
optional_group.add_argument(
"--library-name",
type=str,
choices=["transformers", "diffusers", "timm"],
choices=["transformers", "diffusers", "timm", "sentence_transformers"],
default=None,
help=("The library on the model." " If not provided, will attempt to infer the local checkpoint's library"),
)
Expand Down
35 changes: 25 additions & 10 deletions optimum/exporters/onnx/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,17 +68,16 @@ def _get_submodels_and_onnx_configs(
custom_onnx_configs: Dict,
custom_architecture: bool,
_variant: str,
library_name: str,
int_dtype: str = "int64",
float_dtype: str = "fp32",
fn_get_submodels: Optional[Callable] = None,
preprocessors: Optional[List[Any]] = None,
legacy: bool = False,
library_name: str = "transformers",
model_kwargs: Optional[Dict] = None,
):
is_stable_diffusion = "stable-diffusion" in task
if not custom_architecture:
if is_stable_diffusion:
if library_name == "diffusers":
onnx_config = None
models_and_onnx_configs = get_stable_diffusion_models_for_export(
model, int_dtype=int_dtype, float_dtype=float_dtype
Expand Down Expand Up @@ -129,7 +128,7 @@ def _get_submodels_and_onnx_configs(
if fn_get_submodels is not None:
submodels_for_export = fn_get_submodels(model)
else:
if is_stable_diffusion:
if library_name == "diffusers":
submodels_for_export = _get_submodels_for_export_stable_diffusion(model)
elif (
model.config.is_encoder_decoder
Expand Down Expand Up @@ -373,12 +372,16 @@ def main_export(

if model_type not in TasksManager._SUPPORTED_MODEL_TYPE:
custom_architecture = True
elif task not in TasksManager.get_supported_tasks_for_model_type(model_type, "onnx"):
elif task not in TasksManager.get_supported_tasks_for_model_type(
model_type, "onnx", library_name=library_name
):
if original_task == "auto":
autodetected_message = " (auto-detected)"
else:
autodetected_message = ""
model_tasks = TasksManager.get_supported_tasks_for_model_type(model_type, exporter="onnx")
model_tasks = TasksManager.get_supported_tasks_for_model_type(
model_type, exporter="onnx", library_name=library_name
)
raise ValueError(
f"Asked to export a {model_type} model for the task {task}{autodetected_message}, but the Optimum ONNX exporter only supports the tasks {', '.join(model_tasks.keys())} for {model_type}. Please use a supported task. Please open an issue at https://github.com/huggingface/optimum/issues if you would like the task {task} to be supported in the ONNX export for {model_type}."
)
Expand Down Expand Up @@ -422,7 +425,13 @@ def main_export(
"Could not infer the pad token id, which is needed in this case, please provide it with the --pad_token_id argument"
)

model_type = "stable-diffusion" if "stable-diffusion" in task else model.config.model_type.replace("_", "-")
if "stable-diffusion" in task:
model_type = "stable-diffusion"
elif hasattr(model.config, "export_model_type"):
model_type = model.config.export_model_type.replace("_", "-")
else:
model_type = model.config.model_type.replace("_", "-")

if (
not custom_architecture
and library_name != "diffusers"
Expand Down Expand Up @@ -513,14 +522,20 @@ def onnx_export(
else:
float_dtype = "fp32"

model_type = "stable-diffusion" if library_name == "diffusers" else model.config.model_type.replace("_", "-")
if "stable-diffusion" in task:
model_type = "stable-diffusion"
elif hasattr(model.config, "export_model_type"):
model_type = model.config.export_model_type.replace("_", "-")
else:
model_type = model.config.model_type.replace("_", "-")

custom_architecture = library_name == "transformers" and model_type not in TasksManager._SUPPORTED_MODEL_TYPE
task = TasksManager.map_from_synonym(task)

# TODO: support onnx_config.py in the model repo
if custom_architecture and custom_onnx_configs is None:
raise ValueError(
f"Trying to export a {model.config.model_type} model, that is a custom or unsupported architecture, but no custom onnx configuration was passed as `custom_onnx_configs`. Please refer to https://huggingface.co/docs/optimum/main/en/exporters/onnx/usage_guides/export_a_model#custom-export-of-transformers-models for an example on how to export custom models. Please open an issue at https://github.com/huggingface/optimum/issues if you would like the model type {model.config.model_type} to be supported natively in the ONNX export."
f"Trying to export a {model_type} model, that is a custom or unsupported architecture, but no custom onnx configuration was passed as `custom_onnx_configs`. Please refer to https://huggingface.co/docs/optimum/main/en/exporters/onnx/usage_guides/export_a_model#custom-export-of-transformers-models for an example on how to export custom models. Please open an issue at https://github.com/huggingface/optimum/issues if you would like the model type {model_type} to be supported natively in the ONNX export."
)

if task is None:
Expand Down Expand Up @@ -690,7 +705,7 @@ def onnx_export(
if library_name == "diffusers":
# TODO: fix Can't pickle local object 'get_stable_diffusion_models_for_export.<locals>.<lambda>'
use_subprocess = False
elif model.config.model_type in UNPICKABLE_ARCHS:
elif model_type in UNPICKABLE_ARCHS:
# Pickling is bugged for nn.utils.weight_norm: https://github.com/pytorch/pytorch/issues/102983
# TODO: fix "Cowardly refusing to serialize non-leaf tensor" error for wav2vec2-conformer
use_subprocess = False
Expand Down
10 changes: 8 additions & 2 deletions optimum/exporters/onnx/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -344,7 +344,10 @@ def __init__(

# Set up the encoder ONNX config.
encoder_onnx_config_constructor = TasksManager.get_exporter_config_constructor(
exporter="onnx", task="feature-extraction", model_type=config.encoder.model_type
exporter="onnx",
task="feature-extraction",
model_type=config.encoder.model_type,
library_name="transformers",
)
self._encoder_onnx_config = encoder_onnx_config_constructor(
config.encoder, int_dtype=int_dtype, float_dtype=float_dtype, preprocessors=preprocessors
Expand All @@ -353,7 +356,10 @@ def __init__(

# Set up the decoder ONNX config.
decoder_onnx_config_constructor = TasksManager.get_exporter_config_constructor(
exporter="onnx", task="feature-extraction", model_type=config.decoder.model_type
exporter="onnx",
task="feature-extraction",
model_type=config.decoder.model_type,
library_name="transformers",
)
kwargs = {}
if issubclass(decoder_onnx_config_constructor.func, OnnxConfigWithPast):
Expand Down
5 changes: 5 additions & 0 deletions optimum/exporters/onnx/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -323,6 +323,7 @@ def get_stable_diffusion_models_for_export(
text_encoder_config_constructor = TasksManager.get_exporter_config_constructor(
model=pipeline.text_encoder,
exporter="onnx",
library_name="diffusers",
task="feature-extraction",
)
text_encoder_onnx_config = text_encoder_config_constructor(
Expand All @@ -334,6 +335,7 @@ def get_stable_diffusion_models_for_export(
onnx_config_constructor = TasksManager.get_exporter_config_constructor(
model=pipeline.unet,
exporter="onnx",
library_name="diffusers",
task="semantic-segmentation",
model_type="unet",
)
Expand All @@ -345,6 +347,7 @@ def get_stable_diffusion_models_for_export(
vae_config_constructor = TasksManager.get_exporter_config_constructor(
model=vae_encoder,
exporter="onnx",
library_name="diffusers",
task="semantic-segmentation",
model_type="vae-encoder",
)
Expand All @@ -356,6 +359,7 @@ def get_stable_diffusion_models_for_export(
vae_config_constructor = TasksManager.get_exporter_config_constructor(
model=vae_decoder,
exporter="onnx",
library_name="diffusers",
task="semantic-segmentation",
model_type="vae-decoder",
)
Expand All @@ -366,6 +370,7 @@ def get_stable_diffusion_models_for_export(
onnx_config_constructor = TasksManager.get_exporter_config_constructor(
model=pipeline.text_encoder_2,
exporter="onnx",
library_name="diffusers",
task="feature-extraction",
model_type="clip-text-with-projection",
)
Expand Down

0 comments on commit d2c82e0

Please sign in to comment.