In [None]:
%config InteractiveShell.ast_node_interactivity='last_expr_or_assign'  # always print last expr.

In [None]:
r"""Custom Decorators."""

import gc
import logging
from dataclasses import dataclass
from functools import wraps
from inspect import Parameter, signature
from time import perf_counter_ns
from typing import Any, Callable, Optional

logger = logging.getLogger(__name__)
__all__ = ["decorator", "DecoratorError", "timefun"]


KEYWORD_ONLY = Parameter.KEYWORD_ONLY
POSITIONAL_ONLY = Parameter.POSITIONAL_ONLY
POSITIONAL_OR_KEYWORD = Parameter.POSITIONAL_OR_KEYWORD
VAR_KEYWORD = Parameter.VAR_KEYWORD
VAR_POSITIONAL = Parameter.VAR_POSITIONAL
EMPTY = Parameter.empty


def rpartial(func: Callable, /, *fixed_args: Any, **fixed_kwargs: Any) -> Callable:
    r"""Apply positional arguments from the right."""

    @wraps(func)
    def wrapper(*func_args, **func_kwargs):
        return func(*(func_args + fixed_args), **(func_kwargs | fixed_kwargs))

    return wrapper


@dataclass
class DecoratorError(Exception):
    r"""Raise Error related to decorator construction."""

    decorator: Callable
    message: Optional[str]

    def __str__(self):
        default = f"{self.decorator} with signature {signature(decorator)}\n"
        return default + self.message


def decorator(deco: Callable) -> Callable:
    """Meta-Decorator for constructing parametrized decorators."""
    mandatory_pos_args, mandatory_key_args = set(), set()

    for key, param in signature(deco).parameters.items():
        if param.kind is VAR_POSITIONAL:
            raise DecoratorError(
                deco, "Decorator does not support VAR_POSITIONAL arguments (*args)!!"
            )
        if param.kind is POSITIONAL_OR_KEYWORD:
            raise DecoratorError(
                deco,
                "Decorator does not support POSITIONAL_OR_KEYWORD arguments!!"
                f"Got {signature(deco)=}"
                "Separate positional and keyword arguments like fun(po, /, *, ko=None,)"
                "Cf. https://www.python.org/dev/peps/pep-0570/",
            )
        if param.kind is POSITIONAL_ONLY and param.default is not EMPTY:
            raise DecoratorError(
                deco, "Positional arguments are not allowed to be optional!"
            )
        if param.default is EMPTY and param.kind is POSITIONAL_ONLY:
            mandatory_pos_args |= {key}
        if param.default is EMPTY and param.kind is KEYWORD_ONLY:
            mandatory_key_args |= {key}

    if not mandatory_pos_args:
        raise DecoratorError(
            deco, "Decorator requires at least one POSITIONAL_ONLY argument, got zero."
        )

    @wraps(deco)
    def parametrized_decorator(
        __func__: Any = None, *args: Any, **kwargs: Any
    ) -> Callable:
        if (
            len(mandatory_pos_args | mandatory_key_args) > 1
        ):  # no bare decorator allowed!
            if (
                len(args) + 1 == len(mandatory_pos_args) - 1
            ):  # all pos args except func given
                if missing_keys := (mandatory_key_args - kwargs.keys()):
                    raise DecoratorError(f"Not enough kwargs supplied, {missing_keys=}")
                logger.info(">>> Generating bracket version of %s <<<", decorator)
                return rpartial(deco, *(__func__, *args), **kwargs)
            logger.info(">>> Generating functional version of %s <<<", decorator)
            return deco(__func__, *args, **kwargs)
        if __func__ is None:
            logger.info(">>> Generating bare version of %s <<<", decorator)
            return rpartial(deco, *args, **kwargs)
        logger.info(">>> Generating bracket version of %s <<<", decorator)
        return deco(__func__, *args, **kwargs)

    return parametrized_decorator


