# Understand the TransCoder part idea:

In [194]:
import ast
import copy
import inspect
import logging
import functools
import sys
from typing import Callable, Type, Dict, Tuple, Any, Union, Iterable, List, Optional
import numpy as np
from functools import reduce
import numba as nb
import hashlib


def source_parser(f):
    source = inspect.getsource(f)
    if any(source.startswith(managed_method) for managed_method in ["apply", "aggregate", "source_parser"]):
        code = "    return " + source[source.find('x:') + 2:source.find('mapping')].strip().strip(',').strip(':').strip(')')
    elif f.__module__ != '__main__':
        code = "    import " + f.__module__ + '\n'
        code += "    return " + f.__module__ + '.' + f.__qualname__ +'(x)'
    else:
        return source
    uid = int(hashlib.sha1(code.encode("utf-8")).hexdigest(), 16) % (10 ** 8)
    return (f"""
def f{uid}(x):
{code} 
    """)
    

def create_if_callmap_function_ast(index: Iterable[str], fuid: str) -> ast.FunctionDef:


    body = []
    for value, key in enumerate(index):
        compare = ast.Compare(
            left=ast.Name(id='x', ctx=ast.Load()),
            ops=[ast.Eq()],
            comparators=[ast.Constant(value = key)]
        )
        body.append(
            ast.If(
                test=compare,
                body=[ast.Return(value=ast.Constant(value=value))],
                orelse=[]
            )
        )
    
    # Add a default return statement
    body.append(ast.Return(value=ast.Constant(value=value+1)))

    # Create the function definition
    func_def = ast.FunctionDef(
        name=(callmap_name:=f'callmap_{fuid}'[:64]),
        args=ast.arguments(
            posonlyargs=[],
            args=[ast.arg(arg='x')],
            vararg=None,
            kwonlyargs=[],
            kw_defaults=[],
            kwarg=None,
            defaults=[]
        ),
        body=body,
        decorator_list=[],
        returns=None
    )
    
    callmap_str = ast.unparse(ast.fix_missing_locations(ast.Module(body=[func_def], type_ignores=[])))

    decorator = "@nb.njit(nb.uint8(nb.types.string), cache=True, parallel=True, fastmath=True, forceinline=True, looplift=True, inline='always', target_backend='host', no_cfunc_wrapper=True, no_rewrites=True, nopython=True, nogil=True)\n"
    
    return decorator + callmap_str, callmap_name

def create_func_callmap_function_ast(funcs_list: Iterable[str]) -> ast.FunctionDef:

    
    body = []
    for idx, func in enumerate(funcs_list):
        compare = ast.Compare(
            left=ast.Name(id='idx', ctx=ast.Load()),
            ops=[ast.Eq()],
            comparators=[ast.Constant(value=idx)]
        )
        body.append(
            ast.If(
                test=compare,
                body=[ast.Return(value=ast.Call(func=ast.Name(id=func, ctx=ast.Load()),
                                               args=[ast.Name(id='x', ctx=ast.Load())], 
                                               keywords=[]))],
                orelse=[]
            )
        )
    # Add a default return statement
    body.append(ast.Return(value=ast.Call(func=ast.Name(id=func, ctx=ast.Load()),
                                               args=[ast.Name(id='x', ctx=ast.Load())], 
                                               keywords=[])))


    # Create the function definition
    func_def = ast.FunctionDef(
        name=(callmap_name:=f'func_callmap_{"".join((func for func in funcs_list))}'[:64]),
        args=ast.arguments(
            posonlyargs=[],
            args=[ast.arg(arg='idx'), ast.arg(arg='x')],
            vararg=None,
            kwonlyargs=[],
            kw_defaults=[],
            kwarg=None,
            defaults=[]
        ),
        body=body,
        decorator_list=[],
        returns=None
    )
    decorator = "@nb.njit(nb.PLACEHOLDERDTYPE(nb.uint8, nb.PLACEHOLDERDTYPE[:]), cache=True, parallel=True, fastmath=True, forceinline=True, looplift=True, inline='always', target_backend='host', no_cfunc_wrapper=True, no_rewrites=True, nogil=True)\n"
    
    callmap_str = ast.unparse(ast.fix_missing_locations(ast.Module(body=[func_def], type_ignores=[])))

    return decorator + callmap_str, callmap_name

