# Understand the TransCoder part idea:

In [1]:
import ast
import copy
import inspect
import logging
import functools
import sys
from typing import Callable, Type, Dict, Tuple, Any, Union, Iterable, List
import numpy as np
from functools import reduce
import numba as nb
import hashlib


def source_parser(f):
    source = inspect.getsource(f)
    if any(source.startswith(managed_method) for managed_method in ["apply", "aggregate", "source_parser"]):
        code = "    return " + source[source.find('x:') + 2:source.find('mapping')].strip().strip(',').strip(':').strip(')')
    elif f.__module__ != '__main__':
        code = "    import " + f.__module__ + '\n'
        code += "    return " + f.__module__ + '.' + f.__qualname__ +'(x)'
    else:
        return source
    uid = int(hashlib.sha1(code.encode("utf-8")).hexdigest(), 16) % (10 ** 8)
    return (f"""
def f{uid}(x):
{code} 
    """)
    

def create_callmap_function_ast(index: Iterable[str], fuid: str) -> ast.FunctionDef:


    body = []
    for value, key in enumerate(index):
        try:
            key = int(key)
            continue
        except ValueError:
            key = ast.Constant(value=key)
        compare = ast.Compare(
            left=ast.Name(id='x', ctx=ast.Load()),
            ops=[ast.Eq()],
            comparators=[key]
        )
        body.append(
            ast.If(
                test=compare,
                body=[ast.Return(value=ast.Constant(value=value))],
                orelse=[]
            )
        )
    
    # Add a default return statement
    body.append(ast.Return(value=ast.Name(id='x', ctx=ast.Load())))

    # Create the Numba JIT decorator with specified options
    numba_decorator = ast.Call(
        func=ast.Attribute(value=ast.Name(id='nb', ctx=ast.Load()), attr='njit', ctx=ast.Load()),
        args=[
            ast.Call(
                func=ast.Attribute(value=ast.Name(id='nb', ctx=ast.Load()), attr='int8', ctx=ast.Load()),
                args=[ast.Attribute(value=ast.Name(id='nb', ctx=ast.Load()), attr='types.string', ctx=ast.Load())],
                keywords=[]
            )
        ],
        keywords=[
            ast.keyword(arg='cache', value=ast.Constant(value=False)),
            ast.keyword(arg='parallel', value=ast.Constant(value=True)),
            ast.keyword(arg='fastmath', value=ast.Constant(value=True)),
            ast.keyword(arg='forceinline', value=ast.Constant(value=True)),
            ast.keyword(arg='looplift', value=ast.Constant(value=True)),
            ast.keyword(arg='inline', value=ast.Constant(value='always')),
            ast.keyword(arg='target_backend', value=ast.Constant(value='host')),
            ast.keyword(arg='no_cfunc_wrapper', value=ast.Constant(value=True)),
            ast.keyword(arg='no_rewrites', value=ast.Constant(value=True)),
            ast.keyword(arg='nopython', value=ast.Constant(value=True)),
            ast.keyword(arg='nogil', value=ast.Constant(value=True)),
        ]
    )
    
    # Create the function definition
    func_def = ast.FunctionDef(
        name=(callmap_name:=f'callmap{fuid}'),
        args=ast.arguments(
            posonlyargs=[],
            args=[ast.arg(arg='x')],
            vararg=None,
            kwonlyargs=[],
            kw_defaults=[],
            kwarg=None,
            defaults=[]
        ),
        body=body,
        decorator_list=[numba_decorator],
        returns=None
    )
    return func_def, callmap_name