def timefun(
    fun: Callable, append: bool = True, loglevel: int = logging.WARNING
) -> Callable:
    r"""Log the execution time of the function. Use as decorator.

    By default appends the execution time (in seconds) to the function call.

    ``outputs, time_elapse = timefun(f, append=True)(inputs)``

    If the function call failed, ``outputs=None`` and ``time_elapsed=float('nan')`` are returned.

    Parameters
    ----------
    fun: Callable
    append: bool, default=True
        Whether to append the time result to the function call
    loglevel: int, default=logging.Warning (20)
    """
    timefun_logger = logging.getLogger("timefun")

    @wraps(fun)
    def timed_fun(*args, **kwargs):
        gc.collect()
        gc.disable()
        try:
            start_time = perf_counter_ns()
            result = fun(*args, **kwargs)
            end_time = perf_counter_ns()
            elapsed = (end_time - start_time) / 10**9
            timefun_logger.log(loglevel, "%s executed in %.4f s", fun.__name__, elapsed)
        except (KeyboardInterrupt, SystemExit) as E:
            raise E
        except Exception as E:  # pylint: disable=W0703
            result = None
            elapsed = float("nan")
            RuntimeWarning(f"Function execution failed with Exception {E}")
            timefun_logger.log(loglevel, "%s failed with Exception %s", fun.__name__, E)
        gc.enable()

        return result, elapsed if append else result

    return timed_fun

In [None]:
len("POSITIONAL_OR_KEYWORD")

In [None]:
repr(Parameter.empty)

In [None]:
str(p)

In [None]:
repr([(key, str(param.kind)) for key, param in signature(timefun).parameters.items()])

In [None]:
from tsdm.utils import timefun

In [None]:
@timefun(append=False)
def test(x):
    return x


test(3)

In [None]:
timefun(lambda x: x**2, append=False)(3)

In [None]:
"a".join("12345")

In [None]:
"""
po: positional-only
ko: keyword only
pk: positional or keyword
vk: variable keyword only
vp: variable positional only
_d: with default argument
"""

There are 3 different ways of calling parametrized decorators


1. Vanilla:   
    ```python
    deco(func, *deco_args, **deco_kwargs)
    ```
    - input is `args=(func, *deco_args)`, `kwargs=deco_kwargs`
    - in particular `args=(func,)` and `kwargs={}` is possible
2. Bare:  (only allowable if all arguments except `**kwargs` have default values)
    ```python
    @deco
    def func(..)
    ```
    - input is  `args=(func,)`, `kwargs={}`
3. Bracketed:
    ```python
    @deco(*deco_args, **deco_kwargs)
    def func(...)
    ```
    - input is `args=deco_args`, `kwargs=deco_kwargs`
    - in particular `args=(other_callable, )` and `kwargs={}` is possible
 
    
Main problem: How to distinguish whether the input is `(func, )`, `(*deco_args,)` or `(func, *deco_args)` ? In particular consider the edge cases:



| Type    | Code                                                           | Return                     | `deco_args = ()` | `deco_args = (Callable,)` |
|---------|----------------------------------------------------------------|----------------------------|------------------|---------------------------|
| Bare    | <code>@deco<br>def func(..)</code>                             | `wrapper(func)`            | `(func, )`       | `(func, )`                |
| Bracket | <code>@deco(*deco_args, **deco_kwargs)<br>def func(...)</code> | `wrapper(decorator)(func)` | `()`             | `(Callable, )`            |
| Called  | <code>deco(func, *deco_args, **deco_kwargs)</code>             | `wrapped(func)`            | `(func, )`       | `(func, Callable)`        |


As we can see there are multiple ambiguity problems: When the decorator get's passed a single callable as input, there are 3 different things that could be the case.

### Resolution: Allow only a subset of function signatures

We require that **all optional parameters must be keyword-only** (i.e. parameters with defaults, `*args` and `**kwargs` ).
In particular, `*args` will not be allowed and `POSITIONAL_OR_KEYWORD` arguments will not be allowed

```python
def decorator(
    func, 
    po₁, po₂, …, poₙ,  # must all have non-optional!
    /, *,   # <- seperator between po and ko args
    ko₁, ko₂, …, koₘ,
    kod₁=d₁, kod₂=d₂, …, kodₖ=dₖ,
    **kwargs
)
```

This has the main advantage that it is a-priori determined how many positional arguments will be consumed by `decorator`.

