-
Notifications
You must be signed in to change notification settings - Fork 7.2k
Refactor the get_weights API
#5006
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
761959e
3339918
ed77fb0
6d0cc8e
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -15,3 +15,4 @@ | |
| from . import quantization | ||
| from . import segmentation | ||
| from . import video | ||
| from ._api import get_weight | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,7 +1,9 @@ | ||
| import importlib | ||
| import inspect | ||
| import sys | ||
| from collections import OrderedDict | ||
| from dataclasses import dataclass, fields | ||
| from enum import Enum | ||
| from inspect import signature | ||
| from typing import Any, Callable, Dict | ||
|
|
||
| from ..._internally_replaced_utils import load_state_dict_from_url | ||
|
|
@@ -30,7 +32,6 @@ class Weights: | |
| url: str | ||
| transforms: Callable | ||
| meta: Dict[str, Any] | ||
| default: bool | ||
|
|
||
|
|
||
| class WeightsEnum(Enum): | ||
|
|
@@ -50,7 +51,7 @@ def __init__(self, value: Weights): | |
| def verify(cls, obj: Any) -> Any: | ||
| if obj is not None: | ||
| if type(obj) is str: | ||
| obj = cls.from_str(obj) | ||
| obj = cls.from_str(obj.replace(cls.__name__ + ".", "")) | ||
| elif not isinstance(obj, cls): | ||
| raise TypeError( | ||
| f"Invalid Weight class provided; expected {cls.__name__} but received {obj.__class__.__name__}." | ||
|
|
@@ -59,8 +60,8 @@ def verify(cls, obj: Any) -> Any: | |
|
|
||
| @classmethod | ||
| def from_str(cls, value: str) -> "WeightsEnum": | ||
| for v in cls: | ||
| if v._name_ == value or (value == "default" and v.default): | ||
| for k, v in cls.__members__.items(): | ||
| if k == value: | ||
| return v | ||
| raise ValueError(f"Invalid value {value} for enum {cls.__name__}.") | ||
|
|
||
|
|
@@ -78,41 +79,35 @@ def __getattr__(self, name): | |
| return super().__getattr__(name) | ||
|
|
||
|
|
||
| def get_weight(fn: Callable, weight_name: str) -> WeightsEnum: | ||
| def get_weight(name: str) -> WeightsEnum: | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm wondering: shoudn't this return a There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think we need to return a |
||
| """ | ||
| Gets the weight enum of a specific model builder method and weight name combination. | ||
| Gets the weight enum value by its full name. Example: "ResNet50_Weights.ImageNet1K_V1" | ||
|
|
||
| Args: | ||
| fn (Callable): The builder method used to create the model. | ||
| weight_name (str): The name of the weight enum entry of the specific model. | ||
| name (str): The name of the weight enum entry. | ||
|
|
||
| Returns: | ||
| WeightsEnum: The requested weight enum. | ||
| """ | ||
| sig = signature(fn) | ||
| if "weights" not in sig.parameters: | ||
| raise ValueError("The method is missing the 'weights' parameter.") | ||
| try: | ||
| enum_name, value_name = name.split(".") | ||
| except ValueError: | ||
| raise ValueError(f"Invalid weight name provided: '{name}'.") | ||
|
|
||
| base_module_name = ".".join(sys.modules[__name__].__name__.split(".")[:-1]) | ||
| base_module = importlib.import_module(base_module_name) | ||
datumbox marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| model_modules = [base_module] + [ | ||
| x[1] for x in inspect.getmembers(base_module, inspect.ismodule) if x[1].__file__.endswith("__init__.py") | ||
| ] | ||
|
|
||
| ann = signature(fn).parameters["weights"].annotation | ||
| weights_enum = None | ||
| if isinstance(ann, type) and issubclass(ann, WeightsEnum): | ||
| weights_enum = ann | ||
| else: | ||
| # handle cases like Union[Optional, T] | ||
| # TODO: Replace ann.__args__ with typing.get_args(ann) after python >= 3.8 | ||
| for t in ann.__args__: # type: ignore[union-attr] | ||
| if isinstance(t, type) and issubclass(t, WeightsEnum): | ||
| # ensure the name exists. handles builders with multiple types of weights like in quantization | ||
| try: | ||
| t.from_str(weight_name) | ||
| except ValueError: | ||
| continue | ||
| weights_enum = t | ||
| break | ||
| for m in model_modules: | ||
| potential_class = m.__dict__.get(enum_name, None) | ||
| if potential_class is not None and issubclass(potential_class, WeightsEnum): | ||
| weights_enum = potential_class | ||
| break | ||
|
|
||
| if weights_enum is None: | ||
| raise ValueError( | ||
| "The weight class for the specific method couldn't be retrieved. Make sure the typing info is correct." | ||
| ) | ||
| raise ValueError(f"The weight enum '{enum_name}' for the specific method couldn't be retrieved.") | ||
|
|
||
| return weights_enum.from_str(weight_name) | ||
| return weights_enum.from_str(value_name) | ||
Uh oh!
There was an error while loading. Please reload this page.