|
1 | 1 | import functools
|
| 2 | +import inspect |
2 | 3 | import warnings
|
3 | 4 | from collections import OrderedDict
|
4 | 5 | from typing import Any, Dict, Optional, TypeVar, Callable, Tuple, Union
|
5 | 6 |
|
6 | 7 | from torch import nn
|
7 |
| -from torchvision.prototype.utils._internal import kwonly_to_pos_or_kw |
8 | 8 |
|
| 9 | +from .._utils import sequence_to_str |
9 | 10 | from ._api import WeightsEnum
|
10 | 11 |
|
11 | 12 |
|
@@ -88,6 +89,60 @@ def _make_divisible(v: float, divisor: int, min_value: Optional[int] = None) ->
|
88 | 89 | return new_v
|
89 | 90 |
|
90 | 91 |
|
| 92 | +D = TypeVar("D") |
| 93 | + |
| 94 | + |
| 95 | +def kwonly_to_pos_or_kw(fn: Callable[..., D]) -> Callable[..., D]: |
| 96 | + """Decorates a function that uses keyword only parameters to also allow them being passed as positionals. |
| 97 | +
|
| 98 | + For example, consider the use case of changing the signature of ``old_fn`` into the one from ``new_fn``: |
| 99 | +
|
| 100 | + .. code:: |
| 101 | +
|
| 102 | + def old_fn(foo, bar, baz=None): |
| 103 | + ... |
| 104 | +
|
| 105 | + def new_fn(foo, *, bar, baz=None): |
| 106 | + ... |
| 107 | +
|
| 108 | + Calling ``old_fn("foo", "bar, "baz")`` was valid, but the same call is no longer valid with ``new_fn``. To keep BC |
| 109 | + and at the same time warn the user of the deprecation, this decorator can be used: |
| 110 | +
|
| 111 | + .. code:: |
| 112 | +
|
| 113 | + @kwonly_to_pos_or_kw |
| 114 | + def new_fn(foo, *, bar, baz=None): |
| 115 | + ... |
| 116 | +
|
| 117 | + new_fn("foo", "bar, "baz") |
| 118 | + """ |
| 119 | + params = inspect.signature(fn).parameters |
| 120 | + |
| 121 | + try: |
| 122 | + keyword_only_start_idx = next( |
| 123 | + idx for idx, param in enumerate(params.values()) if param.kind == param.KEYWORD_ONLY |
| 124 | + ) |
| 125 | + except StopIteration: |
| 126 | + raise TypeError(f"Found no keyword-only parameter on function '{fn.__name__}'") from None |
| 127 | + |
| 128 | + keyword_only_params = tuple(inspect.signature(fn).parameters)[keyword_only_start_idx:] |
| 129 | + |
| 130 | + @functools.wraps(fn) |
| 131 | + def wrapper(*args: Any, **kwargs: Any) -> D: |
| 132 | + args, keyword_only_args = args[:keyword_only_start_idx], args[keyword_only_start_idx:] |
| 133 | + if keyword_only_args: |
| 134 | + keyword_only_kwargs = dict(zip(keyword_only_params, keyword_only_args)) |
| 135 | + warnings.warn( |
| 136 | + f"Using {sequence_to_str(tuple(keyword_only_kwargs.keys()), separate_last='and ')} as positional " |
| 137 | + f"parameter(s) is deprecated. Please use keyword parameter(s) instead." |
| 138 | + ) |
| 139 | + kwargs.update(keyword_only_kwargs) |
| 140 | + |
| 141 | + return fn(*args, **kwargs) |
| 142 | + |
| 143 | + return wrapper |
| 144 | + |
| 145 | + |
91 | 146 | W = TypeVar("W", bound=WeightsEnum)
|
92 | 147 | M = TypeVar("M", bound=nn.Module)
|
93 | 148 | V = TypeVar("V")
|
|
0 commit comments