class SubscriptReplacer(ast.NodeTransformer):
    """
    AST Node Transformer to replace subscript expressions in a function's AST.

    This class is a custom AST NodeTransformer that traverses a function's AST and
    replaces subscript expressions (e.g., array[index]) based on a specified argument name.

    Attributes
    ----------
    arg_name : str
        The name of the argument whose subscript expressions are to be transformed.

    Methods
    -------
    visit_Subscript(node):
        Visit a Subscript node in the AST and replace it if it matches the specified argument.

    Examples
    --------
    >>> replacer = SubscriptReplacer('arg_name')
    >>> modified_tree = replacer.visit(original_tree)
    """
    
    def __init__(self, arg_name: str, new_var_name: str, index: List[str], for_vectorize_form: bool, axis: int, ndims : int, callmap_name: Optional[str] = None, callmap_def: Optional[str] = None):
        """
        Initialize the SubscriptReplacer with the specified argument name.

        Parameters
        ----------
        arg_name : str
            The name of the argument to target for subscript replacement.

        Examples
        --------
        >>> replacer = SubscriptReplacer('data')
        """

        self.arg_name = arg_name
        self.new_var_name = new_var_name
        self.callmap_name = callmap_name
        self.callmap_def = callmap_def
        self.index = index
        self.for_vectorize_form = for_vectorize_form
        self.axis = axis
        self.ndims = ndims
        self.used = False
        self.vectorizable = True
        self.fuid = '_general_'

    
    def visit_Subscript(self, node):
        if isinstance(node.value, ast.Name) and node.value.id == self.arg_name:
            # Check for Python version compatibility
            if self.fuid == '_general_':
                self.fuid = hashlib.sha1("".join(self.index.astype(str)).encode("UTF-8")).digest().hex()
            if sys.version_info >= (3, 9):
                # Python 3.9 and later
                old_slice = node.slice
            else:
                # Python 3.8 and earlier
                old_slice = node.slice.value if isinstance(node.slice, ast.Index) else node.slice
            if isinstance(old_slice, ast.Constant) and len((idxs:=np.where(self.index==old_slice.value)[0]))==1:
                new_value = ast.Constant(value = idxs[0])
            else:
                if self.callmap_name is None:
                    self.callmap_def, self.callmap_name = create_if_callmap_function_ast(self.index, self.fuid)
                new_value = ast.Call(
                        func=ast.Name(id=self.callmap_name, ctx=ast.Load()),
                        args=[old_slice],
                        keywords=[]
                    )
                self.vectorizable = False
            if self.for_vectorize_form:
                node.slice = ast.Tuple(
                   elts=[
                       ast.Slice() if i != self.axis
                       else new_value
                       for i in range(self.ndims)],
                   ctx=ast.Load())
            else:
                node.slice = new_value
            node.value.id = self.new_var_name 
            # Wrap the subscript in a call to callmap
        return self.generic_visit(node) 

    def visit_If(self, node):
        self.vectorizable = False
        return self.generic_visit(node) 

    def visit_IfExp(self, node):
        self.vectorizable = False
        return self.generic_visit(node) 

    def visit_With(self, node):
        self.vectorizable = False
        return self.generic_visit(node) 

    def visit_While(self, node):
        self.vectorizable = False
        return self.generic_visit(node) 
    
    def visit_Name(self, node):
        if node.id == self.arg_name:
            self.vectorizable = False
            node.id = self.new_var_name
        return self.generic_visit(node) 

