# Title

In [None]:
%config InteractiveShell.ast_node_interactivity='last_expr_or_assign'  # always print last expr.
%config InlineBackend.figure_format = 'svg'
%load_ext autoreload
%autoreload 2
%matplotlib inline

import logging

logging.basicConfig(level=logging.INFO)
__logger__ = logging.getLogger(__name__)

In [None]:
import numpy as np

np.set_printoptions(precision=4, floatmode="fixed", suppress=True)
rng = np.random.default_rng()

In [None]:
# import tsdm
# from tsdm.utils.decorators import decorator, wrap_func

In [None]:
from functools import wraps
from typing import Callable, Optional

from decorator import decorator


@decorator
def wrap_func(
    func: Callable,
    before: Optional[Callable],
    after: Optional[Callable],
    /,
) -> Callable:
    r"""Wrap a function with pre and post hooks."""
    print(f"wrap_func called with {dir()=} {locals()=}")

    if before is None and after is None:
        __logger__.debug("No hooks added to %s", func)
        return func

    if before is not None and after is None:
        __logger__.debug("Adding pre hook %s to %s", before, func)

        @wraps(func)
        def _wrapper(*args, **kwargs):
            before(*args, **kwargs)
            result = func(*args, **kwargs)
            return result

        return _wrapper

    if before is None and after is not None:
        __logger__.debug("Adding post hook %s to %s", after, func)

        @wraps(func)
        def _wrapper(*args, **kwargs):
            result = func(*args, **kwargs)
            after(*args, **kwargs)
            return result

        return _wrapper

    if before is not None and after is not None:
        __logger__.debug("Adding pre hook %s to %s", before, func)
        __logger__.debug("Adding post hook %s to %s", after, func)

        @wraps(func)
        def _wrapper(*args, **kwargs):
            before(*args, **kwargs)
            result = func(*args, **kwargs)
            after(*args, **kwargs)
            return result

        return _wrapper

    raise RuntimeError(f"Unreachable code reached for {func}")

In [None]:
def pre_func(*args, **kwargs):
    print(f"λ={pre_func} called with {args=} {kwargs=}")


def post_func(*args, **kwargs):
    print(f"λ={post_func} called with {args=} {kwargs=}")


def func(x):
    return x

In [None]:
g = wrap_func(func, pre_func)
g("a")

In [None]:
@wrap_func(pre_func, post_func) def f(x):
    return x

f("a")

In [None]:
@wrap_func(pre_func, post_func)
def f(a=None):
    print(f"{vars(f)=}")
    print(f"{f.__qualname__=}")
    print(f"{dir()=}")
    print(f"{locals()=}")

In [None]:
f()

In [None]:
f.__wrapped__