## imports

In [1]:
import ast
import inspect
import types
import numpy as np
import numba as nb
import pandas as pd
import time
import logging
import functools
import copy
import sys
from typing import Callable, Type
logger = logging.getLogger()
logger.setLevel(0)

## dev imports

In [2]:
import astpretty

## Current state

In [3]:
def create_callmap_function_ast(mapping):
    # Create the body of the callmap function
    body = []
    for key, value in mapping.items():
        compare = ast.Compare(
            left=ast.Name(id='x', ctx=ast.Load()),
            ops=[ast.Eq()],
            comparators=[ast.Str(s=key)]
        )
        body.append(
            ast.If(
                test=compare,
                body=[ast.Return(value=ast.Num(n=value))],
                orelse=[]
            )
        )
    
    # Add a default return statement
    body.append(ast.Return(value=ast.Name(id='x', ctx=ast.Load())))

    # Create the function definition
    func_def = ast.FunctionDef(
        name='callmap',
        args=ast.arguments(
            posonlyargs=[],
            args=[ast.arg(arg='x')],
            vararg=None,
            kwonlyargs=[],
            kw_defaults=[],
            kwarg=None,
            defaults=[]
        ),
        body=body,
        decorator_list=[],
        returns=None
    )
    return func_def

class SubscriptReplacer(ast.NodeTransformer):
    def __init__(self, arg_name):
        self.arg_name = arg_name

    def visit_Subscript(self, node):
        if isinstance(node.value, ast.Name) and node.value.id == self.arg_name:
            # Check for Python version compatibility
            if sys.version_info >= (3, 9):
                # Python 3.9 and later
                old_slice = node.slice
            else:
                # Python 3.8 and earlier
                old_slice = node.slice.value if isinstance(node.slice, ast.Index) else node.slice

            # Wrap the subscript in a call to callmap
            node.slice = ast.Call(
                func=ast.Name(id='callmap', ctx=ast.Load()),
                args=[old_slice],
                keywords=[]
            )
        return self.generic_visit(node) 

def create_transformed_function_ast(original_func, mapping):
    # Parse the original function
    original_tree = ast.parse(inspect.getsource(original_func))
    arg_name = original_tree.body[0].args.args[0].arg
    
    # Rename the original function
    original_tree.body[0].name = 'temporary'
    
    # Apply the AST transformation
    replacer = SubscriptReplacer(arg_name)
    original_tree = replacer.visit(original_tree)
    ast.fix_missing_locations(original_tree)

    # Replace dictionary accesses with callmap in the original function
    # This would be similar to the code in SubscriptReplacer

    # Create a new function that applies 'temporary' over an array
    new_func_code = f"""
def {original_func.__qualname__}_numba_compiled(Z):
    n = Z.shape[0]
    res = np.zeros((n, 1))
    for i in nb.prange(n):
        res[i, 0] = temporary(Z[i, :])
    return res
    """
    vecv_func_code = f"""
def {original_func.__qualname__}_vectorized(Z):
    return temporary(Z)
    """
    new_func_tree = ast.parse(new_func_code)
    new_Vfunc_tree = ast.parse(vecv_func_code)

    # # Add Numba JIT decorator
    numba_decorator = ast.Call(
        func=ast.Attribute(value=ast.Name(id='nb', ctx=ast.Load()), attr='jit', ctx=ast.Load()),
        args=[],
        keywords=[
            ast.keyword(arg='nopython', value=ast.NameConstant(value=True)),
            ast.keyword(arg='nogil', value=ast.NameConstant(value=True)),
            ast.keyword(arg='parallel', value=ast.NameConstant(value=True))
        ]
    )
    return original_tree, new_func_tree, new_Vfunc_tree, numba_decorator