def AstModifier(original_funcs, index, axis, ndims):

    callmap_name, callmap_def = None, None
    modified_funcs, modified_names = [], []
    for original_func in original_funcs:
        modified_tree, vectorize_tree = None, None
        source_code = source_parser(original_func)
        original_tree = ast.parse(source_code)
    
        arg_name = original_tree.body[0].args.args[0].arg
        new_name = arg_name + "X"
        while new_name in source_code:
            new_name + "X"
            
        replacer = SubscriptReplacer(arg_name, new_name, index, for_vectorize_form=False, axis = axis, ndims = ndims, callmap_name = callmap_name, callmap_def = callmap_def)
        
        modified_tree = copy.deepcopy(original_tree)
        modified_tree = replacer.visit(modified_tree)
        callmap_name, callmap_def = replacer.callmap_name, replacer.callmap_def
        ast.fix_missing_locations(modified_tree)
        fuid = original_func.__name__ + replacer.fuid 
        modified_tree.body[0].args.args[0].arg = new_name
        modified_tree.body[0].name = (fuid +"_modified")[:64]
        ast.fix_missing_locations(modified_tree)
        modified_funcs.append(ast.unparse(modified_tree) + "\n\n" )
        modified_names.append(modified_tree.body[0].name)

    final_func_callmap, final_func_callmap_name = create_func_callmap_function_ast(modified_names)

    decorator = "@nb.njit(nb.PLACEHOLDERDTYPE(nb.PLACEHOLDERDTYPE[:]), cache=True, fastmath=True, forceinline=True, looplift=True, inline='always', target_backend='host', no_cfunc_wrapper=True, no_rewrites=True, nogil=True)\n"

    final_text = decorator + decorator.join(modified_funcs) + "\n"
    final_text += final_func_callmap
    if callmap_def is not None:
        final_text = callmap_def + "\n\n" + final_text
    
    return final_func_callmap_name, final_text



In [178]:
def func_a(x):
    return np.sum(x)

def func_b(x):
    return np.mean(x)
    
final_func_callmap_name, final_form = AstModifier([func_a, func_b], df.columns.values, 1, 2)
print(final_form)


@nb.njit(nb.PLACEHOLDERDTYPE(nb.PLACEHOLDERDTYPE[:]), cache=True, parallel=True, fastmath=True, forceinline=True, looplift=True, inline='always', target_backend='host', no_cfunc_wrapper=True, no_rewrites=True, nogil=True)
def func_a_general__modified(xX):
    return np.sum(xX)

@nb.njit(nb.PLACEHOLDERDTYPE(nb.PLACEHOLDERDTYPE[:]), cache=True, parallel=True, fastmath=True, forceinline=True, looplift=True, inline='always', target_backend='host', no_cfunc_wrapper=True, no_rewrites=True, nogil=True)
def func_b_general__modified(xX):
    return np.mean(xX)


@nb.njit(nb.PLACEHOLDERDTYPE(nb.uint8, nb.PLACEHOLDERDTYPE[:]), cache=True, parallel=True, fastmath=True, forceinline=True, looplift=True, inline='always', target_backend='host', no_cfunc_wrapper=True, no_rewrites=True, nogil=True)
def func_callmap_func_a_general__modifiedfunc_b_general__modified(idx, x):
    if idx == 0:
        return func_a_general__modified(x)
    if idx == 1:
        return func_b_general__modified(x)
    return fu

## Step A: Transform the code so that it works with array index instead of columns name (or non rangeIndex in fact) of pandas

### NB: For applicable function offer a vectorization of it directly (no compilation, leveraging directly numpy) 
#### applicable means no if, no need of callmap, all use of inputs refers to one of it's columns/rows ..


In [204]:
import pandas as pd
df = pd.DataFrame(np.random.randn(100_000,3), columns=['A','B', 'C']).astype(np.float32)
axis = 1

def f0(z):
    return z['A'] + z['B']
    
final_func_callmap_name, final_form = AstModifier([f0], df.columns.values, 1, 2)
print(final_form)