class SubscriptReplacer(ast.NodeTransformer):
    """
    AST Node Transformer to replace subscript expressions in a function's AST.

    This class is a custom AST NodeTransformer that traverses a function's AST and
    replaces subscript expressions (e.g., array[index]) based on a specified argument name.

    Attributes
    ----------
    arg_name : str
        The name of the argument whose subscript expressions are to be transformed.

    Methods
    -------
    visit_Subscript(node):
        Visit a Subscript node in the AST and replace it if it matches the specified argument.

    Examples
    --------
    >>> replacer = SubscriptReplacer('arg_name')
    >>> modified_tree = replacer.visit(original_tree)
    """
    
    def __init__(self, arg_name: str, new_var_name: str, index: List[str], for_vectorize_form: bool, axis: int, ndims : int):
        """
        Initialize the SubscriptReplacer with the specified argument name.

        Parameters
        ----------
        arg_name : str
            The name of the argument to target for subscript replacement.

        Examples
        --------
        >>> replacer = SubscriptReplacer('data')
        """

        self.arg_name = arg_name
        self.new_var_name = new_var_name
        self.callmap_name = None
        self.index = index
        self.for_vectorize_form = for_vectorize_form
        self.axis = axis
        self.ndims = ndims
        self.used = False
        self.vectorizable = True
        self.fuid = '_general_'

    
    def visit_Subscript(self, node):
        if isinstance(node.value, ast.Name) and node.value.id == self.arg_name:
            # Check for Python version compatibility
            if self.fuid == '_general_':
                self.fuid = hashlib.sha1("".join(self.index.astype(str)).encode("UTF-8")).digest().hex()
            if sys.version_info >= (3, 9):
                # Python 3.9 and later
                old_slice = node.slice
            else:
                # Python 3.8 and earlier
                old_slice = node.slice.value if isinstance(node.slice, ast.Index) else node.slice
            if isinstance(old_slice, ast.Constant) and len((idxs:=np.where(self.index==old_slice.value)[0]))==1:
                new_value = ast.Constant(value = idxs[0])
            else:
                if self.callmap_name is None:
                    self.callmap_def, self.callmap_name = create_callmap_function_ast(self.index, self.fuid)
                new_value = ast.Call(
                        func=ast.Name(id=self.callmap_name, ctx=ast.Load()),
                        args=[old_slice],
                        keywords=[]
                    )
                self.vectorizable = False
            if self.for_vectorize_form:
                node.slice = ast.Tuple(
                   elts=[
                       ast.Slice() if i != self.axis
                       else new_value
                       for i in range(self.ndims)],
                   ctx=ast.Load())
            else:
                node.slice = new_value
            node.value.id = self.new_var_name 
            # Wrap the subscript in a call to callmap
        return self.generic_visit(node) 

    def visit_If(self, node):
        self.vectorizable = False
        return self.generic_visit(node) 

    def visit_IfExp(self, node):
        self.vectorizable = False
        return self.generic_visit(node) 

    def visit_With(self, node):
        self.vectorizable = False
        return self.generic_visit(node) 

    def visit_While(self, node):
        self.vectorizable = False
        return self.generic_visit(node) 
    
    def visit_Name(self, node):
        if node.id == self.arg_name:
            self.vectorizable = False
            node.id = self.new_var_name
        return self.generic_visit(node) 

def AstModifier(original_func, index, axis, ndims):
    modified_tree, vectorize_tree = None, None
    source_code = source_parser(original_func)
    original_tree = ast.parse(source_code)

    arg_name = original_tree.body[0].args.args[0].arg
    new_name = arg_name + "X"
    while new_name in source_code:
        new_name + "X"
        
    replacer = SubscriptReplacer(arg_name, new_name, index, for_vectorize_form=False, axis = axis, ndims = ndims)
    modified_tree = copy.deepcopy(original_tree)
    modified_tree = replacer.visit(modified_tree)
    ast.fix_missing_locations(modified_tree)
    fuid = original_func.__name__ + replacer.fuid 
    if replacer.callmap_name is not None:
        modified_tree.body.insert(1, replacer.callmap_def)
        ast.fix_missing_locations(modified_tree)
    modified_tree.body[0].args.args[0].arg = new_name
    modified_tree.body[0].name += fuid +"_modified" 
    ast.fix_missing_locations(modified_tree)

    if replacer.vectorizable:
        replacer = SubscriptReplacer(arg_name, new_name, index, for_vectorize_form=True, axis = axis, ndims = ndims)
        vectorize_tree = copy.deepcopy(original_tree)
        vectorize_tree = replacer.visit(vectorize_tree)
        ast.fix_missing_locations(vectorize_tree)
        vectorize_tree.body[0].args.args[0].arg = new_name
        vectorize_tree.body[0].name += fuid +"_vectorized"
        ast.fix_missing_locations(vectorize_tree)
    
    original_tree.body[0].name += fuid +"_original" 
    return fuid, ast.unparse(original_tree), ast.unparse(modified_tree), ast.unparse(vectorize_tree) if vectorize_tree is not None else None



## Step A: Transform the code so that it works with array index instead of columns name (or non rangeIndex in fact) of pandas

### NB: For applicable function offer a vectorization of it directly (no compilation, leveraging directly numpy) 
#### applicable means no if, no need of callmap, all use of inputs refers to one of it's columns/rows ..


