diff --git a/pythonbpf/allocation_pass.py b/pythonbpf/allocation_pass.py index 9d824841..3149c752 100644 --- a/pythonbpf/allocation_pass.py +++ b/pythonbpf/allocation_pass.py @@ -5,6 +5,7 @@ from dataclasses import dataclass from typing import Any from pythonbpf.helper import HelperHandlerRegistry +from .expr import VmlinuxHandlerRegistry from pythonbpf.type_deducer import ctypes_to_ir logger = logging.getLogger(__name__) @@ -49,6 +50,15 @@ def handle_assign_allocation(builder, stmt, local_sym_tab, structs_sym_tab): logger.debug(f"Variable {var_name} already allocated, skipping") return + # When allocating a variable, check if it's a vmlinux struct type + if isinstance(stmt.value, ast.Name) and VmlinuxHandlerRegistry.is_vmlinux_struct( + stmt.value.id + ): + # Handle vmlinux struct allocation + # This requires more implementation + print(stmt.value) + pass + # Determine type and allocate based on rval if isinstance(rval, ast.Call): _allocate_for_call(builder, var_name, rval, local_sym_tab, structs_sym_tab) diff --git a/pythonbpf/codegen.py b/pythonbpf/codegen.py index beac470c..e97b1944 100644 --- a/pythonbpf/codegen.py +++ b/pythonbpf/codegen.py @@ -5,6 +5,8 @@ from .maps import maps_proc from .structs import structs_proc from .vmlinux_parser import vmlinux_proc +from pythonbpf.vmlinux_parser.vmlinux_exports_handler import VmlinuxHandler +from .expr import VmlinuxHandlerRegistry from .globals_pass import ( globals_list_creation, globals_processing, @@ -56,10 +58,13 @@ def processor(source_code, filename, module): logger.info(f"Found BPF function/struct: {func_node.name}") vmlinux_symtab = vmlinux_proc(tree, module) + if vmlinux_symtab: + handler = VmlinuxHandler.initialize(vmlinux_symtab) + VmlinuxHandlerRegistry.set_handler(handler) + populate_global_symbol_table(tree, module) license_processing(tree, module) globals_processing(tree, module) - print("DEBUG:", vmlinux_symtab) structs_sym_tab = structs_proc(tree, module, bpf_chunks) map_sym_tab = maps_proc(tree, module, bpf_chunks) func_proc(tree, module, bpf_chunks, map_sym_tab, structs_sym_tab) diff --git a/pythonbpf/expr/__init__.py b/pythonbpf/expr/__init__.py index 3c403ddf..ac3a9751 100644 --- a/pythonbpf/expr/__init__.py +++ b/pythonbpf/expr/__init__.py @@ -2,6 +2,7 @@ from .type_normalization import convert_to_bool, get_base_type_and_depth from .ir_ops import deref_to_depth from .call_registry import CallHandlerRegistry +from .vmlinux_registry import VmlinuxHandlerRegistry __all__ = [ "eval_expr", @@ -11,4 +12,5 @@ "deref_to_depth", "get_operand_value", "CallHandlerRegistry", + "VmlinuxHandlerRegistry", ] diff --git a/pythonbpf/expr/expr_pass.py b/pythonbpf/expr/expr_pass.py index 8bbd5242..2a7cd5fe 100644 --- a/pythonbpf/expr/expr_pass.py +++ b/pythonbpf/expr/expr_pass.py @@ -12,6 +12,7 @@ get_base_type_and_depth, deref_to_depth, ) +from .vmlinux_registry import VmlinuxHandlerRegistry logger: Logger = logging.getLogger(__name__) @@ -27,8 +28,12 @@ def _handle_name_expr(expr: ast.Name, local_sym_tab: Dict, builder: ir.IRBuilder val = builder.load(var) return val, local_sym_tab[expr.id].ir_type else: - logger.info(f"Undefined variable {expr.id}") - return None + # Check if it's a vmlinux enum/constant + vmlinux_result = VmlinuxHandlerRegistry.handle_name(expr.id) + if vmlinux_result is not None: + return vmlinux_result + + raise SyntaxError(f"Undefined variable {expr.id}") def _handle_constant_expr(module, builder, expr: ast.Constant): @@ -74,6 +79,13 @@ def _handle_attribute_expr( val = builder.load(gep) field_type = metadata.field_type(attr_name) return val, field_type + + # Try vmlinux handler as fallback + vmlinux_result = VmlinuxHandlerRegistry.handle_attribute( + expr, local_sym_tab, None, builder + ) + if vmlinux_result is not None: + return vmlinux_result return None @@ -130,7 +142,12 @@ def get_operand_value( logger.info(f"var is {var}, base_type is {base_type}, depth is {depth}") val = deref_to_depth(func, builder, var, depth) return val - raise ValueError(f"Undefined variable: {operand.id}") + else: + # Check if it's a vmlinux enum/constant + vmlinux_result = VmlinuxHandlerRegistry.handle_name(operand.id) + if vmlinux_result is not None: + val, _ = vmlinux_result + return val elif isinstance(operand, ast.Constant): if isinstance(operand.value, int): cst = ir.Constant(ir.IntType(64), int(operand.value)) @@ -332,6 +349,7 @@ def _handle_unary_op( neg_one = ir.Constant(ir.IntType(64), -1) result = builder.mul(operand, neg_one) return result, ir.IntType(64) + return None # ============================================================================ diff --git a/pythonbpf/expr/vmlinux_registry.py b/pythonbpf/expr/vmlinux_registry.py new file mode 100644 index 00000000..9e9d52ed --- /dev/null +++ b/pythonbpf/expr/vmlinux_registry.py @@ -0,0 +1,45 @@ +import ast + + +class VmlinuxHandlerRegistry: + """Registry for vmlinux handler operations""" + + _handler = None + + @classmethod + def set_handler(cls, handler): + """Set the vmlinux handler""" + cls._handler = handler + + @classmethod + def get_handler(cls): + """Get the vmlinux handler""" + return cls._handler + + @classmethod + def handle_name(cls, name): + """Try to handle a name as vmlinux enum/constant""" + if cls._handler is None: + return None + return cls._handler.handle_vmlinux_enum(name) + + @classmethod + def handle_attribute(cls, expr, local_sym_tab, module, builder): + """Try to handle an attribute access as vmlinux struct field""" + if cls._handler is None: + return None + + if isinstance(expr.value, ast.Name): + var_name = expr.value.id + field_name = expr.attr + return cls._handler.handle_vmlinux_struct_field( + var_name, field_name, module, builder, local_sym_tab + ) + return None + + @classmethod + def is_vmlinux_struct(cls, name): + """Check if a name refers to a vmlinux struct""" + if cls._handler is None: + return False + return cls._handler.is_vmlinux_struct(name) diff --git a/pythonbpf/functions/functions_pass.py b/pythonbpf/functions/functions_pass.py index 8d0bce1e..82433441 100644 --- a/pythonbpf/functions/functions_pass.py +++ b/pythonbpf/functions/functions_pass.py @@ -311,7 +311,13 @@ def process_stmt( def process_func_body( - module, builder, func_node, func, ret_type, map_sym_tab, structs_sym_tab + module, + builder, + func_node, + func, + ret_type, + map_sym_tab, + structs_sym_tab, ): """Process the body of a bpf function""" # TODO: A lot. We just have print -> bpf_trace_printk for now @@ -384,7 +390,13 @@ def process_bpf_chunk(func_node, module, return_type, map_sym_tab, structs_sym_t builder = ir.IRBuilder(block) process_func_body( - module, builder, func_node, func, ret_type, map_sym_tab, structs_sym_tab + module, + builder, + func_node, + func, + ret_type, + map_sym_tab, + structs_sym_tab, ) return func diff --git a/pythonbpf/helper/printk_formatter.py b/pythonbpf/helper/printk_formatter.py index e0cd669f..66fcb502 100644 --- a/pythonbpf/helper/printk_formatter.py +++ b/pythonbpf/helper/printk_formatter.py @@ -3,6 +3,7 @@ from llvmlite import ir from pythonbpf.expr import eval_expr, get_base_type_and_depth, deref_to_depth +from pythonbpf.expr.vmlinux_registry import VmlinuxHandlerRegistry logger = logging.getLogger(__name__) @@ -108,6 +109,16 @@ def _process_name_in_fval(name_node, fmt_parts, exprs, local_sym_tab): if local_sym_tab and name_node.id in local_sym_tab: _, var_type, tmp = local_sym_tab[name_node.id] _populate_fval(var_type, name_node, fmt_parts, exprs) + else: + # Try to resolve through vmlinux registry if not in local symbol table + result = VmlinuxHandlerRegistry.handle_name(name_node.id) + if result: + val, var_type = result + _populate_fval(var_type, name_node, fmt_parts, exprs) + else: + raise ValueError( + f"Variable '{name_node.id}' not found in symbol table or vmlinux" + ) def _process_attr_in_fval(attr_node, fmt_parts, exprs, local_sym_tab, struct_sym_tab): diff --git a/pythonbpf/maps/maps_pass.py b/pythonbpf/maps/maps_pass.py index 84598483..85837d72 100644 --- a/pythonbpf/maps/maps_pass.py +++ b/pythonbpf/maps/maps_pass.py @@ -6,6 +6,8 @@ from .maps_utils import MapProcessorRegistry from .map_types import BPFMapType from .map_debug_info import create_map_debug_info, create_ringbuf_debug_info +from pythonbpf.expr.vmlinux_registry import VmlinuxHandlerRegistry + logger: Logger = logging.getLogger(__name__) @@ -51,7 +53,7 @@ def _parse_map_params(rval, expected_args=None): """Parse map parameters from call arguments and keywords.""" params = {} - + handler = VmlinuxHandlerRegistry.get_handler() # Parse positional arguments if expected_args: for i, arg_name in enumerate(expected_args): @@ -65,7 +67,12 @@ def _parse_map_params(rval, expected_args=None): # Parse keyword arguments (override positional) for keyword in rval.keywords: if isinstance(keyword.value, ast.Name): - params[keyword.arg] = keyword.value.id + name = keyword.value.id + if handler and handler.is_vmlinux_enum(name): + result = handler.get_vmlinux_enum_value(name) + params[keyword.arg] = result if result is not None else name + else: + params[keyword.arg] = name elif isinstance(keyword.value, ast.Constant): params[keyword.arg] = keyword.value.value diff --git a/pythonbpf/vmlinux_parser/vmlinux_exports_handler.py b/pythonbpf/vmlinux_parser/vmlinux_exports_handler.py new file mode 100644 index 00000000..1986b447 --- /dev/null +++ b/pythonbpf/vmlinux_parser/vmlinux_exports_handler.py @@ -0,0 +1,90 @@ +import logging +from llvmlite import ir + +from pythonbpf.vmlinux_parser.assignment_info import AssignmentType + +logger = logging.getLogger(__name__) + + +class VmlinuxHandler: + """Handler for vmlinux-related operations""" + + _instance = None + + @classmethod + def get_instance(cls): + """Get the singleton instance""" + if cls._instance is None: + logger.warning("VmlinuxHandler used before initialization") + return None + return cls._instance + + @classmethod + def initialize(cls, vmlinux_symtab): + """Initialize the handler with vmlinux symbol table""" + cls._instance = cls(vmlinux_symtab) + return cls._instance + + def __init__(self, vmlinux_symtab): + """Initialize with vmlinux symbol table""" + self.vmlinux_symtab = vmlinux_symtab + logger.info( + f"VmlinuxHandler initialized with {len(vmlinux_symtab) if vmlinux_symtab else 0} symbols" + ) + + def is_vmlinux_enum(self, name): + """Check if name is a vmlinux enum constant""" + return ( + name in self.vmlinux_symtab + and self.vmlinux_symtab[name]["value_type"] == AssignmentType.CONSTANT + ) + + def is_vmlinux_struct(self, name): + """Check if name is a vmlinux struct""" + return ( + name in self.vmlinux_symtab + and self.vmlinux_symtab[name]["value_type"] == AssignmentType.STRUCT + ) + + def handle_vmlinux_enum(self, name): + """Handle vmlinux enum constants by returning LLVM IR constants""" + if self.is_vmlinux_enum(name): + value = self.vmlinux_symtab[name]["value"] + logger.info(f"Resolving vmlinux enum {name} = {value}") + return ir.Constant(ir.IntType(64), value), ir.IntType(64) + return None + + def get_vmlinux_enum_value(self, name): + """Handle vmlinux enum constants by returning LLVM IR constants""" + if self.is_vmlinux_enum(name): + value = self.vmlinux_symtab[name]["value"] + logger.info(f"The value of vmlinux enum {name} = {value}") + return value + return None + + def handle_vmlinux_struct(self, struct_name, module, builder): + """Handle vmlinux struct initializations""" + if self.is_vmlinux_struct(struct_name): + # TODO: Implement core-specific struct handling + # This will be more complex and depends on the BTF information + logger.info(f"Handling vmlinux struct {struct_name}") + # Return struct type and allocated pointer + # This is a stub, actual implementation will be more complex + return None + return None + + def handle_vmlinux_struct_field( + self, struct_var_name, field_name, module, builder, local_sym_tab + ): + """Handle access to vmlinux struct fields""" + # Check if it's a variable of vmlinux struct type + if struct_var_name in local_sym_tab: + var_info = local_sym_tab[struct_var_name] # noqa: F841 + # Need to check if this variable is a vmlinux struct + # This will depend on how you track vmlinux struct types in your symbol table + logger.info( + f"Attempting to access field {field_name} of possible vmlinux struct {struct_var_name}" + ) + # Return pointer to field and field type + return None + return None diff --git a/tests/passing_tests/vmlinux/simple_struct_test.py b/tests/passing_tests/vmlinux/simple_struct_test.py index c9390c84..97ab54a1 100644 --- a/tests/passing_tests/vmlinux/simple_struct_test.py +++ b/tests/passing_tests/vmlinux/simple_struct_test.py @@ -1,10 +1,26 @@ -from pythonbpf import bpf, section, bpfglobal, compile_to_ir +import logging + +from pythonbpf import bpf, section, bpfglobal, compile_to_ir, map +from pythonbpf import compile # noqa: F401 from vmlinux import TASK_COMM_LEN # noqa: F401 from vmlinux import struct_trace_event_raw_sys_enter # noqa: F401 +from ctypes import c_uint64, c_int32, c_int64 +from pythonbpf.maps import HashMap # from vmlinux import struct_uinput_device # from vmlinux import struct_blk_integrity_iter -from ctypes import c_int64 + + +@bpf +@map +def mymap() -> HashMap: + return HashMap(key=c_int32, value=c_uint64, max_entries=TASK_COMM_LEN) + + +@bpf +@map +def mymap2() -> HashMap: + return HashMap(key=c_int32, value=c_uint64, max_entries=18) # Instructions to how to run this program @@ -16,8 +32,9 @@ @bpf @section("tracepoint/syscalls/sys_enter_execve") def hello_world(ctx: struct_trace_event_raw_sys_enter) -> c_int64: - print("Hello, World!") - return c_int64(0) + a = 2 + TASK_COMM_LEN + TASK_COMM_LEN + print(f"Hello, World{TASK_COMM_LEN} and {a}") + return c_int64(TASK_COMM_LEN + 2) @bpf @@ -26,4 +43,5 @@ def LICENSE() -> str: return "GPL" -compile_to_ir("simple_struct_test.py", "simple_struct_test.ll") +compile_to_ir("simple_struct_test.py", "simple_struct_test.ll", loglevel=logging.DEBUG) +# compile()