diff --git a/pythonbpf/expr/__init__.py b/pythonbpf/expr/__init__.py index 94cf330e..4b1dac80 100644 --- a/pythonbpf/expr/__init__.py +++ b/pythonbpf/expr/__init__.py @@ -1,4 +1,4 @@ -from .expr_pass import eval_expr, handle_expr, get_operand_value +from .expr_pass import eval_expr, handle_expr, get_operand_value, CallHandlerRegistry from .type_normalization import convert_to_bool, get_base_type_and_depth, deref_to_depth __all__ = [ @@ -8,4 +8,5 @@ "get_base_type_and_depth", "deref_to_depth", "get_operand_value", + "CallHandlerRegistry", ] diff --git a/pythonbpf/expr/expr_pass.py b/pythonbpf/expr/expr_pass.py index f16fd46c..e662984a 100644 --- a/pythonbpf/expr/expr_pass.py +++ b/pythonbpf/expr/expr_pass.py @@ -15,6 +15,28 @@ logger: Logger = logging.getLogger(__name__) +class CallHandlerRegistry: + """Registry for handling different types of calls (helpers, etc.)""" + + _handler = None + + @classmethod + def set_handler(cls, handler): + """Set the handler for unknown calls""" + cls._handler = handler + + @classmethod + def handle_call( + cls, call, module, builder, func, local_sym_tab, map_sym_tab, structs_sym_tab + ): + """Handle a call using the registered handler""" + if cls._handler is None: + return None + return cls._handler( + call, module, builder, func, local_sym_tab, map_sym_tab, structs_sym_tab + ) + + def get_operand_value( func, module, operand, builder, local_sym_tab, map_sym_tab, structs_sym_tab=None ): @@ -478,51 +500,14 @@ def eval_expr( structs_sym_tab, ) - # delayed import to avoid circular dependency - from pythonbpf.helper import HelperHandlerRegistry, handle_helper_call + result = CallHandlerRegistry.handle_call( + expr, module, builder, func, local_sym_tab, map_sym_tab, structs_sym_tab + ) + if result is not None: + return result - if isinstance(expr.func, ast.Name) and HelperHandlerRegistry.has_handler( - expr.func.id - ): - return handle_helper_call( - expr, - module, - builder, - func, - local_sym_tab, - map_sym_tab, - structs_sym_tab, - ) - elif isinstance(expr.func, ast.Attribute): - logger.info(f"Handling method call: {ast.dump(expr.func)}") - if isinstance(expr.func.value, ast.Call) and isinstance( - expr.func.value.func, ast.Name - ): - method_name = expr.func.attr - if HelperHandlerRegistry.has_handler(method_name): - return handle_helper_call( - expr, - module, - builder, - func, - local_sym_tab, - map_sym_tab, - structs_sym_tab, - ) - elif isinstance(expr.func.value, ast.Name): - obj_name = expr.func.value.id - method_name = expr.func.attr - if obj_name in map_sym_tab: - if HelperHandlerRegistry.has_handler(method_name): - return handle_helper_call( - expr, - module, - builder, - func, - local_sym_tab, - map_sym_tab, - structs_sym_tab, - ) + logger.warning(f"Unknown call: {ast.dump(expr)}") + return None elif isinstance(expr, ast.Attribute): return _handle_attribute_expr(expr, local_sym_tab, structs_sym_tab, builder) elif isinstance(expr, ast.BinOp): diff --git a/pythonbpf/helper/__init__.py b/pythonbpf/helper/__init__.py index 007724f7..265da51d 100644 --- a/pythonbpf/helper/__init__.py +++ b/pythonbpf/helper/__init__.py @@ -2,6 +2,58 @@ from .bpf_helper_handler import handle_helper_call from .helpers import ktime, pid, deref, XDP_DROP, XDP_PASS + +# Register the helper handler with expr module +def _register_helper_handler(): + """Register helper call handler with the expression evaluator""" + from pythonbpf.expr.expr_pass import CallHandlerRegistry + + def helper_call_handler( + call, module, builder, func, local_sym_tab, map_sym_tab, structs_sym_tab + ): + """Check if call is a helper and handle it""" + import ast + + # Check for direct helper calls (e.g., ktime(), print()) + if isinstance(call.func, ast.Name): + if HelperHandlerRegistry.has_handler(call.func.id): + return handle_helper_call( + call, + module, + builder, + func, + local_sym_tab, + map_sym_tab, + structs_sym_tab, + ) + + # Check for method calls (e.g., map.lookup()) + elif isinstance(call.func, ast.Attribute): + method_name = call.func.attr + + # Handle: my_map.lookup(key) + if isinstance(call.func.value, ast.Name): + obj_name = call.func.value.id + if map_sym_tab and obj_name in map_sym_tab: + if HelperHandlerRegistry.has_handler(method_name): + return handle_helper_call( + call, + module, + builder, + func, + local_sym_tab, + map_sym_tab, + structs_sym_tab, + ) + + return None + + CallHandlerRegistry.set_handler(helper_call_handler) + + +# Register on module import +_register_helper_handler() + __all__ = [ "HelperHandlerRegistry", "reset_scratch_pool",