In [2]:
import pandas as pd
df = pd.DataFrame(np.random.randn(100_000,3), columns=['A','B', 'C']).astype(np.float32)
axis = 1

def f0(z):
    return z['A'] + z['B']
    
fuid, original_tree, modified, vectorized = AstModifier(f0, df.columns.values, 1, 2)
print(original_tree)
print()
print(modified)
print()
if vectorized is not None:
    print(vectorized)
else:
    print('Not vectorizable')

def f0f03c01bdbb26f358bab27f267924aa2c9a03fcfdb8_original(z):
    return z['A'] + z['B']

def f0f03c01bdbb26f358bab27f267924aa2c9a03fcfdb8_modified(zX):
    return zX[0] + zX[1]

def f0f03c01bdbb26f358bab27f267924aa2c9a03fcfdb8_vectorized(zX):
    return zX[:, 0] + zX[:, 1]


In [3]:

def f1(z):
    return z['A' if z['B'] > 0 else 'B'] * z['B']
    
fuid, original_tree, modified, vectorized = AstModifier(f1, df.columns.values, 1, 2)
print(original_tree)
print()
print(modified)
print()
if vectorized is not None:
    print(vectorized)
else:
    print('Not vectorizable')

def f1f13c01bdbb26f358bab27f267924aa2c9a03fcfdb8_original(z):
    return z['A' if z['B'] > 0 else 'B'] * z['B']

def f1f13c01bdbb26f358bab27f267924aa2c9a03fcfdb8_modified(zX):
    return zX[callmap3c01bdbb26f358bab27f267924aa2c9a03fcfdb8('A' if zX[1] > 0 else 'B')] * zX[1]

@nb.njit(nb.int8(nb.types.string), cache=False, parallel=True, fastmath=True, forceinline=True, looplift=True, inline='always', target_backend='host', no_cfunc_wrapper=True, no_rewrites=True, nopython=True, nogil=True)
def callmap3c01bdbb26f358bab27f267924aa2c9a03fcfdb8(x):
    if x == 'A':
        return 0
    if x == 'B':
        return 1
    if x == 'C':
        return 2
    return x

Not vectorizable


In [4]:
def f2(z):
    return np.sum(z)
    
fuid, original_tree, modified, vectorized = AstModifier(f2, {"A": 0, "B": 1}, 0, 2)
print(original_tree)
print()
print(modified)
print()
if vectorized is not None:
    print(vectorized)
else:
    print('Not vectorizable')

def f2f2_general__original(z):
    return np.sum(z)

def f2f2_general__modified(zX):
    return np.sum(zX)

Not vectorizable


## Step 2 wrapped the code inside the pattern developed to optimize numba usage

In [5]:

two_dim_axis_0 = """
@nb.njit(nb.PLACEHOLDERDTYPE(nb.PLACEHOLDERDTYPE[:]), cache=True, fastmath=True, forceinline=True, looplift=True, inline='always', target_backend='host', nogil=True)
PLACEHOLDERFUNC

@nb.njit((nb.PLACEHOLDERDTYPE[:], nb.PLACEHOLDERDTYPE[:],nb.types.uint32), cache=True, fastmath=True, forceinline=True, looplift=True, inline='always', target_backend='host', nogil=True)
def vl_PLACEHOLDERUID(z, r, i):
    r[i] = PLACEHOLDERNAME(z)

@nb.njit((nb.PLACEHOLDERDTYPE[:,:], nb.PLACEHOLDERDTYPE[:],nb.types.uint32), cache=True, parallel=True, fastmath=True, forceinline=True, looplift=True, inline='always', target_backend='host', no_cfunc_wrapper=True, no_rewrites=True, nogil=True)
def cvlopt_PLACEHOLDERUID(z, r, n):
    for i in nb.prange(n):
        vl_PLACEHOLDERUID(z[:,i], r, i)

def vlopt_PLACEHOLDERUID(z):
    n = np.uint32(z.shape[1])  # Determine the number of rows in z
    r = np.empty(n, dtype=np.PLACEHOLDERDTYPE)  # Initialize the result array
    cvlopt_PLACEHOLDERUID(z, r, n)
    return r
"""