@nb.njit(nb.PLACEHOLDERDTYPE(nb.PLACEHOLDERDTYPE[:]), cache=True, fastmath=True, forceinline=True, looplift=True, inline='always', target_backend='host', no_cfunc_wrapper=True, no_rewrites=True, nogil=True)
def f03c01bdbb26f358bab27f267924aa2c9a03fcfdb8_modified(zX):
    return zX[0] + zX[1]


@nb.njit(nb.PLACEHOLDERDTYPE(nb.uint8, nb.PLACEHOLDERDTYPE[:]), cache=True, parallel=True, fastmath=True, forceinline=True, looplift=True, inline='always', target_backend='host', no_cfunc_wrapper=True, no_rewrites=True, nogil=True)
def func_callmap_f03c01bdbb26f358bab27f267924aa2c9a03fcfdb8_modified(idx, x):
    if idx == 0:
        return f03c01bdbb26f358bab27f267924aa2c9a03fcfdb8_modified(x)
    return f03c01bdbb26f358bab27f267924aa2c9a03fcfdb8_modified(x)


In [180]:
def f0(z):
    return z['A'] + z['B']
    
def f1(z):
    return z['A' if z['B'] > 0 else 'B'] * z['B']

def f2(z):
    return np.sum(z)
    
final_func_callmap_name, final_form = AstModifier([f0, f1, f2], df.columns.values, 1, 2)
print(final_form)


@nb.njit(nb.uint8(nb.types.string), cache=True, parallel=True, fastmath=True, forceinline=True, looplift=True, inline='always', target_backend='host', no_cfunc_wrapper=True, no_rewrites=True, nopython=True, nogil=True)
def callmap_3c01bdbb26f358bab27f267924aa2c9a03fcfdb8(x):
    if x == 'A':
        return 0
    if x == 'B':
        return 1
    if x == 'C':
        return 2
    return 3

@nb.njit(nb.PLACEHOLDERDTYPE(nb.PLACEHOLDERDTYPE[:]), cache=True, parallel=True, fastmath=True, forceinline=True, looplift=True, inline='always', target_backend='host', no_cfunc_wrapper=True, no_rewrites=True, nogil=True)
def f03c01bdbb26f358bab27f267924aa2c9a03fcfdb8_modified(zX):
    return zX[0] + zX[1]