Thus, given input, we can simply count the number of arguments and unqiuely determine the action mode that way.


Many people propose to check something along the lines of `callable[args[0]]`, but that is error prone and does not work for decorators like

```python
def compose(func: Callable, outer: Callable, *outer_args, **outer_kwargs) -> Callable:
    @wraps(func)
    def wrapper(*fun_args **func_kwargs):
        y = func(*args, **kwargs)
        return outer(y, *outer_args, **outer_kwargs)
    return wrapper
```

```python
def compose(func: Callable, inner: Callable, *inner_args, **inner_kwargs) -> Callable:
    @wraps(func)
    def wrapper(x, *fun_args **func_kwargs):
        y = inner(x, *inner_args, **inner_kwargs)
        return func(y, *outer_args, **outer_kwargs)
    return wrapper
```

In [None]:
def decorator(
    func, po1, po2, poN, /, ko1, ko2, koM, kod1=None, kod2=None, kodK=None, **kwargs
):
    pass

In [None]:
"""
Note that there are two modi:


# naked mode
@decorator        -> returns decorator(fun)
def fun(...)

# param mode
@decorator(args)  -> returns decorator(args)(fun), i.e. decorator must return an inner_decorator = decorator(args)
def f(....)


- we require that the first argument to the decorator is __func__
- naked mode cannot be allowed if decorator has arguments (po/pk/ko) without default values.
- naked mode not possible if decorator can be called with *args due to amibutity @dec:f = dec(f) 

bottom line: only allow naked more iff: (1) no *args is present and (2) all parameters other than **kwargs have default values


- naked mode requires that the first agument passed is interpreted as the function
    - We can have both modi simulataneously by a trick: having the first argument of the parametrized decorator being (__func__=None)
    - Either the use writes is along th lines of
        ``` def deco(func, a, b, ..., *args, *kwargs): ...
                
        If there are no non-default (po/pk/kd) arguments except a single pos-only argument `func`, we transform the function into
        
        ``` def f(func=None, po_d=..., / pk_d=None, (*args or *), ko_d, **vk)

        If there are any non-default arguments, then only param_mode is allowed. The function get's transfomed into a wrapper:
        
        def decorator_factory(argument):
            def decorator(function):
                def wrapper(*args, **kwargs):
                    funny_stuff()
                    something_with_argument(argument)
                    result = function(*args, **kwargs)
                    more_funny_stuff()
                    return result
                return wrapper
            return decorator

        
        deco(func, *args, **kwargs) -> partial (func, *args,**kwargs)
        
        def debug(func, args):
        @wraps(func)
        def wrapper(*func_args, **func_kwargs):
            # print("\n" + "-"*80 + "\n", flush=True)
            # print(F">>> Entering {func}", flush=True)
            # print(F"    {signature(func)=}", flush=True)
            # print(F"    {func_args=}", flush=True)
            # print(F"    {func_kwargs=}", flush=True)
            return_value = func(*func_args, **func_kwargs)
            # print(F"    {return_value=}", flush=True)
            # print(F"<<< Exiting  {func}", flush=True)
            return return_value
        return wrapper
    
Idea: have 

- H
"""

In [None]:
from reprlib import recursive_repr