two_dim_axis_1 = """
@nb.njit(nb.PLACEHOLDERDTYPE(nb.PLACEHOLDERDTYPE[:]), cache=True, fastmath=False, forceinline=True, looplift=True, inline='always', target_backend='host', nogil=True)
PLACEHOLDERFUNC

@nb.njit((nb.PLACEHOLDERDTYPE[:], nb.PLACEHOLDERDTYPE[:],nb.types.uint32), cache=True, fastmath=True, forceinline=True, looplift=True, inline='always', target_backend='host', nogil=True)
def vl_PLACEHOLDERUID(z, r, i):
    r[i] = PLACEHOLDERNAME(z)

@nb.njit((nb.PLACEHOLDERDTYPE[:,:], nb.PLACEHOLDERDTYPE[:],nb.types.uint32), cache=True, parallel=True, fastmath=True, forceinline=True, looplift=True, inline='always', target_backend='host', no_cfunc_wrapper=True, no_rewrites=True, nogil=True)
def cvlopt_PLACEHOLDERUID(z, r, n):
    for i in nb.prange(n):
        vl_PLACEHOLDERUID(z[i,:], r, i)

def vlopt_PLACEHOLDERUID(z):
    n = np.uint32(z.shape[0])  # Determine the number of rows in z
    r = np.empty(n, dtype=np.PLACEHOLDERDTYPE)  # Initialize the result array
    cvlopt_PLACEHOLDERUID(z, r, n)
    return r
"""

def _make_func(base_model: str, name: str, func: str, dtype: str, uid: str):
    final = base_model.replace('PLACEHOLDERNAME', name)
    final = final.replace('PLACEHOLDERFUNC', func)
    final = final.replace('PLACEHOLDERDTYPE', dtype)
    final = final.replace('PLACEHOLDERUID', uid)
    return final
    

def _compile_tree(func: Dict[str, str], exec_globals: Dict[str, Any]) -> Dict:
    def wrapped(x):
        if func['name'] not in exec_globals:
            exec(compile(ast.parse(func['source_code']), filename="/tmp/numbacache", mode="exec"), exec_globals)
        return exec_globals[func['name']](x)
    wrapped.__name__ = func['name']
    wrapped.__qualname__ = "__numba__." + func['name']
    return wrapped

def _prepare_func(original_func: Callable, index: Iterable[str], axis: int = 0, ndims: int = 2, dtype: str = 'float32') -> Tuple[ast.AST, ast.AST, ast.AST]:
    result = {}
    fuid, original, modified, vectorized = AstModifier(original_func = original_func, index = index, axis = axis, ndims = ndims)
    name_with_id = original_func.__name__ + fuid
    result['original'] = {"name":  name_with_id + "_original", "source_code": original}
    result["original"]["function"] = _compile_tree(result["original"], globals())
    if ndims == 1:
        decorator =  "@nb.njit(nb.{dtype}(nb.{dtype}[:]), cache = False, fastmath=True, forceinline=True, looplift=True, inline='always', target_backend='host', no_cfunc_wrapper=True, no_rewrites=True, nogil=True)\n"
        result['modified'] = decorator + modified
        return result

    elif ndims == 2:
        final_name =  "vlopt_" + fuid 
        modified_name = name_with_id + "_modified"
        if axis > 1:
            raise ValueError    
        result['modified'] = {"name": final_name, "source_code": _make_func(two_dim_axis_0 if axis == 0 else two_dim_axis_1, modified_name, modified, dtype, fuid)}
        result["modified"]["function"] = _compile_tree(result["modified"], globals())
        if vectorized is not None:
            vectorized_name = name_with_id + '_vectorized'
            result['vectorized'] = {"name": vectorized_name, "source_code": vectorized}
            result["vectorized"]["function"] = _compile_tree(result["vectorized"], globals())
        return result
    elif ndims == 3:
        raise NotImplemented

In [7]:

def f0(X):
    return X['A'] + X['B']


prepared = _prepare_func(f0, df.columns.values.astype(str) if axis==1 else df.index.values.astype(str), axis = axis, ndims = 2, dtype = "float32")

print(prepared['modified']['source_code'])




@nb.njit(nb.float32(nb.float32[:]), cache=True, fastmath=False, forceinline=True, looplift=True, inline='always', target_backend='host', nogil=True)
def f0f03c01bdbb26f358bab27f267924aa2c9a03fcfdb8_modified(XX):
    return XX[0] + XX[1]

@nb.njit((nb.float32[:], nb.float32[:],nb.types.uint32), cache=True, fastmath=True, forceinline=True, looplift=True, inline='always', target_backend='host', nogil=True)
def vl_f03c01bdbb26f358bab27f267924aa2c9a03fcfdb8(z, r, i):
    r[i] = f0f03c01bdbb26f358bab27f267924aa2c9a03fcfdb8_modified(z)

