diff --git a/optimum/commands/export/onnx.py b/optimum/commands/export/onnx.py index aeda4dec9e..937ed1e50d 100644 --- a/optimum/commands/export/onnx.py +++ b/optimum/commands/export/onnx.py @@ -140,7 +140,7 @@ def parse_args_onnx(parser): optional_group.add_argument( "--library-name", type=str, - choices=["transformers", "diffusers", "timm"], + choices=["transformers", "diffusers", "timm", "sentence_transformers"], default=None, help=("The library on the model." " If not provided, will attempt to infer the local checkpoint's library"), ) diff --git a/optimum/exporters/onnx/__main__.py b/optimum/exporters/onnx/__main__.py index 120164f23a..8f6f96d54a 100644 --- a/optimum/exporters/onnx/__main__.py +++ b/optimum/exporters/onnx/__main__.py @@ -68,17 +68,16 @@ def _get_submodels_and_onnx_configs( custom_onnx_configs: Dict, custom_architecture: bool, _variant: str, + library_name: str, int_dtype: str = "int64", float_dtype: str = "fp32", fn_get_submodels: Optional[Callable] = None, preprocessors: Optional[List[Any]] = None, legacy: bool = False, - library_name: str = "transformers", model_kwargs: Optional[Dict] = None, ): - is_stable_diffusion = "stable-diffusion" in task if not custom_architecture: - if is_stable_diffusion: + if library_name == "diffusers": onnx_config = None models_and_onnx_configs = get_stable_diffusion_models_for_export( model, int_dtype=int_dtype, float_dtype=float_dtype @@ -129,7 +128,7 @@ def _get_submodels_and_onnx_configs( if fn_get_submodels is not None: submodels_for_export = fn_get_submodels(model) else: - if is_stable_diffusion: + if library_name == "diffusers": submodels_for_export = _get_submodels_for_export_stable_diffusion(model) elif ( model.config.is_encoder_decoder @@ -373,12 +372,16 @@ def main_export( if model_type not in TasksManager._SUPPORTED_MODEL_TYPE: custom_architecture = True - elif task not in TasksManager.get_supported_tasks_for_model_type(model_type, "onnx"): + elif task not in TasksManager.get_supported_tasks_for_model_type( + model_type, "onnx", library_name=library_name + ): if original_task == "auto": autodetected_message = " (auto-detected)" else: autodetected_message = "" - model_tasks = TasksManager.get_supported_tasks_for_model_type(model_type, exporter="onnx") + model_tasks = TasksManager.get_supported_tasks_for_model_type( + model_type, exporter="onnx", library_name=library_name + ) raise ValueError( f"Asked to export a {model_type} model for the task {task}{autodetected_message}, but the Optimum ONNX exporter only supports the tasks {', '.join(model_tasks.keys())} for {model_type}. Please use a supported task. Please open an issue at https://github.com/huggingface/optimum/issues if you would like the task {task} to be supported in the ONNX export for {model_type}." ) @@ -422,7 +425,13 @@ def main_export( "Could not infer the pad token id, which is needed in this case, please provide it with the --pad_token_id argument" ) - model_type = "stable-diffusion" if "stable-diffusion" in task else model.config.model_type.replace("_", "-") + if "stable-diffusion" in task: + model_type = "stable-diffusion" + elif hasattr(model.config, "export_model_type"): + model_type = model.config.export_model_type.replace("_", "-") + else: + model_type = model.config.model_type.replace("_", "-") + if ( not custom_architecture and library_name != "diffusers" @@ -513,14 +522,20 @@ def onnx_export( else: float_dtype = "fp32" - model_type = "stable-diffusion" if library_name == "diffusers" else model.config.model_type.replace("_", "-") + if "stable-diffusion" in task: + model_type = "stable-diffusion" + elif hasattr(model.config, "export_model_type"): + model_type = model.config.export_model_type.replace("_", "-") + else: + model_type = model.config.model_type.replace("_", "-") + custom_architecture = library_name == "transformers" and model_type not in TasksManager._SUPPORTED_MODEL_TYPE task = TasksManager.map_from_synonym(task) # TODO: support onnx_config.py in the model repo if custom_architecture and custom_onnx_configs is None: raise ValueError( - f"Trying to export a {model.config.model_type} model, that is a custom or unsupported architecture, but no custom onnx configuration was passed as `custom_onnx_configs`. Please refer to https://huggingface.co/docs/optimum/main/en/exporters/onnx/usage_guides/export_a_model#custom-export-of-transformers-models for an example on how to export custom models. Please open an issue at https://github.com/huggingface/optimum/issues if you would like the model type {model.config.model_type} to be supported natively in the ONNX export." + f"Trying to export a {model_type} model, that is a custom or unsupported architecture, but no custom onnx configuration was passed as `custom_onnx_configs`. Please refer to https://huggingface.co/docs/optimum/main/en/exporters/onnx/usage_guides/export_a_model#custom-export-of-transformers-models for an example on how to export custom models. Please open an issue at https://github.com/huggingface/optimum/issues if you would like the model type {model_type} to be supported natively in the ONNX export." ) if task is None: @@ -690,7 +705,7 @@ def onnx_export( if library_name == "diffusers": # TODO: fix Can't pickle local object 'get_stable_diffusion_models_for_export..' use_subprocess = False - elif model.config.model_type in UNPICKABLE_ARCHS: + elif model_type in UNPICKABLE_ARCHS: # Pickling is bugged for nn.utils.weight_norm: https://github.com/pytorch/pytorch/issues/102983 # TODO: fix "Cowardly refusing to serialize non-leaf tensor" error for wav2vec2-conformer use_subprocess = False diff --git a/optimum/exporters/onnx/config.py b/optimum/exporters/onnx/config.py index 2eaa78d85e..c505237948 100644 --- a/optimum/exporters/onnx/config.py +++ b/optimum/exporters/onnx/config.py @@ -344,7 +344,10 @@ def __init__( # Set up the encoder ONNX config. encoder_onnx_config_constructor = TasksManager.get_exporter_config_constructor( - exporter="onnx", task="feature-extraction", model_type=config.encoder.model_type + exporter="onnx", + task="feature-extraction", + model_type=config.encoder.model_type, + library_name="transformers", ) self._encoder_onnx_config = encoder_onnx_config_constructor( config.encoder, int_dtype=int_dtype, float_dtype=float_dtype, preprocessors=preprocessors @@ -353,7 +356,10 @@ def __init__( # Set up the decoder ONNX config. decoder_onnx_config_constructor = TasksManager.get_exporter_config_constructor( - exporter="onnx", task="feature-extraction", model_type=config.decoder.model_type + exporter="onnx", + task="feature-extraction", + model_type=config.decoder.model_type, + library_name="transformers", ) kwargs = {} if issubclass(decoder_onnx_config_constructor.func, OnnxConfigWithPast): diff --git a/optimum/exporters/onnx/utils.py b/optimum/exporters/onnx/utils.py index d927fee4ce..e5d6de2524 100644 --- a/optimum/exporters/onnx/utils.py +++ b/optimum/exporters/onnx/utils.py @@ -323,6 +323,7 @@ def get_stable_diffusion_models_for_export( text_encoder_config_constructor = TasksManager.get_exporter_config_constructor( model=pipeline.text_encoder, exporter="onnx", + library_name="diffusers", task="feature-extraction", ) text_encoder_onnx_config = text_encoder_config_constructor( @@ -334,6 +335,7 @@ def get_stable_diffusion_models_for_export( onnx_config_constructor = TasksManager.get_exporter_config_constructor( model=pipeline.unet, exporter="onnx", + library_name="diffusers", task="semantic-segmentation", model_type="unet", ) @@ -345,6 +347,7 @@ def get_stable_diffusion_models_for_export( vae_config_constructor = TasksManager.get_exporter_config_constructor( model=vae_encoder, exporter="onnx", + library_name="diffusers", task="semantic-segmentation", model_type="vae-encoder", ) @@ -356,6 +359,7 @@ def get_stable_diffusion_models_for_export( vae_config_constructor = TasksManager.get_exporter_config_constructor( model=vae_decoder, exporter="onnx", + library_name="diffusers", task="semantic-segmentation", model_type="vae-decoder", ) @@ -366,6 +370,7 @@ def get_stable_diffusion_models_for_export( onnx_config_constructor = TasksManager.get_exporter_config_constructor( model=pipeline.text_encoder_2, exporter="onnx", + library_name="diffusers", task="feature-extraction", model_type="clip-text-with-projection", ) diff --git a/optimum/exporters/tasks.py b/optimum/exporters/tasks.py index 515111f34a..2ac1caee57 100644 --- a/optimum/exporters/tasks.py +++ b/optimum/exporters/tasks.py @@ -292,6 +292,46 @@ class TasksManager: "timm": "default-timm-config", } + _DIFFUSERS_SUPPORTED_MODEL_TYPE = { + "clip-text-model": supported_tasks_mapping( + "feature-extraction", + onnx="CLIPTextOnnxConfig", + ), + "clip-text-with-projection": supported_tasks_mapping( + "feature-extraction", + onnx="CLIPTextWithProjectionOnnxConfig", + ), + "unet": supported_tasks_mapping( + "semantic-segmentation", + onnx="UNetOnnxConfig", + ), + "vae-encoder": supported_tasks_mapping( + "semantic-segmentation", + onnx="VaeEncoderOnnxConfig", + ), + "vae-decoder": supported_tasks_mapping( + "semantic-segmentation", + onnx="VaeDecoderOnnxConfig", + ), + } + + _TIMM_SUPPORTED_MODEL_TYPE = { + "default-timm-config": supported_tasks_mapping("image-classification", onnx="TimmDefaultOnnxConfig"), + } + + _SENTENCE_TRANSFORMERS_SUPPORTED_MODEL_TYPE = { + "clip": supported_tasks_mapping( + "feature-extraction", + "sentence-similarity", + onnx="SentenceTransformersCLIPOnnxConfig", + ), + "transformer": supported_tasks_mapping( + "feature-extraction", + "sentence-similarity", + onnx="SentenceTransformersTransformerOnnxConfig", + ), + } + # TODO: some models here support text-generation export but are not supported in ORTModelForCausalLM # Set of model topologies we support associated to the tasks supported by each topology and the factory _SUPPORTED_MODEL_TYPE = { @@ -404,14 +444,6 @@ class TasksManager: "zero-shot-image-classification", onnx="CLIPOnnxConfig", ), - "clip-text-model": supported_tasks_mapping( - "feature-extraction", - onnx="CLIPTextOnnxConfig", - ), - "clip-text-with-projection": supported_tasks_mapping( - "feature-extraction", - onnx="CLIPTextWithProjectionOnnxConfig", - ), "codegen": supported_tasks_mapping( "feature-extraction", "feature-extraction-with-past", @@ -839,7 +871,6 @@ class TasksManager: "resnet": supported_tasks_mapping( "feature-extraction", "image-classification", onnx="ResNetOnnxConfig", tflite="ResNetTFLiteConfig" ), - "default-timm-config": supported_tasks_mapping("image-classification", onnx="TimmDefaultOnnxConfig"), "roberta": supported_tasks_mapping( "feature-extraction", "fill-mask", @@ -876,16 +907,6 @@ class TasksManager: "semantic-segmentation", onnx="SegformerOnnxConfig", ), - "sentence-transformers-clip": supported_tasks_mapping( - "feature-extraction", - "sentence-similarity", - onnx="SentenceTransformersCLIPOnnxConfig", - ), - "sentence-transformers-transformer": supported_tasks_mapping( - "feature-extraction", - "sentence-similarity", - onnx="SentenceTransformersTransformerOnnxConfig", - ), "sew": supported_tasks_mapping( "feature-extraction", "automatic-speech-recognition", @@ -954,10 +975,6 @@ class TasksManager: "image-to-text-with-past", onnx="TrOCROnnxConfig", ), - "unet": supported_tasks_mapping( - "semantic-segmentation", - onnx="UNetOnnxConfig", - ), "unispeech": supported_tasks_mapping( "feature-extraction", "automatic-speech-recognition", @@ -972,14 +989,6 @@ class TasksManager: "audio-xvector", onnx="UniSpeechSATOnnxConfig", ), - "vae-encoder": supported_tasks_mapping( - "semantic-segmentation", - onnx="VaeEncoderOnnxConfig", - ), - "vae-decoder": supported_tasks_mapping( - "semantic-segmentation", - onnx="VaeDecoderOnnxConfig", - ), "vision-encoder-decoder": supported_tasks_mapping( "image-to-text", "image-to-text-with-past", @@ -1057,15 +1066,26 @@ class TasksManager: onnx="YolosOnnxConfig", ), } + _LIBRARY_TO_SUPPORTED_MODEL_TYPES = { + "diffusers": _DIFFUSERS_SUPPORTED_MODEL_TYPE, + "sentence_transformers": _SENTENCE_TRANSFORMERS_SUPPORTED_MODEL_TYPE, + "timm": _TIMM_SUPPORTED_MODEL_TYPE, + "transformers": _SUPPORTED_MODEL_TYPE, + } _UNSUPPORTED_CLI_MODEL_TYPE = { "unet", "vae-encoder", "vae-decoder", "clip-text-model", "clip-text-with-projection", - "trocr", + "trocr", # TODO: why? } - _SUPPORTED_CLI_MODEL_TYPE = set(_SUPPORTED_MODEL_TYPE.keys()) - _UNSUPPORTED_CLI_MODEL_TYPE + _SUPPORTED_CLI_MODEL_TYPE = ( + set(_SUPPORTED_MODEL_TYPE.keys()) + | set(_DIFFUSERS_SUPPORTED_MODEL_TYPE.keys()) + | set(_TIMM_SUPPORTED_MODEL_TYPE.keys()) + | set(_SENTENCE_TRANSFORMERS_SUPPORTED_MODEL_TYPE.keys()) + ) - _UNSUPPORTED_CLI_MODEL_TYPE @classmethod def create_register( @@ -1094,9 +1114,15 @@ def create_register( ``` """ - def wrapper(model_type: str, *supported_tasks: str) -> Callable[[Type], Type]: + def wrapper( + model_type: str, *supported_tasks: str, library_name: str = "transformers" + ) -> Callable[[Type], Type]: def decorator(config_cls: Type) -> Type: - mapping = cls._SUPPORTED_MODEL_TYPE.get(model_type, {}) + supported_model_type_for_library = TasksManager._LIBRARY_TO_SUPPORTED_MODEL_TYPES[ + library_name + ] # This is a pointer. + + mapping = supported_model_type_for_library.get(model_type, {}) mapping_backend = mapping.get(backend, {}) for task in supported_tasks: if task not in cls.get_all_tasks(): @@ -1108,7 +1134,7 @@ def decorator(config_cls: Type) -> Type: continue mapping_backend[task] = make_backend_config_constructor_for_task(config_cls, task) mapping[backend] = mapping_backend - cls._SUPPORTED_MODEL_TYPE[model_type] = mapping + supported_model_type_for_library[model_type] = mapping return config_cls return decorator @@ -1117,7 +1143,7 @@ def decorator(config_cls: Type) -> Type: @staticmethod def get_supported_tasks_for_model_type( - model_type: str, exporter: str, model_name: Optional[str] = None, library_name: str = "transformers" + model_type: str, exporter: str, model_name: Optional[str] = None, library_name: Optional[str] = None ) -> TaskNameToExportConfigDict: """ Retrieves the `task -> exporter backend config constructors` map from the model type. @@ -1129,13 +1155,29 @@ def get_supported_tasks_for_model_type( The name of the exporter. model_name (`Optional[str]`, defaults to `None`): The name attribute of the model object, only used for the exception message. - library_name (defaults to `transformers`): - The library name of the model. + library_name (`Optional[str]`, defaults to `None`): + The library name of the model. Can be any of "transformers", "timm", "diffusers", "sentence_transformers". Returns: `TaskNameToExportConfigDict`: The dictionary mapping each task to a corresponding `ExportConfig` constructor. """ + if library_name is None: + logger.warning( + 'Not passing the argument `library_name` to `get_supported_tasks_for_model_type` is deprecated and the support will be removed in a future version of Optimum. Please specify a `library_name`. Defaulting to `"transformers`.' + ) + + # We are screwed if different dictionaries have the same keys. + supported_model_type_for_library = { + **TasksManager._DIFFUSERS_SUPPORTED_MODEL_TYPE, + **TasksManager._TIMM_SUPPORTED_MODEL_TYPE, + **TasksManager._SENTENCE_TRANSFORMERS_SUPPORTED_MODEL_TYPE, + **TasksManager._SUPPORTED_MODEL_TYPE, + } + library_name = "transformers" + else: + supported_model_type_for_library = TasksManager._LIBRARY_TO_SUPPORTED_MODEL_TYPES[library_name] + model_type = model_type.lower().replace("_", "-") model_type_and_model_name = f"{model_type} ({model_name})" if model_name else model_type @@ -1143,28 +1185,28 @@ def get_supported_tasks_for_model_type( if library_name in TasksManager._MODEL_TYPE_FOR_DEFAULT_CONFIG: default_model_type = TasksManager._MODEL_TYPE_FOR_DEFAULT_CONFIG[library_name] - if model_type not in TasksManager._SUPPORTED_MODEL_TYPE: + if model_type not in supported_model_type_for_library: if default_model_type is not None: model_type = default_model_type else: raise KeyError( f"{model_type_and_model_name} is not supported yet for {library_name}. " - f"Only {list(TasksManager._SUPPORTED_MODEL_TYPE.keys())} are supported. " + f"Only {list(supported_model_type_for_library.keys())} are supported for the library {library_name}. " f"If you want to support {model_type} please propose a PR or open up an issue." ) - if exporter not in TasksManager._SUPPORTED_MODEL_TYPE[model_type]: + if exporter not in supported_model_type_for_library[model_type]: raise KeyError( f"{model_type_and_model_name} is not supported yet with the {exporter} backend. " - f"Only {list(TasksManager._SUPPORTED_MODEL_TYPE[model_type].keys())} are supported. " + f"Only {list(supported_model_type_for_library[model_type].keys())} are supported. " f"If you want to support {exporter} please propose a PR or open up an issue." ) - return TasksManager._SUPPORTED_MODEL_TYPE[model_type][exporter] + return supported_model_type_for_library[model_type][exporter] @staticmethod def get_supported_model_type_for_task(task: str, exporter: str) -> List[str]: """ - Returns the list of supported architectures by the exporter for a given task. + Returns the list of supported architectures by the exporter for a given task. Transformers-specific. """ return [ model_type.replace("-", "_") @@ -1226,7 +1268,7 @@ def get_model_class_for_task( parameter is useful for example for "automatic-speech-recognition", that may map to AutoModelForSpeechSeq2Seq or to AutoModelForCTC. library (`str`, defaults to `transformers`): - The library name of the model. + The library name of the model. Can be any of "transformers", "timm", "diffusers", "sentence_transformers". Returns: The AutoModel class corresponding to the task. @@ -1475,7 +1517,9 @@ def _infer_task_from_model_name_or_path( if is_local: # TODO: maybe implement that. - raise RuntimeError("Cannot infer the task from a local directory yet, please specify the task manually.") + raise RuntimeError( + f"Cannot infer the task from a local directory yet, please specify the task manually ({', '.join(TasksManager.get_all_tasks())})." + ) else: if subfolder != "": raise RuntimeError( @@ -1580,7 +1624,7 @@ def infer_library_from_model( subfolder: str = "", revision: Optional[str] = None, cache_dir: str = huggingface_hub.constants.HUGGINGFACE_HUB_CACHE, - library_name: str = None, + library_name: Optional[str] = None, ): """ Infers the library from the model repo. @@ -1597,7 +1641,7 @@ def infer_library_from_model( cache_dir (`Optional[str]`, *optional*): Path to a directory in which a downloaded pretrained model weights have been cached if the standard cache should not be used. library_name (`Optional[str]`, *optional*): - The library name of the model. + The library name of the model. Can be any of "transformers", "timm", "diffusers", "sentence_transformers". Returns: `str`: The library name automatically detected from the model repo. """ @@ -1620,18 +1664,28 @@ def infer_library_from_model( if "model_index.json" in all_files: library_name = "diffusers" elif CONFIG_NAME in all_files: - model_config = PretrainedConfig.from_pretrained( - model_name_or_path, subfolder=subfolder, revision=revision - ) + # We do not use PretrainedConfig.from_pretrained which has unwanted warnings about model type. + kwargs = { + "subfolder": subfolder, + "revision": revision, + "cache_dir": cache_dir, + } + config_dict, kwargs = PretrainedConfig.get_config_dict(model_name_or_path, **kwargs) + model_config = PretrainedConfig.from_dict(config_dict, **kwargs) if hasattr(model_config, "pretrained_cfg") or hasattr(model_config, "architecture"): library_name = "timm" elif hasattr(model_config, "_diffusers_version"): library_name = "diffusers" - elif any(file_path.startswith("sentence_") for file_path in all_files): - library_name = "sentence_transformers" else: library_name = "transformers" + elif ( + any(file_path.startswith("sentence_") for file_path in all_files) + or "config_sentence_transformers.json" in all_files + ): + library_name = "sentence_transformers" + else: + library_name = "transformers" if library_name is None: raise ValueError( @@ -1648,7 +1702,7 @@ def standardize_model_attributes( subfolder: str = "", revision: Optional[str] = None, cache_dir: str = huggingface_hub.constants.HUGGINGFACE_HUB_CACHE, - library_name: str = None, + library_name: Optional[str] = None, ): """ Updates the model for export. This function is suitable to make required changes to the models from different @@ -1668,7 +1722,7 @@ def standardize_model_attributes( cache_dir (`Optional[str]`, *optional*): Path to a directory in which a downloaded pretrained model weights have been cached if the standard cache should not be used. library_name (`Optional[str]`, *optional*):: - The library name of the model. + The library name of the model. Can be any of "transformers", "timm", "diffusers", "sentence_transformers". """ library_name = TasksManager.infer_library_from_model( model_name_or_path, subfolder, revision, cache_dir, library_name @@ -1677,7 +1731,9 @@ def standardize_model_attributes( full_model_path = Path(model_name_or_path) / subfolder is_local = full_model_path.is_dir() - if library_name == "timm": + if library_name == "diffusers": + model.config.export_model_type = "stable-diffusion" + elif library_name == "timm": # Retrieve model config config_path = full_model_path / "config.json" @@ -1698,17 +1754,18 @@ def standardize_model_attributes( with open(config_path) as fp: model_type = json.load(fp)["architecture"] - setattr(model.config, "model_type", model_type) + # `model_type` is a class attribute in Transformers, let's avoid modifying it. + model.config.export_model_type = model_type elif library_name == "sentence_transformers": if "Transformer" in model[0].__class__.__name__: model.config = model[0].auto_model.config - model.config.model_type = "sentence-transformers-transformer" + model.config.export_model_type = "transformer" elif "CLIP" in model[0].__class__.__name__: model.config = model[0].model.config - model.config.model_type = "sentence-transformers-clip" + model.config.export_model_type = "clip" else: raise ValueError( - f"The export of a sentence-transformers model with the first module being {model[0].__class__.__name__} is currently not supported in Optimum. Please open an issue or submit a PR to add the support." + f"The export of a sentence_transformers model with the first module being {model[0].__class__.__name__} is currently not supported in Optimum. Please open an issue or submit a PR to add the support." ) @staticmethod @@ -1771,8 +1828,8 @@ def get_model_from_task( Device to initialize the model on. PyTorch-only argument. For PyTorch, defaults to "cpu". model_kwargs (`Dict[str, Any]`, *optional*): Keyword arguments to pass to the model `.from_pretrained()` method. - library_name (`Optional[str]`, *optional*): - The library name of the model. See `TasksManager.infer_library_from_model` for the priority should + library_name (`Optional[str]`, defaults to `None`): + The library name of the model. Can be any of "transformers", "timm", "diffusers", "sentence_transformers". See `TasksManager.infer_library_from_model` for the priority should none be provided. Returns: @@ -1861,7 +1918,7 @@ def get_exporter_config_constructor( model_type: Optional[str] = None, model_name: Optional[str] = None, exporter_config_kwargs: Optional[Dict[str, Any]] = None, - library_name: str = "transformers", + library_name: Optional[str] = None, ) -> ExportConfigConstructor: """ Gets the `ExportConfigConstructor` for a model (or alternatively for a model type) and task combination. @@ -1877,17 +1934,39 @@ def get_exporter_config_constructor( The model type to retrieve the config for. model_name (`Optional[str]`, defaults to `None`): The name attribute of the model object, only used for the exception message. - exporter_config_kwargs(`Optional[Dict[str, Any]]`, defaults to `None`): + exporter_config_kwargs (`Optional[Dict[str, Any]]`, defaults to `None`): Arguments that will be passed to the exporter config class when building the config constructor. + library_name (`Optional[str]`, defaults to `None`): + The library name of the model. Can be any of "transformers", "timm", "diffusers", "sentence_transformers". Returns: `ExportConfigConstructor`: The `ExportConfig` constructor for the requested backend. """ + if library_name is None: + logger.warning( + "Passing the argument `library_name` to `get_supported_tasks_for_model_type` is required, but got library_name=None. Defaulting to `transformers`. An error will be raised in a future version of Optimum if `library_name` is not provided." + ) + + # We are screwed if different dictionaries have the same keys. + supported_model_type_for_library = { + **TasksManager._DIFFUSERS_SUPPORTED_MODEL_TYPE, + **TasksManager._TIMM_SUPPORTED_MODEL_TYPE, + **TasksManager._SENTENCE_TRANSFORMERS_SUPPORTED_MODEL_TYPE, + **TasksManager._SUPPORTED_MODEL_TYPE, + } + library_name = "transformers" + else: + supported_model_type_for_library = TasksManager._LIBRARY_TO_SUPPORTED_MODEL_TYPES[library_name] + if model is None and model_type is None: raise ValueError("Either a model_type or model should be provided to retrieve the export config.") if model_type is None: - model_type = getattr(model.config, "model_type", model_type) + if hasattr(model.config, "export_model_type"): + # We can specifiy a custom `export_model_type` attribute in the config. Useful for timm, sentence_transformers + model_type = model.config.export_model_type + else: + model_type = getattr(model.config, "model_type", None) if model_type is None: raise ValueError("Model type cannot be inferred. Please provide the model_type for the model!") @@ -1911,10 +1990,10 @@ def get_exporter_config_constructor( f" Supported tasks are: {', '.join(model_tasks.keys())}." ) - if model_type not in TasksManager._SUPPORTED_MODEL_TYPE: + if model_type not in supported_model_type_for_library: model_type = TasksManager._MODEL_TYPE_FOR_DEFAULT_CONFIG[library_name] - exporter_config_constructor = TasksManager._SUPPORTED_MODEL_TYPE[model_type][exporter][task] + exporter_config_constructor = supported_model_type_for_library[model_type][exporter][task] if exporter_config_kwargs is not None: exporter_config_constructor = partial(exporter_config_constructor, **exporter_config_kwargs) diff --git a/optimum/exporters/tflite/__main__.py b/optimum/exporters/tflite/__main__.py index af577a4736..b3c90cb63f 100644 --- a/optimum/exporters/tflite/__main__.py +++ b/optimum/exporters/tflite/__main__.py @@ -62,7 +62,9 @@ def main(): task, args.model, framework="tf", cache_dir=args.cache_dir, trust_remote_code=args.trust_remote_code ) - tflite_config_constructor = TasksManager.get_exporter_config_constructor(model=model, exporter="tflite", task=task) + tflite_config_constructor = TasksManager.get_exporter_config_constructor( + model=model, exporter="tflite", task=task, library_name="transformers" + ) # TODO: find a cleaner way to do this. shapes = {name: getattr(args, name) for name in tflite_config_constructor.func.get_mandatory_axes_for_task(task)} tflite_config = tflite_config_constructor(model.config, **shapes) diff --git a/tests/exporters/common/test_tasks_manager.py b/tests/exporters/common/test_tasks_manager.py index a480cab99d..fc0d3eb8db 100644 --- a/tests/exporters/common/test_tasks_manager.py +++ b/tests/exporters/common/test_tasks_manager.py @@ -33,6 +33,15 @@ def _check_all_models_are_registered( for mappings in TasksManager._SUPPORTED_MODEL_TYPE.values(): for class_ in mappings.get(backend, {}).values(): registered_classes.add(class_.func.__name__) + for mappings in TasksManager._TIMM_SUPPORTED_MODEL_TYPE.values(): + for class_ in mappings.get(backend, {}).values(): + registered_classes.add(class_.func.__name__) + for mappings in TasksManager._SENTENCE_TRANSFORMERS_SUPPORTED_MODEL_TYPE.values(): + for class_ in mappings.get(backend, {}).values(): + registered_classes.add(class_.func.__name__) + for mappings in TasksManager._DIFFUSERS_SUPPORTED_MODEL_TYPE.values(): + for class_ in mappings.get(backend, {}).values(): + registered_classes.add(class_.func.__name__) if classes_to_ignore is None: classes_to_ignore = set() diff --git a/tests/exporters/exporters_utils.py b/tests/exporters/exporters_utils.py index 8c65c3b635..95b069cecb 100644 --- a/tests/exporters/exporters_utils.py +++ b/tests/exporters/exporters_utils.py @@ -329,8 +329,8 @@ } PYTORCH_SENTENCE_TRANSFORMERS_MODEL = { - "sentence-transformers-clip": "sentence-transformers/all-MiniLM-L6-v2", - "sentence-transformers-transformer": "sentence-transformers/clip-ViT-B-32-multilingual-v1", + "clip": "sentence-transformers/clip-ViT-B-32", + "transformer": "sentence-transformers/all-MiniLM-L6-v2", } diff --git a/tests/exporters/onnx/test_exporters_onnx_cli.py b/tests/exporters/onnx/test_exporters_onnx_cli.py index 397ecd16e7..d6a89afe41 100644 --- a/tests/exporters/onnx/test_exporters_onnx_cli.py +++ b/tests/exporters/onnx/test_exporters_onnx_cli.py @@ -50,7 +50,7 @@ ) -def _get_models_to_test(export_models_dict: Dict, library_name: str = "transformers"): +def _get_models_to_test(export_models_dict: Dict, library_name: str): models_to_test = [] if is_torch_available(): for model_type, model_names_tasks in export_models_dict.items(): @@ -270,7 +270,9 @@ def test_exporters_cli_pytorch_gpu_stable_diffusion(self, model_type: str, model def test_exporters_cli_fp16_stable_diffusion(self, model_type: str, model_name: str): self._onnx_export(model_name, model_type, device="cuda", fp16=True) - @parameterized.expand(_get_models_to_test(PYTORCH_SENTENCE_TRANSFORMERS_MODEL)) + @parameterized.expand( + _get_models_to_test(PYTORCH_SENTENCE_TRANSFORMERS_MODEL, library_name="sentence_transformers") + ) @require_torch @require_vision @require_sentence_transformers @@ -369,7 +371,7 @@ def test_exporters_cli_fp16_timm( ): self._onnx_export(model_name, task, monolith, no_post_process, device="cuda", fp16=True) - @parameterized.expand(_get_models_to_test(PYTORCH_EXPORT_MODELS_TINY)) + @parameterized.expand(_get_models_to_test(PYTORCH_EXPORT_MODELS_TINY, library_name="transformers")) @require_torch @require_vision def test_exporters_cli_pytorch_cpu( @@ -395,7 +397,7 @@ def test_exporters_cli_pytorch_cpu( self._onnx_export(model_name, task, monolith, no_post_process, variant=variant, model_kwargs=model_kwargs) - @parameterized.expand(_get_models_to_test(PYTORCH_TRANSFORMERS_MODEL_NO_DYNAMIC_AXES)) + @parameterized.expand(_get_models_to_test(PYTORCH_TRANSFORMERS_MODEL_NO_DYNAMIC_AXES, library_name="transformers")) @require_torch @require_vision def test_exporters_cli_pytorch_cpu_no_dynamic_axes( @@ -425,7 +427,7 @@ def test_exporters_cli_pytorch_cpu_no_dynamic_axes( model_name, task, input_shape, input_shape_for_validation, monolith, no_post_process, variant=variant ) - @parameterized.expand(_get_models_to_test(PYTORCH_EXPORT_MODELS_TINY)) + @parameterized.expand(_get_models_to_test(PYTORCH_EXPORT_MODELS_TINY, library_name="transformers")) @require_vision @require_torch_gpu @pytest.mark.gpu_test @@ -455,7 +457,7 @@ def test_exporters_cli_pytorch_gpu( model_name, task, monolith, no_post_process, device="cuda", variant=variant, model_kwargs=model_kwargs ) - @parameterized.expand(_get_models_to_test(PYTORCH_EXPORT_MODELS_TINY)) + @parameterized.expand(_get_models_to_test(PYTORCH_EXPORT_MODELS_TINY, library_name="transformers")) @require_torch @require_vision @slow @@ -493,7 +495,7 @@ def test_exporters_cli_pytorch_with_optimization( else: raise e - @parameterized.expand(_get_models_to_test(PYTORCH_EXPORT_MODELS_TINY)) + @parameterized.expand(_get_models_to_test(PYTORCH_EXPORT_MODELS_TINY, library_name="transformers")) @require_torch_gpu @require_vision @slow @@ -608,7 +610,7 @@ def test_legacy(self): model = onnx.load(Path(tmpdirname) / ONNX_DECODER_MERGED_NAME) self.assertNotIn("position_ids", {node.name for node in model.graph.input}) - @parameterized.expand(_get_models_to_test(PYTORCH_EXPORT_MODELS_TINY)) + @parameterized.expand(_get_models_to_test(PYTORCH_EXPORT_MODELS_TINY, library_name="transformers")) @require_vision @require_torch_gpu @slow