class rpartial:
    """New function with partial application of the given arguments
    and keywords.
    """

    __slots__ = "func", "args", "keywords", "__dict__", "__weakref__"

    def __new__(cls, func, /, *args, **keywords):
        if not callable(func):
            raise TypeError("the first argument must be callable")

        if hasattr(func, "func"):
            args = func.args + args
            keywords = {**func.keywords, **keywords}
            func = func.func

        self = super(rpartial, cls).__new__(cls)

        self.func = func
        self.args = args
        self.keywords = keywords
        return self

    def __call__(self, /, *args, **keywords):
        keywords = {**self.keywords, **keywords}
        return self.func(*args, *self.args, **keywords)

    @recursive_repr()
    def __repr__(self):
        qualname = type(self).__qualname__
        args = [repr(self.func)]
        args.extend(repr(x) for x in self.args)
        args.extend(f"{k}={v!r}" for (k, v) in self.keywords.items())
        if type(self).__module__ == "functools":
            return f"functools.{qualname}({', '.join(args)})"
        return f"{qualname}({', '.join(args)})"

    def __reduce__(self):
        return (
            type(self),
            (self.func,),
            (self.func, self.args, self.keywords or None, self.__dict__ or None),
        )

    def __setstate__(self, state):
        if not isinstance(state, tuple):
            raise TypeError("argument to __setstate__ must be a tuple")
        if len(state) != 4:
            raise TypeError(f"expected 4 items in state, got {len(state)}")
        func, args, kwds, namespace = state
        if (
            not callable(func)
            or not isinstance(args, tuple)
            or (kwds is not None and not isinstance(kwds, dict))
            or (namespace is not None and not isinstance(namespace, dict))
        ):
            raise TypeError("invalid partial state")

        args = tuple(args)  # just in case it's a subclass
        if kwds is None:
            kwds = {}
        elif type(kwds) is not dict:  # XXX does it need to be *exactly* dict?
            kwds = dict(kwds)
        if namespace is None:
            namespace = {}

        self.__dict__ = namespace
        self.func = func
        self.args = args
        self.keywords = kwds

In [None]:
if a := 3 - 3:
    print(a)

In [None]:
from functools import wraps
from inspect import Parameter, signature
from typing import Callable

KEYWORD_ONLY = Parameter.KEYWORD_ONLY
POSITIONAL_ONLY = Parameter.POSITIONAL_ONLY
POSITIONAL_OR_KEYWORD = Parameter.POSITIONAL_OR_KEYWORD
VAR_KEYWORD = Parameter.VAR_KEYWORD
VAR_POSITIONAL = Parameter.VAR_POSITIONAL
EMPTY = Parameter.empty


def debug(func):
    @wraps(func)
    def wrapper(*func_args, **func_kwargs):
        print("\n" + "-" * 80 + "\n", flush=True)
        print(f">>> Entering {func}", flush=True)
        print(f"    {signature(func)=}", flush=True)
        print(f"    {func_args=}", flush=True)
        print(f"    {func_kwargs=}", flush=True)
        return_value = func(*func_args, **func_kwargs)
        print(f"    {return_value=}", flush=True)
        print(f"<<< Exiting  {func}", flush=True)
        return return_value

    return wrapper


# def rpartial(func, /, *fixed_args, **fixed_kwargs):
#     """Partially applying arguments from the right."""
#     @wraps(func)
#     def wrapper(*func_args, **func_kwargs):
#         return func(*(func_args + fixed_args), **(func_kwargs | fixed_kwargs))
#     return wrapper


# def decorator(deco: Callable) -> Callable:
#     """Meta-Decorator for cosntructing parametrized decorators."""
#     params = signature(deco).parameters
#     no_bare_decorator = any(  # check if some params do not have defaults
#         param.default is Parameter.empty and param.kind is not Parameter.VAR_KEYWORD
#         for param in islice(params.values(), 1, None)
#     )

#     @wraps(deco)
#     def parametrized_decorator(__func__=None, *args, **kwargs):
#         print(__func__, args, kwargs)
#         if no_bare_decorator:
#             if __func__ is None:
# #                 return rpartial(deco, *((__func__,)+ args), **kwargs)
#                 return rpartial(deco, __func__, *args, **kwargs)
#             else:
# #                 return rpartial(deco, *((__func__,)+ args), **kwargs)
#                 return rpartial(deco, __func__, *args, **kwargs)
#         if __func__ is None:
#             return rpartial(deco, *args, **kwargs)
#         return deco(__func__, *args, **kwargs)
#     return parametrized_decorator


def rpartial(func, /, *fixed_args, **fixed_kwargs):
    """Partially applying arguments from the right."""

    @wraps(func)
    def wrapper(*func_args, **func_kwargs):
        return func(*(func_args + fixed_args), **(func_kwargs | fixed_kwargs))

    return wrapper


class DecoratorError(Exception):
    """Raise for my specific kind of exception"""