@nb.njit((nb.float32[:,:], nb.float32[:],nb.types.uint32), cache=True, parallel=True, fastmath=True, forceinline=True, looplift=True, inline='always', target_backend='host', no_cfunc_wrapper=True, no_rewrites=True, nogil=True)
def cvlopt_f03c01bdbb26f358bab27f267924aa2c9a03fcfdb8(z, r, n):
    for i in nb.prange(n):
        vl_f03c01bdbb26f358bab27f267924aa2c9a03fcfdb8(z[i,:], r, i)

def vlopt_f03c01bdbb26f358bab27f267924aa2c9a03fcfdb8(z):
    n = np.uint32(z

In [8]:
print('original')
print(df.apply(prepared['original']['function'], axis = axis))
%timeit df.apply(prepared['original']['function'], axis = axis)
print('modified')
print(prepared['modified']['function'](df.to_numpy()))
%timeit prepared['modified']['function'](df.to_numpy())
if "vectorized" in prepared:
    print('vectorized')
    print(prepared['vectorized']['function'](df.to_numpy()))
    %timeit prepared['vectorized']['function'](df.to_numpy())


original
0        3.133342
1        1.109728
2        2.519707
3       -1.649337
4       -0.933895
           ...   
99995    0.523224
99996   -0.708453
99997    2.436834
99998   -2.116399
99999    3.808775
Length: 100000, dtype: float32
462 ms ± 1.51 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
modified
[ 3.1333416  1.1097283  2.5197074 ...  2.4368343 -2.1163986  3.808775 ]
62.3 µs ± 403 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
vectorized
[ 3.1333416  1.1097283  2.5197074 ...  2.4368343 -2.1163986  3.808775 ]
33.8 µs ± 49.3 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


In [9]:

def f1(X):
    return np.sum(X)

prepared = _prepare_func(f1, df.columns.values.astype(str) if axis==1 else df.index.values.astype(str), axis = axis, ndims = 2, dtype = "float32")

print(prepared['modified']['source_code'])


@nb.njit(nb.float32(nb.float32[:]), cache=True, fastmath=False, forceinline=True, looplift=True, inline='always', target_backend='host', nogil=True)
def f1f1_general__modified(XX):
    return np.sum(XX)

@nb.njit((nb.float32[:], nb.float32[:],nb.types.uint32), cache=True, fastmath=True, forceinline=True, looplift=True, inline='always', target_backend='host', nogil=True)
def vl_f1_general_(z, r, i):
    r[i] = f1f1_general__modified(z)

@nb.njit((nb.float32[:,:], nb.float32[:],nb.types.uint32), cache=True, parallel=True, fastmath=True, forceinline=True, looplift=True, inline='always', target_backend='host', no_cfunc_wrapper=True, no_rewrites=True, nogil=True)
def cvlopt_f1_general_(z, r, n):
    for i in nb.prange(n):
        vl_f1_general_(z[i,:], r, i)

def vlopt_f1_general_(z):
    n = np.uint32(z.shape[0])  # Determine the number of rows in z
    r = np.empty(n, dtype=np.float32)  # Initialize the result array
    cvlopt_f1_general_(z, r, n)
    return r



In [10]:
print('pandas specific')
print(df.sum(axis = axis))
%timeit df.sum(axis = axis)

print('original')
print(df.apply(prepared['original']['function'], axis = axis))
%timeit df.apply(prepared['original']['function'], axis = axis)
print('modified')
print(prepared['modified']['function'](df.to_numpy()))
%timeit prepared['modified']['function'](df.to_numpy())
if "vectorized" in prepared:
    print('vectorized')
    print(prepared['vectorized']['function'](df.to_numpy()))
    %timeit prepared['vectorized']['function'](df.to_numpy())

pandas specific
0        2.938186
1        2.063428
2        3.001939
3       -0.614513
4       -1.568000
           ...   
99995    1.317062
99996   -1.882635
99997    4.044487
99998   -2.143098
99999    3.679650
Length: 100000, dtype: float32
8.63 ms ± 62.7 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
original
0        2.938186
1        2.063428
2        3.001939
3       -0.614513
4       -1.568000
           ...   
99995    1.317062
99996   -1.882635
99997    4.044487
99998   -2.143098
99999    3.679650
Length: 100000, dtype: float32
2.21 s ± 64.7 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
modified
[ 2.9381864  2.0634277  3.0019388 ...  4.0444875 -2.143098   3.6796496]
294 µs ± 3.38 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