def _prepare_funcs(original_func, mapping):
    exec_globals = globals().copy()
    exec_globals.update({'np': np, 'nb': nb})
    callmap_func_def = create_callmap_function_ast(mapping)
    callmap_func_ast = ast.fix_missing_locations(ast.Module(body=[callmap_func_def], type_ignores=[]))
    exec(compile(callmap_func_ast, filename="<ast>", mode="exec"), exec_globals)
    original_func_ast, new_func_ast, new_Vfunc_ast, numba_decorator = create_transformed_function_ast(original_func, mapping)
    exec(compile(original_func_ast, filename="<ast>", mode="exec"), exec_globals)
    
    new_func_ast.body[0].body.insert(0, callmap_func_ast.body[0])
    new_func_ast.body[0].body.insert(1, original_func_ast.body[0])

    new_Vfunc_ast.body[0].body.insert(0, callmap_func_ast.body[0])
    new_Vfunc_ast.body[0].body.insert(1, original_func_ast.body[0])
    
    new_func_ast_no_numba = copy.deepcopy(new_func_ast)
    new_func_ast_no_numba.body[0].name = new_func_ast.body[0].name.replace('_numba_compiled', '_safe_loop')
    new_func_ast.body[0].decorator_list.append(numba_decorator)
    
    new_func_ast = ast.fix_missing_locations(new_func_ast)
    new_func_ast_no_numba = ast.fix_missing_locations(new_func_ast_no_numba)
    new_Vfunc_ast = ast.fix_missing_locations(new_Vfunc_ast)

    available_funcs = []
    try:
        exec(compile(new_func_ast, filename="<ast>", mode="exec"), exec_globals)
        available_funcs.append(exec_globals[original_func.__qualname__+'_numba_compiled'])
    except:
        pass
    try:
        exec(compile(new_func_ast_no_numba, filename="<ast>", mode="exec"), exec_globals)
        available_funcs.append(exec_globals[original_func.__qualname__+'_vectorized'])
    except:
        pass
    try:
        exec(compile(new_Vfunc_ast, filename="<ast>", mode="exec"), exec_globals)
        available_funcs.append(exec_globals[original_func.__qualname__+'_safe_loop'])
    except:
        pass
    return tuple(available_funcs)


def dynamic_decorator(func, column_name_to_index, numba_decorate = True):
    # Parse the original function
    func_ast = ast.parse(inspect.getsource(func))
    function_def = func_ast.body[0]

    if numba_decorate:
        decorator_str = "nb.jit(nopython=True, nogil=True, parallel=True, cache=True)"
        # Add the decorator
        decorator = ast.parse(decorator_str).body[0].value
        function_def.decorator_list = [decorator]
        
    
    # Get the first argument name
    arg_name = function_def.args.args[0].arg
    mapping_assignment = ast.parse(f"mapping = {column_name_to_index}").body[0]
    
    # Transform the function
    transformer = TransformFunction(arg_name, mapping_assignment)
    transformed_ast = transformer.visit(func_ast)
    
    # Fix missing locations
    ast.fix_missing_locations(transformed_ast)
    
    # Compile the transformed AST
    compiled_code = compile(transformed_ast, filename="<ast>", mode="exec")

    # Find the correct code object for the function
    func_code = next(c for c in compiled_code.co_consts if isinstance(c, types.CodeType) and c.co_name == func.__name__)

    # Create a new function object
    new_func = types.FunctionType(func_code, globals(), func.__name__)

    return new_func


def make_class_decorator(function_decorator: Callable) -> Callable:
    """
    Creates a class decorator from a given function decorator.

    Args:
        function_decorator (Callable): A function decorator to be applied to class methods.

    Returns:
        Callable: A class decorator.
    """
    @functools.wraps(function_decorator)
    def class_decorator(cls: Type) -> Type:
        """
        The class decorator generated from the function decorator.

        Args:
            cls (Type): The class to which the decorator is applied.

        Returns:
            Type: The decorated class.
        """
        for attr_name, attr_value in cls.__bases__[0].__dict__.items():
            if callable(attr_value) and not attr_name.startswith('_') and attr_name not in cls.__dict__:
                setattr(cls, attr_name, function_decorator(attr_value))
        for attr_name, attr_value in cls.__dict__.items():
             if callable(attr_value) and not attr_name.startswith('_'):
                setattr(cls, attr_name, function_decorator(attr_value))
        return cls
    return class_decorator