def decorator(deco: Callable) -> Callable:
    """Meta-Decorator for cosntructing parametrized decorators."""
    mandatory_pos_args, mandatory_key_args = set(), set()

    for key, param in signature(deco).parameters.items():
        if param.kind is VAR_POSITIONAL:
            raise DecoratorError(
                "Decorator does not support VAR_POSITIONAL arguments (*args)!!"
            )
        if param.kind is POSITIONAL_OR_KEYWORD:
            raise DecoratorError(
                "Decorator does not support POSITIONAL_OR_KEYWORD arguments!!"
                "Seperate positional and keyword arguments like fun(po, /, *, ko=None,)"
                "Cf. https://www.python.org/dev/peps/pep-0570/"
            )
        if param.kind is POSITIONAL_ONLY and param.default is not EMPTY:
            raise DecoratorError("Positonal arguments are not allowed to be optional!")
        if param.default is EMPTY and param.kind is POSITIONAL_ONLY:
            mandatory_pos_args |= {key}
        if param.default is EMPTY and param.kind is KEYWORD_ONLY:
            mandatory_key_args |= {key}

    if not mandatory_pos_args:
        raise DecoratorError(
            "Decorator requires at least one POSITIONAL_ONLY argument, got zero."
        )

    no_bare_decorator = len(mandatory_pos_args | mandatory_key_args) > 1

    @wraps(deco)
    def parametrized_decorator(__func__=None, *args, **kwargs):
        print(__func__, args, kwargs)
        if no_bare_decorator:
            if len(args) + 1 == len(mandatory_pos_args) - 1:
                if missing_keys := mandatory_key_args - kwargs.keys():
                    raise DecoratorError(f"Not enough kwargs supplied, {missing_keys=}")
                print(">>> Generating bracket verions <<<")
                return rpartial(deco, *(__func__, *args), **kwargs)
            print(">>> Generating functional verions <<<")
            return deco(__func__, *args, **kwargs)
        if __func__ is None:
            print(">>> Generating bare verion <<<")
            return rpartial(deco, *args, **kwargs)
        print(">>> Generating bracket verion <<<")
        return deco(__func__, *args, **kwargs)

    return parametrized_decorator

In [None]:
from typing import Callable


@decorator
def clip(
    func: Callable[[float, ...], float], /, *, lower=-1, upper=+1
) -> Callable[[float, ...], float]:
    r"""Clip function values post-hoc."""

    @wraps(func)
    def wrapper(x, *func_args, **func_kwargs):
        y = func(x, *func_args, **func_kwargs)
        return max(lower, min(upper, y))

    return wrapper


@decorator
def modulo(func: Callable[[int, ...], int], m: int, /) -> Callable[[int, ...], int]:
    r"""Apply post-hoc modulo operation $x↦x 𝗆𝗈𝖽 m$."""

    @wraps(func)
    def wrapper(x, *func_args, **func_kwargs):
        y = func(x, *func_args, **func_kwargs)
        return y % m

    return wrapper


print(repr(modulo), modulo.__doc__, modulo.__annotations__, sep="\n")
print(repr(clip), clip.__doc__, clip.__annotations__, sep="\n")

# decorator usage

In [None]:
@clip
def identity(x: float) -> float:
    """identity function"""
    return x


@modulo(3)
def square(x: float) -> float:
    """identity function"""
    return x**2


@clip(lower=-10, upper=+10)
def cube(x: float) -> float:
    """cube function"""
    return x**3


print("", repr(identity), identity.__doc__, identity.__annotations__, sep="\n")
print([identity(k) for k in range(-5, +5)])
print("", repr(square), square.__doc__, square.__annotations__, sep="\n")
print([square(k) for k in range(-5, +5)])
print("", repr(cube), cube.__doc__, cube.__annotations__, sep="\n")
print([cube(k) for k in range(-5, +5)])

# functional usage

In [None]:
def square(x: float) -> float:
    """identity function"""
    return x**2


def identity(x: float) -> float:
    """identity function"""
    return x


def cube(x: float) -> float:
    """cube function"""
    return x**3


print([clip(identity)(k) for k in range(-5, +5)])
print([clip(cube, lower=-9, upper=+9)(k) for k in range(-5, +5)])
print([modulo(square, 3)(k) for k in range(-5, +5)])