@nb.njit(nb.PLACEHOLDERDTYPE(nb.PLACEHOLDERDTYPE[:]), cache=True, parallel=True, fastmath=True, forceinline=True, looplift=True, inline='always', target_backend='host', no_cfunc_wrapper=True, no_rewrites=True, nogil=True)
def f13c01bdbb26f358bab27f267924aa2c9a03fcfdb8_modified(zX):
    return zX[ca

In [205]:
def f2(z):
    return np.sum(z)
    
final_func_callmap_name, final_form = AstModifier([f2], df.columns.values, 1, 2)
print(final_form)

@nb.njit(nb.PLACEHOLDERDTYPE(nb.PLACEHOLDERDTYPE[:]), cache=True, fastmath=True, forceinline=True, looplift=True, inline='always', target_backend='host', no_cfunc_wrapper=True, no_rewrites=True, nogil=True)
def f2_general__modified(zX):
    return np.sum(zX)


@nb.njit(nb.PLACEHOLDERDTYPE(nb.uint8, nb.PLACEHOLDERDTYPE[:]), cache=True, parallel=True, fastmath=True, forceinline=True, looplift=True, inline='always', target_backend='host', no_cfunc_wrapper=True, no_rewrites=True, nogil=True)
def func_callmap_f2_general__modified(idx, x):
    if idx == 0:
        return f2_general__modified(x)
    return f2_general__modified(x)


## Step 2 wrapped the code inside the pattern developed to optimize numba usage

In [182]:

two_dim = """
PLACEHOLDERFUNC

@nb.njit((nb.PLACEHOLDERDTYPE[:,:], nb.PLACEHOLDERDTYPE[:, :],nb.types.uint32), cache=True, parallel=True, fastmath=True, forceinline=True, looplift=True, inline='always', target_backend='host', no_cfunc_wrapper=True, no_rewrites=True, nogil=True)
def cvlopt_PLACEHOLDERNAME(z, r, n):
    for i in nb.prange(n):
        for j in nb.prange(PLACEHOLDERNFUNC):
            r[i, j] = PLACEHOLDERNAME(j, zPLACEHOLDERAXISSELECT)

def vlopt_PLACEHOLDERNAME(z):
    n = np.uint32(z.shape[PLACEHOLDERAXIS])  # Determine the number of rows in z
    r = np.empty((n, PLACEHOLDERNFUNC), dtype=np.PLACEHOLDERDTYPE)  # Initialize the result array
    cvlopt_PLACEHOLDERNAME(z, r, n)
    return r
"""

def _make_func(base_model: str, name: str, func: str, dtype: str, axis: str, axis_select: str, n_funcs: int):
    final = base_model.replace('PLACEHOLDERNAME', name)
    final = final.replace('PLACEHOLDERFUNC', func)
    final = final.replace('PLACEHOLDERDTYPE', dtype)
    final = final.replace('PLACEHOLDERAXISSELECT', axis_select)
    final = final.replace('PLACEHOLDERAXIS', axis)
    final = final.replace('PLACEHOLDERNFUNC', str(n_funcs))
    return final
    

def _compile_tree(func: Dict[str, str], exec_globals: Dict[str, Any]) -> Dict:
    def wrapped(x):
        if func['name'] not in exec_globals:
            exec(compile(ast.parse(func['source_code']), filename="/tmp/numbacache", mode="exec"), exec_globals)
        return exec_globals[func['name']](x)
    wrapped.__name__ = func['name']
    wrapped.__qualname__ = "__numba__." + func['name']
    return wrapped

def _prepare_func(original_funcs: Union[Callable, List[Callable]], index: Iterable[str], axis: int = 0, ndims: int = 2, dtype: str = 'float32') -> Tuple[ast.AST, ast.AST, ast.AST]:
    result = {}
    if isinstance(original_funcs, Callable):
        original_funcs = [original_funcs]
    final_func_callmap_name, final_form = AstModifier(original_funcs = original_funcs, index = index, axis = axis, ndims = ndims)

    name_with_id = final_func_callmap_name

    if ndims == 2:
        final_name =  "vlopt_" + final_func_callmap_name 
        if axis > 1:
            raise ValueError    
        result = {"name": final_name, "source_code": _make_func(two_dim, final_func_callmap_name, final_form, dtype, "0" if axis else "1", "[i, :]" if axis else "[:, i]", len(original_funcs))}
        result["function"] = _compile_tree(result, globals())
        return result
    elif ndims == 3:
        raise NotImplemented

In [198]:
def simple_start(z):
    x = (z['A'] + z['B']) / z['C']
    x += z['B'] * z['D']
    return x / z['B']

def harder_func(z):
    x = (z['A'] + z['B']) / z['C']
    if x > 0:
        return x / z['B']
    x += z['B'] * z['D']
    return x * z['B']

def harder2_func(z):
    x = (z['A'] + z['B']) / z['C']
    k=z['A']-z['C']
    j=z['B']/z['D']
    if k > j:
        return x / k
    x *= j
    return x - k if k > z['C'] else x + k


def harder3_func(z):
    g=lambda a, b: a if abs(a) > abs(b) else - 2 * (b**(-a))
    x = (z['A'] + z['B']) / z['C']
    k=z['A' if z['B'] > 0 else 'D']-z['C']
    j=g(z['B'],z['D'])
    if k > j:
        return j / k
    x *= j
    return x - k if k > z['C'] else x + k

def harder4_func(z):
    g=lambda a, b: a if abs(a) > abs(b) else - 2 * (b**(-int(a)))
    x = (z['A'] + z['B']) / z['C']
    k=z['A' if z['B'] > 0 else 'D']-z['C']
    j=g(z['B'],z['D'])
    if k > j:
        return j / k
    x *= j
    return x - k if k > z['C'] else x + k


df = pd.DataFrame(np.random.randn(int(1e4), 4), columns=['A', 'B', 'C', 'D']).astype(np.float32)

func_list = [simple_start, harder_func, harder2_func, harder3_func, harder4_func]

prepared = _prepare_func(func_list, df.columns.values.astype(str) if axis==1 else df.index.values.astype(str), axis = axis, ndims = 2, dtype = "float32")

print(prepared['source_code'])




@nb.njit(nb.uint8(nb.types.string), cache=True, parallel=True, fastmath=True, forceinline=True, looplift=True, inline='always', target_backend='host', no_cfunc_wrapper=True, no_rewrites=True, nopython=True, nogil=True)
def callmap_fb2f85c88567f3c8ce9b799c7c54642d0c7b41f6(x):
    if x == 'A':
        return 0
    if x == 'B':
        return 1
    if x == 'C':
        return 2
    if x == 'D':
        return 3
    return 4

@nb.njit(nb.float32(nb.float32[:]), cache=True, fastmath=True, forceinline=True, looplift=True, inline='always', target_backend='host', no_cfunc_wrapper=True, no_rewrites=True, nogil=True)
def simple_startfb2f85c88567f3c8ce9b799c7c54642d0c7b41f6_modified(zX):
    x = (zX[0] + zX[1]) / zX[2]
    x += zX[1] * zX[3]
    return x / zX[1]

@nb.njit(nb.float32(nb.float32[:]), cache=True, fastmath=True, forceinline=True, looplift=True, inline='always', target_backend='host', no_cfunc_wrapper=True, no_rewrites=True, nogil=True)
def harder_funcfb2f85c88567f3c8ce9b799c7c54642d

In [199]:
pd.DataFrame(prepared['function'](df.to_numpy()), index=df.index, columns=[func.__name__ for func in func_list])

Unnamed: 0,simple_start,harder_func,harder2_func,harder3_func,harder4_func
0,-0.801731,-0.699993,10.030642,-0.438783,-0.438783
1,-3.405339,-2.399083,5.226829,5.529442,5.529442
2,-0.564130,-1.301725,-1.160487,0.590154,0.590154
3,-2.598243,-1.199597,-1.710738,-3.996047,-5.040822
4,-0.725178,-1.813862,8.013164,-1.557819,-1.557819
...,...,...,...,...,...
9995,0.619576,0.934484,-0.342752,0.686632,0.686632
9996,-74.669060,-75.971443,-16.956228,-1.715873,-1.493752
9997,-10.619404,-1.277432,-6.756509,-2.959808,-2.959808
9998,0.492677,-0.953591,2.849889,-10.831484,-8.474146


In [208]:
df.aggregate([simple_start, harder_func, harder2_func, harder3_func, harder4_func], axis=1)

KeyError: 'D'

In [201]:
%timeit pd.DataFrame(prepared['function'](df.to_numpy()), index=df.index, columns=[func.__name__ for func in func_list])

127 µs ± 1.15 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


In [202]:
%timeit df.apply([simple_start, harder_func, harder2_func, harder3_func, harder4_func], axis=1)

  g=lambda a, b: a if abs(a) > abs(b) else - 2 * (b**(-a))
  g=lambda a, b: a if abs(a) > abs(b) else - 2 * (b**(-a))
  g=lambda a, b: a if abs(a) > abs(b) else - 2 * (b**(-a))
  g=lambda a, b: a if abs(a) > abs(b) else - 2 * (b**(-a))
  g=lambda a, b: a if abs(a) > abs(b) else - 2 * (b**(-a))
  g=lambda a, b: a if abs(a) > abs(b) else - 2 * (b**(-a))
  g=lambda a, b: a if abs(a) > abs(b) else - 2 * (b**(-a))
  g=lambda a, b: a if abs(a) > abs(b) else - 2 * (b**(-a))


3.3 s ± 18.3 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