def autowrap_pandas_return(fn: Callable) -> Callable:
    """
    Decorator to add validation and error handling to class methods.

    Args:
        fn (Callable): The original method of the class.

    Returns:
        Callable: The decorated method with added validation and error handling.
    """
    @functools.wraps(fn)
    def wrapper(self, *args, **kwargs):
        if self._outside_call:
            self._outside_call = False
            res = fn(self, *args, **kwargs)
            if isinstance(res, pd.DataFrame):
                res = dwrap(res)
            self._outside_call = True
            return res
        return fn(self, *args, **kwargs)
    return wrapper

@make_class_decorator(autowrap_pandas_return)
class dwrap(pd.DataFrame):
    _compiled_func = None
    _outside_call = True

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self._compiled_func = {}

    @property
    def __name__(self):
        return functools.reduce(lambda x, y: x + y, self.name_to_index.keys())

    @property
    def colname_to_colnum(self):
        return {k: i for i, k in enumerate(self.columns)}

    @property
    def rowname_to_rownum(self):
        return {k: i for i, k in enumerate(self.index)}
    
    def _compiled_qualifier(self, func_qualifier, mapper):
        return hash(functools.reduce(lambda x, y: f'{x}&{y}', mapper) + func_qualifier)

    def apply(self, func, axis = 0, *args, **kwargs):
        if args or kwargs:
            logger.warning(f'{__class__} apply only supports func and axis arguments, using default pandas apply')
            return super().apply(func, axis = 0, *args, **kwargs)
        return dwrap((self._compiled_func.get((name:=self._compiled_qualifier(func_qualifier = func.__qualname__, mapper=(mapper:=self.colname_to_colnum if axis else self.rowname_to_rownum)))) or self.build_apply(func, mapper, name))(self.to_numpy() if axis else self.to_numpy().T), index = self.index if axis else self.columns)
        
    def build_apply(self, func, map, name):
        print(func, map, name)
        self._compiled_func[name] = _prepare_funcs(func, map, numba_decorate = True)
        return self._compiled_func[name]


## Func compilation / live modification performance test

In [4]:
def create_callmap_function_ast(mapping):
    # Create the body of the callmap function
    body = []
    for key, value in mapping.items():
        compare = ast.Compare(
            left=ast.Name(id='x', ctx=ast.Load()),
            ops=[ast.Eq()],
            comparators=[ast.Str(s=key)]
        )
        body.append(
            ast.If(
                test=compare,
                body=[ast.Return(value=ast.Num(n=value))],
                orelse=[]
            )
        )
    
    # Add a default return statement
    body.append(ast.Return(value=ast.Name(id='x', ctx=ast.Load())))

    # Create the function definition
    func_def = ast.FunctionDef(
        name='callmap',
        args=ast.arguments(
            posonlyargs=[],
            args=[ast.arg(arg='x')],
            vararg=None,
            kwonlyargs=[],
            kw_defaults=[],
            kwarg=None,
            defaults=[]
        ),
        body=body,
        decorator_list=[],
        returns=None
    )
    return func_def

class SubscriptReplacer(ast.NodeTransformer):
    def __init__(self, arg_name):
        self.arg_name = arg_name

    def visit_Subscript(self, node):
        if isinstance(node.value, ast.Name) and node.value.id == self.arg_name:
            # Check for Python version compatibility
            if sys.version_info >= (3, 9):
                # Python 3.9 and later
                old_slice = node.slice
            else:
                # Python 3.8 and earlier
                old_slice = node.slice.value if isinstance(node.slice, ast.Index) else node.slice

            # Wrap the subscript in a call to callmap
            node.slice = ast.Call(
                func=ast.Name(id='callmap', ctx=ast.Load()),
                args=[old_slice],
                keywords=[]
            )
        return self.generic_visit(node) 