In [None]:
identity.__repr__()

In [None]:
print([identity(k) for k in range(10)])
print([square(k) for k in range(10)])

In [None]:
@parametrize
def pre_linear(func, /, a=1, b=0):
    """pre-linear transformation first input argument: x -> f(a*x+b)"""

    @wraps(func)
    def wrapper(x, *func_args, **func_kwargs):
        return func(a * x + b, *func_args, **func_kwargs)

    return wrapper

In [None]:
@parametrize
def post_linear(func, /, a=1, b=0):
    """post-hoc linear transformatation x -> a*f(x) + b"""

    @wraps(func)
    def wrapper(x, *func_args, **func_kwargs):
        return a * func(x, *func_args, **func_kwargs) + b

    return wrapper

In [None]:
@parametrize
def post_linear(func, /, a=1, b=0):
    """post-hoc linear transformatation x -> a*f(x) + b"""

    @wraps(func)
    def wrapper(*func_args, **func_kwargs):
        return n * m * func(*func_args, **func_kwargs)

    return wrapper

In [None]:
multiply.__doc__

In [None]:
@post_linear
def function(a):
    """adds 10"""
    return 10 + a


function(3)  # Prints 26

In [None]:
function.__doc__

In [None]:
sig = signature(bar)

for key, param in sig.parameters.items():
    if param.default is Parameter.empty:

        def decorator_factory(argument):
            def decorator(function):
                def wrapper(*args, **kwargs):
                    funny_stuff()
                    something_with_argument(argument)
                    result = function(*args, **kwargs)
                    more_funny_stuff()
                    return result

                return wrapper

            return decorator

In [None]:
KEYWORD_ONLY = Parameter.KEYWORD_ONLY
POSITIONAL_ONLY = Parameter.POSITIONAL_ONLY
POSITIONAL_OR_KEYWORD = Parameter.POSITIONAL_OR_KEYWORD
VAR_KEYWORD = Parameter.VAR_KEYWORD
VAR_POSITIONAL = Parameter.VAR_POSITIONAL
EMPTY = Parameter.empty

In [None]:
?sig.parameters.items

In [None]:
next(iter(sig.parameters))

In [None]:
params = iter(sig.parameters.items())

f_key, f_target = next(params)

assert isinstance(
    f_target, Callable
), "The first argument must be a handle to the function to be decorated!"

for key, param in params:
    if param.default is EMPTY and param.kind not in (VAR_POSITIONAL, VAR_KEYWORD):
        return param_mode_decorator

return naked_mode_decorator

In [None]:
bar.__defaults__

In [None]:
signature(bar).parameters["ko_d"].default

In [None]:
signature(bar).return_annotation

In [None]:
dir(l[0][-1])

In [None]:
l[0][1].annotation

In [None]:
signature(bar)

In [None]:
bar

In [None]:
import functools


def my_decorator(*args_or_func, **decorator_kwargs):
    def _decorator(func):
        @functools.wraps(func)
        def wrapper(*args, **kwargs):
            if not args_or_func or callable(args_or_func[0]):
                # Here you can set default values for positional arguments
                decorator_args = ()
            else:
                decorator_args = args_or_func

            print("Available inside the wrapper:", decorator_args, decorator_kwargs)

            # ...
            result = func(*args, **kwargs)
            # ...

            return result

        return wrapper

    return (
        _decorator(args_or_func[0])
        if args_or_func and callable(args_or_func[0])
        else _decorator
    )


@my_decorator
def func_1(arg):
    print(arg)


func_1("test")
# Available inside the wrapper: () {}
# test


@my_decorator()
def func_2(arg):
    print(arg)


func_2("test")
# Available inside the wrapper: () {}
# test


@my_decorator("any arg")
def func_3(arg):
    print(arg)


func_3("test")
# Available inside the wrapper: ('any arg',) {}
# test


@my_decorator("arg_1", 2, [3, 4, 5], kwarg_1=1, kwarg_2="2")
def func_4(arg):
    print(arg)


func_4("test")
# Available inside the wrapper: ('arg_1', 2, [3, 4, 5]) {'kwarg_1': 1, 'kwarg_2': '2'}
# test