diff --git a/pythonbpf/codegen.py b/pythonbpf/codegen.py index cf20f06..1c9b9dd 100644 --- a/pythonbpf/codegen.py +++ b/pythonbpf/codegen.py @@ -1,7 +1,7 @@ import ast from llvmlite import ir from .license_pass import license_processing -from .functions_pass import func_proc +from .functions import func_proc from .maps import maps_proc from .structs import structs_proc from .globals_pass import globals_processing diff --git a/pythonbpf/expr_pass.py b/pythonbpf/expr_pass.py index 40d0800..56d047e 100644 --- a/pythonbpf/expr_pass.py +++ b/pythonbpf/expr_pass.py @@ -4,6 +4,8 @@ import logging from typing import Dict +from .type_deducer import ctypes_to_ir, is_ctypes + logger: Logger = logging.getLogger(__name__) @@ -88,6 +90,48 @@ def _handle_deref_call(expr: ast.Call, local_sym_tab: Dict, builder: ir.IRBuilde return val, local_sym_tab[arg.id].ir_type +def _handle_ctypes_call( + func, + module, + builder, + expr, + local_sym_tab, + map_sym_tab, + structs_sym_tab=None, +): + """Handle ctypes type constructor calls.""" + if len(expr.args) != 1: + logger.info("ctypes constructor takes exactly one argument") + return None + + arg = expr.args[0] + val = eval_expr( + func, + module, + builder, + arg, + local_sym_tab, + map_sym_tab, + structs_sym_tab, + ) + if val is None: + logger.info("Failed to evaluate argument to ctypes constructor") + return None + call_type = expr.func.id + expected_type = ctypes_to_ir(call_type) + + if val[1] != expected_type: + # NOTE: We are only considering casting to and from int types for now + if isinstance(val[1], ir.IntType) and isinstance(expected_type, ir.IntType): + if val[1].width < expected_type.width: + val = (builder.sext(val[0], expected_type), expected_type) + else: + val = (builder.trunc(val[0], expected_type), expected_type) + else: + raise ValueError(f"Type mismatch: expected {expected_type}, got {val[1]}") + return val + + def eval_expr( func, module, @@ -106,6 +150,17 @@ def eval_expr( if isinstance(expr.func, ast.Name) and expr.func.id == "deref": return _handle_deref_call(expr, local_sym_tab, builder) + if isinstance(expr.func, ast.Name) and is_ctypes(expr.func.id): + return _handle_ctypes_call( + func, + module, + builder, + expr, + local_sym_tab, + map_sym_tab, + structs_sym_tab, + ) + # delayed import to avoid circular dependency from pythonbpf.helper import HelperHandlerRegistry, handle_helper_call @@ -153,6 +208,10 @@ def eval_expr( ) elif isinstance(expr, ast.Attribute): return _handle_attribute_expr(expr, local_sym_tab, structs_sym_tab, builder) + elif isinstance(expr, ast.BinOp): + from pythonbpf.binary_ops import handle_binary_op + + return handle_binary_op(expr, builder, None, local_sym_tab) logger.info("Unsupported expression evaluation") return None diff --git a/pythonbpf/functions/__init__.py b/pythonbpf/functions/__init__.py new file mode 100644 index 0000000..df99da1 --- /dev/null +++ b/pythonbpf/functions/__init__.py @@ -0,0 +1,3 @@ +from .functions_pass import func_proc + +__all__ = ["func_proc"] diff --git a/pythonbpf/functions/func_registry_handlers.py b/pythonbpf/functions/func_registry_handlers.py new file mode 100644 index 0000000..afe54f6 --- /dev/null +++ b/pythonbpf/functions/func_registry_handlers.py @@ -0,0 +1,22 @@ +from typing import Dict + + +class StatementHandlerRegistry: + """Registry for statement handlers.""" + + _handlers: Dict = {} + + @classmethod + def register(cls, stmt_type): + """Register a handler for a specific statement type.""" + + def decorator(handler): + cls._handlers[stmt_type] = handler + return handler + + return decorator + + @classmethod + def __getitem__(cls, stmt_type): + """Get the handler for a specific statement type.""" + return cls._handlers.get(stmt_type, None) diff --git a/pythonbpf/functions_pass.py b/pythonbpf/functions/functions_pass.py similarity index 90% rename from pythonbpf/functions_pass.py rename to pythonbpf/functions/functions_pass.py index d1ea526..18904ec 100644 --- a/pythonbpf/functions_pass.py +++ b/pythonbpf/functions/functions_pass.py @@ -4,10 +4,13 @@ from typing import Any from dataclasses import dataclass -from .helper import HelperHandlerRegistry, handle_helper_call -from .type_deducer import ctypes_to_ir -from .binary_ops import handle_binary_op -from .expr_pass import eval_expr, handle_expr +from pythonbpf.helper import HelperHandlerRegistry, handle_helper_call +from pythonbpf.type_deducer import ctypes_to_ir +from pythonbpf.binary_ops import handle_binary_op +from pythonbpf.expr_pass import eval_expr, handle_expr + +from .return_utils import _handle_none_return, _handle_xdp_return, _is_xdp_name + logger = logging.getLogger(__name__) @@ -350,6 +353,27 @@ def handle_if( builder.position_at_end(merge_block) +def handle_return(builder, stmt, local_sym_tab, ret_type): + logger.info(f"Handling return statement: {ast.dump(stmt)}") + if stmt.value is None: + return _handle_none_return(builder) + elif isinstance(stmt.value, ast.Name) and _is_xdp_name(stmt.value.id): + return _handle_xdp_return(stmt, builder, ret_type) + else: + val = eval_expr( + func=None, + module=None, + builder=builder, + expr=stmt.value, + local_sym_tab=local_sym_tab, + map_sym_tab={}, + structs_sym_tab={}, + ) + logger.info(f"Evaluated return expression to {val}") + builder.ret(val[0]) + return True + + def process_stmt( func, module, @@ -383,61 +407,12 @@ def process_stmt( func, module, builder, stmt, map_sym_tab, local_sym_tab, structs_sym_tab ) elif isinstance(stmt, ast.Return): - if stmt.value is None: - builder.ret(ir.Constant(ir.IntType(64), 0)) - did_return = True - elif ( - isinstance(stmt.value, ast.Call) - and isinstance(stmt.value.func, ast.Name) - and len(stmt.value.args) == 1 - ): - if isinstance(stmt.value.args[0], ast.Constant) and isinstance( - stmt.value.args[0].value, int - ): - call_type = stmt.value.func.id - if ctypes_to_ir(call_type) != ret_type: - raise ValueError( - "Return type mismatch: expected" - f"{ctypes_to_ir(call_type)}, got {call_type}" - ) - else: - builder.ret(ir.Constant(ret_type, stmt.value.args[0].value)) - did_return = True - elif isinstance(stmt.value.args[0], ast.BinOp): - # TODO: Should be routed through eval_expr - val = handle_binary_op(stmt.value.args[0], builder, None, local_sym_tab) - if val is None: - raise ValueError("Failed to evaluate return expression") - if val[1] != ret_type: - raise ValueError( - "Return type mismatch: expected " f"{ret_type}, got {val[1]}" - ) - builder.ret(val[0]) - did_return = True - elif isinstance(stmt.value.args[0], ast.Name): - if stmt.value.args[0].id in local_sym_tab: - var = local_sym_tab[stmt.value.args[0].id].var - val = builder.load(var) - if val.type != ret_type: - raise ValueError( - "Return type mismatch: expected" - f"{ret_type}, got {val.type}" - ) - builder.ret(val) - did_return = True - else: - raise ValueError("Failed to evaluate return expression") - elif isinstance(stmt.value, ast.Name): - if stmt.value.id == "XDP_PASS": - builder.ret(ir.Constant(ret_type, 2)) - did_return = True - elif stmt.value.id == "XDP_DROP": - builder.ret(ir.Constant(ret_type, 1)) - did_return = True - else: - raise ValueError("Failed to evaluate return expression") - else: - raise ValueError("Unsupported return value") + did_return = handle_return( + builder, + stmt, + local_sym_tab, + ret_type, + ) return did_return diff --git a/pythonbpf/functions/return_utils.py b/pythonbpf/functions/return_utils.py new file mode 100644 index 0000000..c69e416 --- /dev/null +++ b/pythonbpf/functions/return_utils.py @@ -0,0 +1,45 @@ +import logging +import ast + +from llvmlite import ir + +logger: logging.Logger = logging.getLogger(__name__) + +XDP_ACTIONS = { + "XDP_ABORTED": 0, + "XDP_DROP": 1, + "XDP_PASS": 2, + "XDP_TX": 3, + "XDP_REDIRECT": 4, +} + + +def _handle_none_return(builder) -> bool: + """Handle return or return None -> returns 0.""" + builder.ret(ir.Constant(ir.IntType(64), 0)) + logger.debug("Generated default return: 0") + return True + + +def _is_xdp_name(name: str) -> bool: + """Check if a name is an XDP action""" + return name in XDP_ACTIONS + + +def _handle_xdp_return(stmt: ast.Return, builder, ret_type) -> bool: + """Handle XDP returns""" + if not isinstance(stmt.value, ast.Name): + return False + + action_name = stmt.value.id + + if action_name not in XDP_ACTIONS: + raise ValueError( + f"Unknown XDP action: {action_name}. Available: {XDP_ACTIONS.keys()}" + ) + return False + + value = XDP_ACTIONS[action_name] + builder.ret(ir.Constant(ret_type, value)) + logger.debug(f"Generated XDP action return: {action_name} = {value}") + return True diff --git a/pythonbpf/type_deducer.py b/pythonbpf/type_deducer.py index 909d33c..9867cc6 100644 --- a/pythonbpf/type_deducer.py +++ b/pythonbpf/type_deducer.py @@ -1,24 +1,28 @@ from llvmlite import ir # TODO: THIS IS NOT SUPPOSED TO MATCH STRINGS :skull: +mapping = { + "c_int8": ir.IntType(8), + "c_uint8": ir.IntType(8), + "c_int16": ir.IntType(16), + "c_uint16": ir.IntType(16), + "c_int32": ir.IntType(32), + "c_uint32": ir.IntType(32), + "c_int64": ir.IntType(64), + "c_uint64": ir.IntType(64), + "c_float": ir.FloatType(), + "c_double": ir.DoubleType(), + "c_void_p": ir.IntType(64), + # Not so sure about this one + "str": ir.PointerType(ir.IntType(8)), +} def ctypes_to_ir(ctype: str): - mapping = { - "c_int8": ir.IntType(8), - "c_uint8": ir.IntType(8), - "c_int16": ir.IntType(16), - "c_uint16": ir.IntType(16), - "c_int32": ir.IntType(32), - "c_uint32": ir.IntType(32), - "c_int64": ir.IntType(64), - "c_uint64": ir.IntType(64), - "c_float": ir.FloatType(), - "c_double": ir.DoubleType(), - "c_void_p": ir.IntType(64), - # Not so sure about this one - "str": ir.PointerType(ir.IntType(8)), - } if ctype in mapping: return mapping[ctype] raise NotImplementedError(f"No mapping for {ctype}") + + +def is_ctypes(ctype: str) -> bool: + return ctype in mapping diff --git a/tests/passing_tests/return/binop_const.py b/tests/passing_tests/return/binop_const.py new file mode 100644 index 0000000..faafd1f --- /dev/null +++ b/tests/passing_tests/return/binop_const.py @@ -0,0 +1,18 @@ +from pythonbpf import bpf, section, bpfglobal, compile +from ctypes import c_void_p, c_int64 + + +@bpf +@section("tracepoint/syscalls/sys_enter_execve") +def hello_world(ctx: c_void_p) -> c_int64: + print("Hello, World!") + return 1 + 1 - 2 + + +@bpf +@bpfglobal +def LICENSE() -> str: + return "GPL" + + +compile() diff --git a/tests/passing_tests/return/binop_var.py b/tests/passing_tests/return/binop_var.py new file mode 100644 index 0000000..32b5784 --- /dev/null +++ b/tests/passing_tests/return/binop_var.py @@ -0,0 +1,19 @@ +from pythonbpf import bpf, section, bpfglobal, compile +from ctypes import c_void_p, c_int64 + + +@bpf +@section("tracepoint/syscalls/sys_enter_execve") +def hello_world(ctx: c_void_p) -> c_int64: + print("Hello, World!") + a = 2 + return a - 2 + + +@bpf +@bpfglobal +def LICENSE() -> str: + return "GPL" + + +compile() diff --git a/tests/passing_tests/return/int.py b/tests/passing_tests/return/int.py new file mode 100644 index 0000000..b20b4a0 --- /dev/null +++ b/tests/passing_tests/return/int.py @@ -0,0 +1,18 @@ +from pythonbpf import bpf, section, bpfglobal, compile +from ctypes import c_void_p, c_int64 + + +@bpf +@section("tracepoint/syscalls/sys_enter_execve") +def hello_world(ctx: c_void_p) -> c_int64: + print("Hello, World!") + return 1 + + +@bpf +@bpfglobal +def LICENSE() -> str: + return "GPL" + + +compile() diff --git a/tests/passing_tests/return/null.py b/tests/passing_tests/return/null.py new file mode 100644 index 0000000..34a1492 --- /dev/null +++ b/tests/passing_tests/return/null.py @@ -0,0 +1,18 @@ +from pythonbpf import bpf, section, bpfglobal, compile +from ctypes import c_void_p, c_int64 + + +@bpf +@section("tracepoint/syscalls/sys_enter_execve") +def hello_world(ctx: c_void_p) -> c_int64: + print("Hello, World!") + return + + +@bpf +@bpfglobal +def LICENSE() -> str: + return "GPL" + + +compile() diff --git a/tests/passing_tests/return/typecast_binops.py b/tests/passing_tests/return/typecast_binops.py new file mode 100644 index 0000000..c58ba41 --- /dev/null +++ b/tests/passing_tests/return/typecast_binops.py @@ -0,0 +1,20 @@ +from pythonbpf import bpf, section, bpfglobal, compile +from ctypes import c_void_p, c_int32 + + +@bpf +@section("tracepoint/syscalls/sys_enter_execve") +def hello_world(ctx: c_void_p) -> c_int32: + print("Hello, World!") + a = 1 # int64 + x = 1 # int64 + return c_int32(a - x) # typecast to int32 + + +@bpf +@bpfglobal +def LICENSE() -> str: + return "GPL" + + +compile() diff --git a/tests/passing_tests/return/typecast_const.py b/tests/passing_tests/return/typecast_const.py new file mode 100644 index 0000000..50cc26f --- /dev/null +++ b/tests/passing_tests/return/typecast_const.py @@ -0,0 +1,18 @@ +from pythonbpf import bpf, section, bpfglobal, compile +from ctypes import c_void_p, c_int32 + + +@bpf +@section("tracepoint/syscalls/sys_enter_execve") +def hello_world(ctx: c_void_p) -> c_int32: + print("Hello, World!") + return c_int32(1) + + +@bpf +@bpfglobal +def LICENSE() -> str: + return "GPL" + + +compile() diff --git a/tests/passing_tests/return/typecast_var.py b/tests/passing_tests/return/typecast_var.py new file mode 100644 index 0000000..1960edd --- /dev/null +++ b/tests/passing_tests/return/typecast_var.py @@ -0,0 +1,19 @@ +from pythonbpf import bpf, section, bpfglobal, compile +from ctypes import c_void_p, c_int32 + + +@bpf +@section("tracepoint/syscalls/sys_enter_execve") +def hello_world(ctx: c_void_p) -> c_int32: + print("Hello, World!") + a = 1 # int64 + return c_int32(a) # typecast to int32 + + +@bpf +@bpfglobal +def LICENSE() -> str: + return "GPL" + + +compile() diff --git a/tests/passing_tests/return/var.py b/tests/passing_tests/return/var.py new file mode 100644 index 0000000..26fb34d --- /dev/null +++ b/tests/passing_tests/return/var.py @@ -0,0 +1,19 @@ +from pythonbpf import bpf, section, bpfglobal, compile +from ctypes import c_void_p, c_int64 + + +@bpf +@section("tracepoint/syscalls/sys_enter_execve") +def hello_world(ctx: c_void_p) -> c_int64: + print("Hello, World!") + a = 1 + return a + + +@bpf +@bpfglobal +def LICENSE() -> str: + return "GPL" + + +compile() diff --git a/tests/passing_tests/return/xdp.py b/tests/passing_tests/return/xdp.py new file mode 100644 index 0000000..3c0f5d8 --- /dev/null +++ b/tests/passing_tests/return/xdp.py @@ -0,0 +1,19 @@ +from pythonbpf import bpf, section, bpfglobal, compile +from ctypes import c_void_p, c_int64 +from pythonbpf.helper import XDP_PASS + + +@bpf +@section("tracepoint/syscalls/sys_enter_execve") +def hello_world(ctx: c_void_p) -> c_int64: + print("Hello, World!") + return XDP_PASS + + +@bpf +@bpfglobal +def LICENSE() -> str: + return "GPL" + + +compile()