def create_transformed_function_ast(original_func, mapping):
    # Parse the original function
    original_tree = ast.parse(inspect.getsource(original_func))
    arg_name = original_tree.body[0].args.args[0].arg
    
    # Rename the original function
    original_tree.body[0].name = 'temporary'
    
    # Apply the AST transformation
    replacer = SubscriptReplacer(arg_name)
    original_tree = replacer.visit(original_tree)
    ast.fix_missing_locations(original_tree)

    # Replace dictionary accesses with callmap in the original function
    # This would be similar to the code in SubscriptReplacer

    # Create a new function that applies 'temporary' over an array
    new_func_code = f"""
def {original_func.__qualname__}_numba_compiled(Z):
    n = Z.shape[0]
    res = np.zeros((n, 1))
    for i in nb.prange(n):
        res[i, 0] = temporary(Z[i, :])
    return res
    """
    vecv_func_code = f"""
def {original_func.__qualname__}_vectorized(Z):
    return temporary(Z)
    """
    new_func_tree = ast.parse(new_func_code)
    new_Vfunc_tree = ast.parse(vecv_func_code)

    # # Add Numba JIT decorator
    numba_decorator = ast.Call(
        func=ast.Attribute(value=ast.Name(id='nb', ctx=ast.Load()), attr='jit', ctx=ast.Load()),
        args=[],
        keywords=[
            ast.keyword(arg='nopython', value=ast.NameConstant(value=True)),
            ast.keyword(arg='nogil', value=ast.NameConstant(value=True)),
            ast.keyword(arg='parallel', value=ast.NameConstant(value=True))
        ]
    )
    return original_tree, new_func_tree, new_Vfunc_tree, numba_decorator

def create_module_ast(original_func, mapping):
    exec_globals = globals().copy()
    exec_globals.update({'np': np, 'nb': nb})
    callmap_func_def = create_callmap_function_ast(mapping)
    callmap_func_ast = ast.fix_missing_locations(ast.Module(body=[callmap_func_def], type_ignores=[]))
    exec(compile(callmap_func_ast, filename="<ast>", mode="exec"), exec_globals)
    original_func_ast, new_func_ast, new_Vfunc_ast, numba_decorator = create_transformed_function_ast(original_func, mapping)
    exec(compile(original_func_ast, filename="<ast>", mode="exec"), exec_globals)
    
    new_func_ast.body[0].body.insert(0, callmap_func_ast.body[0])
    new_func_ast.body[0].body.insert(1, original_func_ast.body[0])

    new_Vfunc_ast.body[0].body.insert(0, callmap_func_ast.body[0])
    new_Vfunc_ast.body[0].body.insert(1, original_func_ast.body[0])
    
    new_func_ast_no_numba = copy.deepcopy(new_func_ast)
    new_func_ast_no_numba.body[0].name = new_func_ast.body[0].name.replace('_numba_compiled', '_safe_loop')
    new_func_ast.body[0].decorator_list.append(numba_decorator)
    
    new_func_ast = ast.fix_missing_locations(new_func_ast)
    new_func_ast_no_numba = ast.fix_missing_locations(new_func_ast_no_numba)
    new_Vfunc_ast = ast.fix_missing_locations(new_Vfunc_ast)

    exec(compile(new_func_ast, filename="<ast>", mode="exec"), exec_globals)
    exec(compile(new_func_ast_no_numba, filename="<ast>", mode="exec"), exec_globals)
    exec(compile(new_Vfunc_ast, filename="<ast>", mode="exec"), exec_globals)

    return exec_globals[original_func.__qualname__+'_numba_compiled'], exec_globals[original_func.__qualname__+'_vectorized'], exec_globals[original_func.__qualname__+'_safe_loop']


In [5]:
# Example usage
def simple_start(z):
    x = (z['A'] + z['B']) / z['C']
    x += z['B'] * z['D']
    return x / z['B']

def harder_func(z):
    x = (z['A'] + z['B']) / z['C']
    if x > 0:
        return x / z['B']
    x += z['B'] * z['D']
    return x * z['B']

