# Title

In [None]:
%config InteractiveShell.ast_node_interactivity='last_expr_or_assign'  # always print last expr.
%config InlineBackend.figure_format = 'svg'
%load_ext autoreload
%autoreload 2
%matplotlib inline

In [None]:
import numpy as np
import matplotlib.pyplot as plt

np.set_printoptions(precision=4, floatmode="fixed", suppress=True)
rng = np.random.default_rng()

In [None]:
class AutoJit(object):
    def __init__(self, arg):
        self.arg = arg

    def __call__(self, cls):
        class Wrapped(cls):
            classattr = self.arg

            def new_method(self, value):
                return value * 2

        return Wrapped

In [None]:
import torch
from torch import Tensor, nn, jit


class MyModule(nn.Module):
    a: Tensor

    def __init__(self, a: float = 1.0):
        super().__init__()
        self.a = torch.tensor(a)

    def forward(self, x: Tensor) -> Tensor:
        return self.a * x

In [None]:
import functools
from functools import wraps


def autojit(cls):

    print(cls)

    @functools.wraps(cls, updated=())
    class Wrapper(cls):
        print(cls)

        def __new__(cls, *args, **kwargs):
            print(args, kwargs, cls)

            instance = super().__new__(cls, *args, *kwargs)

            print(type(instance))
            return jit.script(instance(*args, **kwargs))

        def __init__(self, *args, **kwargs):
            super().__init__(*args, **kwargs)

    return Wrapper

In [None]:
def append_func(other):
    def decorator(func):
        @wraps(func)
        def wrapper(*args, **kwargs):
            return other(func(*args, **kwargs))

        return wrapper

    return decorator


def prepend_func(other):
    def decorator(func):
        @wraps(func)
        def wrapper(*args, **kwargs):
            return func(other(*args, **kwargs))

        return wrapper

    return decorator


@append_func(lambda x: 2 * x)
def f(x):
    return x + 1


@prepend_func(lambda x: 2 * x)
def g(x):
    return x + 1


f(1), g(1)

In [None]:
def patched_jit(obj):

    if isinstance(obj, type):

        @functools.wraps(obj, updated=())
        class WrappedClass(obj):
            def __new__(cls, *args, **kwargs):
                instance = obj(*args, **kwargs)
                return jit.script(instance)

        return WrappedClass

    return jit.script(obj)

In [None]:
def autojit(class_):
    @functools.wraps(class_, updated=())
    class WrappedClass(class_):
        def __new__(cls, *args, **kwargs):
            # create object from base class
            instance = class_(*args, **kwargs)
            return jit.script(instance)

    return WrappedClass

In [None]:
@jit.script
class MyModule(nn.Module):
    a: Tensor

    def __init__(self, a: float = 1.0):
        super().__init__()
        self.a = torch.tensor(a)

    def forward(self, x: Tensor) -> Tensor:
        return self.a * x


MyModule(a=1.3)

In [None]:
def decorator(class_):
    @functools.wraps(class_, updated=())
    def has_method(cls, meth):
        # (FIXME:the check bellow does not take in account other applications of this decorator)
        return any(meth in ancestor.__dict__ for ancestor in cls.__mro__[:-1])

    def has_new(cls):
        return has_method(cls, "__new__")

    def has_init(cls):
        return has_method(cls, "__init__")

    class Wrapper(class_):
        def __new__(cls, *args, **kwargs):
            print("Wrapper.__new__", cls, args, kwargs)
            if (args or kwargs) and not has_new(cls) and has_init(cls):
                args, kwargs = (), {}
            obj = super().__new__(cls)
            # ...
            return obj

        def __init__(self, *args, **kwargs):
            print("Wrapper.__init__", self, args, kwargs)
            functools.update_wrapper(self, class_)
            cls = self.__class__
            if (args or kwargs) and not has_init(cls) and has_new(cls):
                args, kwargs = (), {}
            super().__init__(*args, **kwargs)

    return Wrapper

In [None]:
@autojit
class MyModule(nn.Module):
    a: Tensor

    def __init__(self, a: float = 1.0):
        super().__init__()
        self.a = torch.tensor(a)

    def forward(self, x: Tensor) -> Tensor:
        return self.a * x


MyModule(a=1.3)

In [None]:
class AutoJit(object):
    def __init__(self, arg):
        self.arg = arg

    def __call__(self, cls):
        class Wrapped(cls):
            classattr = self.arg

            def new_method(self, value):
                return value * 2

        return Wrapped

In [None]:
from functools import wraps


def autofunc(basecls):
    @wraps(basecls, updated=())
    class WrappedClass(basecls):  # type: ignore
        def __new__(cls, *args, **kwargs):
            instance = basecls(*args, **kwargs)
            return func(instance)

    return WrappedClass


class A:
    pass


class B:
    pass