def harder2_func(z):
    x = (z['A'] + z['B']) / z['C']
    if (k:=z['A']-z['C']) > (j:=z['B']/z['D']):
        return x / k
    x *= j
    return x - k if k > z['C'] else x + k

def harder3_func(z):
    g=lambda a, b: a if abs(a) > abs(b) else - 2 * (b**(-a))
    x = (z['A'] + z['B']) / z['C']
    if (k:=z['A']-z['C']) > (j:=g(z['B'],z['D'])):
        return j / k
    x *= j
    return x - k if k > z['C'] else x + k

df=pd.DataFrame(data=np.random.randn(100000,4), columns = ['A', 'B', 'C', 'D']).astype(np.float32)
mapping = {k: i for i, k in enumerate(df.columns)}

for func_test in [simple_start, harder_func, harder2_func, harder3_func]:
    print('Testing ', func_test.__name__)
    opt_func, vectorized_func, safe_func = create_module_ast(func_test, mapping)
    try:
        print('df.apply(f, axis=1):', end=' ')
        %timeit df.apply(func_test, axis=1)
        print('checksum: ', np.sum(df.apply(func_test, axis=1)))
    except Exception as e:
        print('df.apply(f, axis=1) FAILED : ', e)
    try:
        print('test_func(df):', end=' ')
        %timeit func_test(df)
        print('checksum: ', np.sum(func_test(df)))
    except Exception as e:
        print('test_func(df) FAILED : ', e)
    try:
        print('safe_func(df):', end=' ')
        %timeit safe_func(df.to_numpy())
        print('checksum: ', np.sum(safe_func(df.to_numpy())))
    except Exception as e:
        print('safe_func(df) FAILED : ', e)
    try:
        print('vectorized_func(df):', end=' ')
        %timeit vectorized_func(df.to_numpy().T)
        print('checksum: ', np.sum(vectorized_func(df.to_numpy().T)))
    except Exception as e:
        print('vectorized_func(df) FAILED : ', e)
    try:
        print('opt_func(df):', end=' ')
        %timeit opt_func(df.to_numpy())
        print('checksum: ', np.sum(opt_func(df.to_numpy())))
    except Exception as e:
        print('opt_func(df) FAILED : ', e)

    print('\n\n\n')
    

Testing  simple_start
df.apply(f, axis=1): 694 ms ± 9.38 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
checksum:  -422046.25
test_func(df): 425 µs ± 8.64 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
checksum:  -422046.25
safe_func(df): 88 ms ± 1.93 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
checksum:  -422046.3152218703
vectorized_func(df): 236 µs ± 2.73 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
checksum:  -422046.25
opt_func(df): 202 µs ± 126 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)
checksum:  -422046.3152218703




Testing  harder_func
df.apply(f, axis=1): 722 ms ± 11.5 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
checksum:  -745761.56
test_func(df): test_func(df) FAILED :  The truth value of a Series is ambiguous. Use a.empty, a.bool(), a.item(), a.any() or a.all().
safe_func(df): 186 ms ± 1.03 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
checksum:  -745761.6245374094
vectorized_func(df): vectorized

  g=lambda a, b: a if abs(a) > abs(b) else - 2 * (b**(-a))


956 ms ± 4.69 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
checksum:  -178217.00026388597
test_func(df): test_func(df) FAILED :  The truth value of a Series is ambiguous. Use a.empty, a.bool(), a.item(), a.any() or a.all().
safe_func(df): 



203 ms ± 2.97 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
checksum:  nan
vectorized_func(df): vectorized_func(df) FAILED :  The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()
opt_func(df): opt_func(df) FAILED :  Failed in nopython mode pipeline (step: nopython frontend)
[1mType of variable 'closure__locals__temporary_v39__v262call_49' cannot be determined, operation: unknown operation, location: unknown location (0:0)
[1m
File "unknown location", line 0:[0m
[1m<source missing, REPL/exec in use?>[0m
[0m




