diff --git a/pythonbpf/allocation_pass.py b/pythonbpf/allocation_pass.py index b5fa37c0..4fc6f5e4 100644 --- a/pythonbpf/allocation_pass.py +++ b/pythonbpf/allocation_pass.py @@ -118,6 +118,18 @@ def _allocate_for_call( local_sym_tab[var_name] = LocalSymbol(var, struct_info.ir_type, call_type) logger.info(f"Pre-allocated {var_name} for struct {call_type}") + elif VmlinuxHandlerRegistry.is_vmlinux_struct(call_type): + # When calling struct_name(pointer), we're doing a cast, not construction + # So we allocate as a pointer (i64) not as the actual struct + var = builder.alloca(ir.IntType(64), name=var_name) + var.align = 8 + local_sym_tab[var_name] = LocalSymbol( + var, ir.IntType(64), VmlinuxHandlerRegistry.get_struct_type(call_type) + ) + logger.info( + f"Pre-allocated {var_name} for vmlinux struct pointer cast to {call_type}" + ) + else: logger.warning(f"Unknown call type for allocation: {call_type}") @@ -325,13 +337,6 @@ def _allocate_for_attribute(builder, var_name, rval, local_sym_tab, structs_sym_ VmlinuxHandlerRegistry.get_field_type(vmlinux_struct_name, field_name) ) field_ir, field = field_type - # TODO: For now, we only support integer type allocations. - # This always assumes first argument of function to be the context struct - base_ptr = builder.function.args[0] - local_sym_tab[ - struct_var - ].var = base_ptr # This is repurposing of var to store the pointer of the base type - local_sym_tab[struct_var].ir_type = field_ir # Determine the actual IR type based on the field's type actual_ir_type = None @@ -386,12 +391,14 @@ def _allocate_for_attribute(builder, var_name, rval, local_sym_tab, structs_sym_ ) actual_ir_type = ir.IntType(64) - # Allocate with the actual IR type, not the GlobalVariable + # Allocate with the actual IR type var = _allocate_with_type(builder, var_name, actual_ir_type) - local_sym_tab[var_name] = LocalSymbol(var, actual_ir_type, field) + local_sym_tab[var_name] = LocalSymbol( + var, actual_ir_type, field + ) # <-- Store Field metadata logger.info( - f"Pre-allocated {var_name} from vmlinux struct {vmlinux_struct_name}.{field_name}" + f"Pre-allocated {var_name} as {actual_ir_type} from vmlinux struct {vmlinux_struct_name}.{field_name}" ) return else: diff --git a/pythonbpf/assign_pass.py b/pythonbpf/assign_pass.py index 0bd48c61..5d73cf3e 100644 --- a/pythonbpf/assign_pass.py +++ b/pythonbpf/assign_pass.py @@ -1,5 +1,7 @@ import ast import logging +from inspect import isclass + from llvmlite import ir from pythonbpf.expr import eval_expr from pythonbpf.helper import emit_probe_read_kernel_str_call @@ -148,8 +150,30 @@ def handle_variable_assignment( return False val, val_type = val_result - logger.info(f"Evaluated value for {var_name}: {val} of type {val_type}, {var_type}") + logger.info( + f"Evaluated value for {var_name}: {val} of type {val_type}, expected {var_type}" + ) + if val_type != var_type: + # Handle vmlinux struct pointers - they're represented as Python classes but are i64 pointers + if isclass(val_type) and (val_type.__module__ == "vmlinux"): + logger.info("Handling vmlinux struct pointer assignment") + # vmlinux struct pointers: val is a pointer, need to convert to i64 + if isinstance(var_type, ir.IntType) and var_type.width == 64: + # Convert pointer to i64 using ptrtoint + if isinstance(val.type, ir.PointerType): + val = builder.ptrtoint(val, ir.IntType(64)) + logger.info( + "Converted vmlinux struct pointer to i64 using ptrtoint" + ) + builder.store(val, var_ptr) + logger.info(f"Assigned vmlinux struct pointer to {var_name} (i64)") + return True + else: + logger.error( + f"Type mismatch: vmlinux struct pointer requires i64, got {var_type}" + ) + return False if isinstance(val_type, Field): logger.info("Handling assignment to struct field") # Special handling for struct_xdp_md i32 fields that are zero-extended to i64 diff --git a/pythonbpf/expr/expr_pass.py b/pythonbpf/expr/expr_pass.py index a9eab987..c510c969 100644 --- a/pythonbpf/expr/expr_pass.py +++ b/pythonbpf/expr/expr_pass.py @@ -12,8 +12,8 @@ get_base_type_and_depth, deref_to_depth, ) -from pythonbpf.vmlinux_parser.assignment_info import Field from .vmlinux_registry import VmlinuxHandlerRegistry +from ..vmlinux_parser.dependency_node import Field logger: Logger = logging.getLogger(__name__) @@ -89,8 +89,16 @@ def _handle_attribute_expr( return vmlinux_result else: raise RuntimeError("Vmlinux struct did not process successfully") - metadata = structs_sym_tab[var_metadata] - if attr_name in metadata.fields: + + elif isinstance(var_metadata, Field): + logger.error( + f"Cannot access field '{attr_name}' on already-loaded field value '{var_name}'" + ) + return None + + # Regular user-defined struct + metadata = structs_sym_tab.get(var_metadata) + if metadata and attr_name in metadata.fields: gep = metadata.gep(builder, var_ptr, attr_name) val = builder.load(gep) field_type = metadata.field_type(attr_name) @@ -525,6 +533,66 @@ def _handle_boolean_op( return None +# ============================================================================ +# VMLinux casting +# ============================================================================ + + +def _handle_vmlinux_cast( + func, + module, + builder, + expr, + local_sym_tab, + map_sym_tab, + structs_sym_tab=None, +): + # handle expressions such as struct_request(ctx.di) where struct_request is a vmlinux + # struct and ctx.di is a pointer to a struct but is actually represented as a c_uint64 + # which needs to be cast to a pointer. This is also a field of another vmlinux struct + """Handle vmlinux struct cast expressions like struct_request(ctx.di).""" + if len(expr.args) != 1: + logger.info("vmlinux struct cast takes exactly one argument") + return None + + # Get the struct name + struct_name = expr.func.id + + # Evaluate the argument (e.g., ctx.di which is a c_uint64) + arg_result = eval_expr( + func, + module, + builder, + expr.args[0], + local_sym_tab, + map_sym_tab, + structs_sym_tab, + ) + + if arg_result is None: + logger.info("Failed to evaluate argument to vmlinux struct cast") + return None + + arg_val, arg_type = arg_result + # Get the vmlinux struct type + vmlinux_struct_type = VmlinuxHandlerRegistry.get_struct_type(struct_name) + if vmlinux_struct_type is None: + logger.error(f"Failed to get vmlinux struct type for {struct_name}") + return None + # Cast the integer/value to a pointer to the struct + # If arg_val is an integer type, we need to inttoptr it + ptr_type = ir.PointerType() + # TODO: add a integer check here later + if ctypes_to_ir(arg_type.type.__name__): + # Cast integer to pointer + casted_ptr = builder.inttoptr(arg_val, ptr_type) + else: + logger.error(f"Unsupported type for vmlinux cast: {arg_type}") + return None + + return casted_ptr, vmlinux_struct_type + + # ============================================================================ # Expression Dispatcher # ============================================================================ @@ -545,6 +613,18 @@ def eval_expr( elif isinstance(expr, ast.Constant): return _handle_constant_expr(module, builder, expr) elif isinstance(expr, ast.Call): + if isinstance(expr.func, ast.Name) and VmlinuxHandlerRegistry.is_vmlinux_struct( + expr.func.id + ): + return _handle_vmlinux_cast( + func, + module, + builder, + expr, + local_sym_tab, + map_sym_tab, + structs_sym_tab, + ) if isinstance(expr.func, ast.Name) and expr.func.id == "deref": return _handle_deref_call(expr, local_sym_tab, builder) diff --git a/pythonbpf/helper/__init__.py b/pythonbpf/helper/__init__.py index 6d38e791..17306355 100644 --- a/pythonbpf/helper/__init__.py +++ b/pythonbpf/helper/__init__.py @@ -1,6 +1,10 @@ from .helper_registry import HelperHandlerRegistry from .helper_utils import reset_scratch_pool -from .bpf_helper_handler import handle_helper_call, emit_probe_read_kernel_str_call +from .bpf_helper_handler import ( + handle_helper_call, + emit_probe_read_kernel_str_call, + emit_probe_read_kernel_call, +) from .helpers import ( ktime, pid, @@ -74,6 +78,7 @@ def helper_call_handler( "reset_scratch_pool", "handle_helper_call", "emit_probe_read_kernel_str_call", + "emit_probe_read_kernel_call", "ktime", "pid", "deref", diff --git a/pythonbpf/helper/bpf_helper_handler.py b/pythonbpf/helper/bpf_helper_handler.py index ba35cc45..f52e87a9 100644 --- a/pythonbpf/helper/bpf_helper_handler.py +++ b/pythonbpf/helper/bpf_helper_handler.py @@ -34,6 +34,7 @@ class BPFHelperID(Enum): BPF_PERF_EVENT_OUTPUT = 25 BPF_GET_STACK = 67 BPF_PROBE_READ_KERNEL_STR = 115 + BPF_PROBE_READ_KERNEL = 113 BPF_RINGBUF_OUTPUT = 130 BPF_RINGBUF_RESERVE = 131 BPF_RINGBUF_SUBMIT = 132 @@ -574,6 +575,75 @@ def bpf_probe_read_kernel_str_emitter( return result, ir.IntType(64) +def emit_probe_read_kernel_call(builder, dst_ptr, dst_size, src_ptr): + """Emit LLVM IR call to bpf_probe_read_kernel""" + + fn_type = ir.FunctionType( + ir.IntType(64), + [ir.PointerType(), ir.IntType(32), ir.PointerType()], + var_arg=False, + ) + fn_ptr = builder.inttoptr( + ir.Constant(ir.IntType(64), BPFHelperID.BPF_PROBE_READ_KERNEL.value), + ir.PointerType(fn_type), + ) + + result = builder.call( + fn_ptr, + [ + builder.bitcast(dst_ptr, ir.PointerType()), + ir.Constant(ir.IntType(32), dst_size), + builder.bitcast(src_ptr, ir.PointerType()), + ], + tail=False, + ) + + logger.info(f"Emitted bpf_probe_read_kernel (size={dst_size})") + return result + + +@HelperHandlerRegistry.register( + "probe_read_kernel", + param_types=[ + ir.PointerType(ir.IntType(8)), + ir.PointerType(ir.IntType(8)), + ], + return_type=ir.IntType(64), +) +def bpf_probe_read_kernel_emitter( + call, + map_ptr, + module, + builder, + func, + local_sym_tab=None, + struct_sym_tab=None, + map_sym_tab=None, +): + """Emit LLVM IR for bpf_probe_read_kernel helper.""" + + if len(call.args) != 2: + raise ValueError( + f"probe_read_kernel expects 2 args (dst, src), got {len(call.args)}" + ) + + # Get destination buffer (char array -> i8*) + dst_ptr, dst_size = get_or_create_ptr_from_arg( + func, module, call.args[0], builder, local_sym_tab, map_sym_tab, struct_sym_tab + ) + + # Get source pointer (evaluate expression) + src_ptr, src_type = get_ptr_from_arg( + call.args[1], func, module, builder, local_sym_tab, map_sym_tab, struct_sym_tab + ) + + # Emit the helper call + result = emit_probe_read_kernel_call(builder, dst_ptr, dst_size, src_ptr) + + logger.info(f"Emitted bpf_probe_read_kernel (size={dst_size})") + return result, ir.IntType(64) + + @HelperHandlerRegistry.register( "random", param_types=[], diff --git a/pythonbpf/vmlinux_parser/class_handler.py b/pythonbpf/vmlinux_parser/class_handler.py index a508ff75..0c66ba21 100644 --- a/pythonbpf/vmlinux_parser/class_handler.py +++ b/pythonbpf/vmlinux_parser/class_handler.py @@ -16,6 +16,33 @@ def get_module_symbols(module_name: str): return [name for name in dir(imported_module)], imported_module +def unwrap_pointer_type(type_obj: Any) -> Any: + """ + Recursively unwrap all pointer layers to get the base type. + + This handles multiply nested pointers like LP_LP_struct_attribute_group + and returns the base type (struct_attribute_group). + + Stops unwrapping when reaching a non-pointer type (one without _type_ attribute). + + Args: + type_obj: The type object to unwrap + + Returns: + The base type after unwrapping all pointer layers + """ + current_type = type_obj + # Keep unwrapping while it's a pointer/array type (has _type_) + # But stop if _type_ is just a string or basic type marker + while hasattr(current_type, "_type_"): + next_type = current_type._type_ + # Stop if _type_ is a string (like 'c' for c_char) + if isinstance(next_type, str): + break + current_type = next_type + return current_type + + def process_vmlinux_class( node, llvm_module, @@ -158,13 +185,90 @@ def process_vmlinux_post_ast( if hasattr(elem_type, "_length_") and is_complex_type: type_length = elem_type._length_ - if containing_type.__module__ == "vmlinux": - new_dep_node.add_dependent( - elem_type._type_.__name__ - if hasattr(elem_type._type_, "__name__") - else str(elem_type._type_) + # Unwrap all pointer layers to get the base type for dependency tracking + base_type = unwrap_pointer_type(elem_type) + base_type_module = getattr(base_type, "__module__", None) + + if base_type_module == "vmlinux": + base_type_name = ( + base_type.__name__ + if hasattr(base_type, "__name__") + else str(base_type) + ) + # ONLY add vmlinux types as dependencies + new_dep_node.add_dependent(base_type_name) + + logger.debug( + f"{containing_type} containing type of parent {elem_name} with {elem_type} and ctype {ctype_complex_type} and length {type_length}" + ) + new_dep_node.set_field_containing_type( + elem_name, containing_type + ) + new_dep_node.set_field_type_size(elem_name, type_length) + new_dep_node.set_field_ctype_complex_type( + elem_name, ctype_complex_type + ) + new_dep_node.set_field_type(elem_name, elem_type) + + # Check the containing_type module to decide whether to recurse + containing_type_module = getattr( + containing_type, "__module__", None + ) + if containing_type_module == "vmlinux": + # Also unwrap containing_type to get base type name + base_containing_type = unwrap_pointer_type( + containing_type + ) + containing_type_name = ( + base_containing_type.__name__ + if hasattr(base_containing_type, "__name__") + else str(base_containing_type) + ) + + # Check for self-reference or already processed + if containing_type_name == current_symbol_name: + # Self-referential pointer + logger.debug( + f"Self-referential pointer in {current_symbol_name}.{elem_name}" + ) + new_dep_node.set_field_ready(elem_name, True) + elif handler.has_node(containing_type_name): + # Already processed + logger.debug( + f"Reusing already processed {containing_type_name}" + ) + new_dep_node.set_field_ready(elem_name, True) + else: + # Process recursively - use base containing type, not the pointer wrapper + new_dep_node.add_dependent(containing_type_name) + process_vmlinux_post_ast( + base_containing_type, + llvm_handler, + handler, + processing_stack, + ) + new_dep_node.set_field_ready(elem_name, True) + elif ( + containing_type_module == ctypes.__name__ + or containing_type_module is None + ): + logger.debug( + f"Processing ctype internal{containing_type}" + ) + new_dep_node.set_field_ready(elem_name, True) + else: + raise TypeError( + f"Module not supported in recursive resolution: {containing_type_module}" + ) + elif ( + base_type_module == ctypes.__name__ + or base_type_module is None + ): + # Handle ctypes or types with no module (like some internal ctypes types) + # DO NOT add ctypes as dependencies - just set field metadata and mark ready + logger.debug( + f"Base type {base_type} is ctypes - NOT adding as dependency, just processing field" ) - elif containing_type.__module__ == ctypes.__name__: if isinstance(elem_type, type): if issubclass(elem_type, ctypes.Array): ctype_complex_type = ctypes.Array @@ -176,57 +280,20 @@ def process_vmlinux_post_ast( ) else: raise TypeError("Unsupported ctypes subclass") - else: - raise ImportError( - f"Unsupported module of {containing_type}" + + # Set field metadata but DO NOT add dependency or recurse + new_dep_node.set_field_containing_type( + elem_name, containing_type ) - logger.debug( - f"{containing_type} containing type of parent {elem_name} with {elem_type} and ctype {ctype_complex_type} and length {type_length}" - ) - new_dep_node.set_field_containing_type( - elem_name, containing_type - ) - new_dep_node.set_field_type_size(elem_name, type_length) - new_dep_node.set_field_ctype_complex_type( - elem_name, ctype_complex_type - ) - new_dep_node.set_field_type(elem_name, elem_type) - if containing_type.__module__ == "vmlinux": - containing_type_name = ( - containing_type.__name__ - if hasattr(containing_type, "__name__") - else str(containing_type) + new_dep_node.set_field_type_size(elem_name, type_length) + new_dep_node.set_field_ctype_complex_type( + elem_name, ctype_complex_type ) - - # Check for self-reference or already processed - if containing_type_name == current_symbol_name: - # Self-referential pointer - logger.debug( - f"Self-referential pointer in {current_symbol_name}.{elem_name}" - ) - new_dep_node.set_field_ready(elem_name, True) - elif handler.has_node(containing_type_name): - # Already processed - logger.debug( - f"Reusing already processed {containing_type_name}" - ) - new_dep_node.set_field_ready(elem_name, True) - else: - # Process recursively - THIS WAS MISSING - new_dep_node.add_dependent(containing_type_name) - process_vmlinux_post_ast( - containing_type, - llvm_handler, - handler, - processing_stack, - ) - new_dep_node.set_field_ready(elem_name, True) - elif containing_type.__module__ == ctypes.__name__: - logger.debug(f"Processing ctype internal{containing_type}") + new_dep_node.set_field_type(elem_name, elem_type) new_dep_node.set_field_ready(elem_name, True) else: - raise TypeError( - "Module not supported in recursive resolution" + raise ImportError( + f"Unsupported module of {base_type}: {base_type_module}" ) else: new_dep_node.add_dependent( @@ -245,9 +312,12 @@ def process_vmlinux_post_ast( raise ValueError( f"{elem_name} with type {elem_type} from module {module_name} not supported in recursive resolver" ) - + elif module_name == ctypes.__name__ or module_name is None: + # Handle ctypes types - these don't need processing, just return + logger.debug(f"Skipping ctypes type {current_symbol_name}") + return True else: - raise ImportError("UNSUPPORTED Module") + raise ImportError(f"UNSUPPORTED Module {module_name}") logger.info( f"{current_symbol_name} processed and handler readiness {handler.is_ready}" diff --git a/pythonbpf/vmlinux_parser/import_detector.py b/pythonbpf/vmlinux_parser/import_detector.py index d90c4789..b0da40ae 100644 --- a/pythonbpf/vmlinux_parser/import_detector.py +++ b/pythonbpf/vmlinux_parser/import_detector.py @@ -11,7 +11,9 @@ logger = logging.getLogger(__name__) -def detect_import_statement(tree: ast.AST) -> list[tuple[str, ast.ImportFrom]]: +def detect_import_statement( + tree: ast.AST, +) -> list[tuple[str, ast.ImportFrom, str, str]]: """ Parse AST and detect import statements from vmlinux. @@ -25,7 +27,7 @@ def detect_import_statement(tree: ast.AST) -> list[tuple[str, ast.ImportFrom]]: List of tuples containing (module_name, imported_item) for each vmlinux import Raises: - SyntaxError: If multiple imports from vmlinux are attempted or import * is used + SyntaxError: If import * is used """ vmlinux_imports = [] @@ -40,28 +42,19 @@ def detect_import_statement(tree: ast.AST) -> list[tuple[str, ast.ImportFrom]]: "Please import specific types explicitly." ) - # Check for multiple imports: from vmlinux import A, B, C - if len(node.names) > 1: - imported_names = [alias.name for alias in node.names] - raise SyntaxError( - f"Multiple imports from vmlinux are not supported. " - f"Found: {', '.join(imported_names)}. " - f"Please use separate import statements for each type." - ) - # Check if no specific import is specified (should not happen with valid Python) if len(node.names) == 0: raise SyntaxError( "Import from vmlinux must specify at least one type." ) - # Valid single import + # Support multiple imports: from vmlinux import A, B, C for alias in node.names: import_name = alias.name - # Use alias if provided, otherwise use the original name (commented) - # as_name = alias.asname if alias.asname else alias.name - vmlinux_imports.append(("vmlinux", node)) - logger.info(f"Found vmlinux import: {import_name}") + # Use alias if provided, otherwise use the original name + as_name = alias.asname if alias.asname else alias.name + vmlinux_imports.append(("vmlinux", node, import_name, as_name)) + logger.info(f"Found vmlinux import: {import_name} as {as_name}") # Handle "import vmlinux" statements (not typical but should be rejected) elif isinstance(node, ast.Import): @@ -103,40 +96,37 @@ def vmlinux_proc(tree: ast.AST, module): with open(source_file, "r") as f: mod_ast = ast.parse(f.read(), filename=source_file) - for import_mod, import_node in import_statements: - for alias in import_node.names: - imported_name = alias.name - found = False - for mod_node in mod_ast.body: - if ( - isinstance(mod_node, ast.ClassDef) - and mod_node.name == imported_name - ): - process_vmlinux_class(mod_node, module, handler) - found = True - break - if isinstance(mod_node, ast.Assign): - for target in mod_node.targets: - if isinstance(target, ast.Name) and target.id == imported_name: - process_vmlinux_assign(mod_node, module, assignments) - found = True - break - if found: - break - if not found: - logger.info( - f"{imported_name} not found as ClassDef or Assign in vmlinux" - ) + for import_mod, import_node, imported_name, as_name in import_statements: + found = False + for mod_node in mod_ast.body: + if isinstance(mod_node, ast.ClassDef) and mod_node.name == imported_name: + process_vmlinux_class(mod_node, module, handler) + found = True + break + if isinstance(mod_node, ast.Assign): + for target in mod_node.targets: + if isinstance(target, ast.Name) and target.id == imported_name: + process_vmlinux_assign(mod_node, module, assignments, as_name) + found = True + break + if found: + break + if not found: + logger.info(f"{imported_name} not found as ClassDef or Assign in vmlinux") IRGenerator(module, handler, assignments) return assignments -def process_vmlinux_assign(node, module, assignments: dict[str, AssignmentInfo]): +def process_vmlinux_assign( + node, module, assignments: dict[str, AssignmentInfo], target_name=None +): """Process assignments from vmlinux module.""" # Only handle single-target assignments if len(node.targets) == 1 and isinstance(node.targets[0], ast.Name): - target_name = node.targets[0].id + # Use provided target_name (for aliased imports) or fall back to original name + if target_name is None: + target_name = node.targets[0].id # Handle constant value assignments if isinstance(node.value, ast.Constant): diff --git a/pythonbpf/vmlinux_parser/ir_gen/debug_info_gen.py b/pythonbpf/vmlinux_parser/ir_gen/debug_info_gen.py index 232cf10a..c4f5642c 100644 --- a/pythonbpf/vmlinux_parser/ir_gen/debug_info_gen.py +++ b/pythonbpf/vmlinux_parser/ir_gen/debug_info_gen.py @@ -21,7 +21,7 @@ def debug_info_generation( generated_debug_info: List of tuples (struct, debug_info) to track generated debug info Returns: - The generated global variable debug info + The generated global variable debug info, or None for unsupported types """ # Set up debug info generator generator = DebugInfoGenerator(llvm_module) @@ -31,23 +31,42 @@ def debug_info_generation( if existing_struct.name == struct.name: return debug_info + # Check if this is a union (not supported yet) + if not struct.name.startswith("struct_"): + logger.warning(f"Skipping debug info generation for union: {struct.name}") + # Create a minimal forward declaration for unions + union_type = generator.create_struct_type( + [], struct.__sizeof__() * 8, is_distinct=True + ) + return union_type + # Process all fields and create members for the struct members = [] - for field_name, field in struct.fields.items(): - # Get appropriate debug type for this field - field_type = _get_field_debug_type( - field_name, field, generator, struct, generated_debug_info - ) - # Create struct member with proper offset - member = generator.create_struct_member_vmlinux( - field_name, field_type, field.offset * 8 - ) - members.append(member) - if struct.name.startswith("struct_"): - struct_name = struct.name.removeprefix("struct_") - else: - raise ValueError("Unions are not supported in the current version") + sorted_fields = sorted(struct.fields.items(), key=lambda item: item[1].offset) + + for field_name, field in sorted_fields: + try: + # Get appropriate debug type for this field + field_type = _get_field_debug_type( + field_name, field, generator, struct, generated_debug_info + ) + + # Ensure field_type is a tuple + if not isinstance(field_type, tuple) or len(field_type) != 2: + logger.error(f"Invalid field_type for {field_name}: {field_type}") + continue + + # Create struct member with proper offset + member = generator.create_struct_member_vmlinux( + field_name, field_type, field.offset * 8 + ) + members.append(member) + except Exception as e: + logger.error(f"Failed to process field {field_name} in {struct.name}: {e}") + continue + + struct_name = struct.name.removeprefix("struct_") # Create struct type with all members struct_type = generator.create_struct_type_with_name( struct_name, members, struct.__sizeof__() * 8, is_distinct=True @@ -74,11 +93,19 @@ def _get_field_debug_type( generated_debug_info: List of already generated debug info Returns: - The debug info type for this field + A tuple of (debug_type, size_in_bits) """ - # Handle complex types (arrays, pointers) + # Handle complex types (arrays, pointers, function pointers) if field.ctype_complex_type is not None: - if issubclass(field.ctype_complex_type, ctypes.Array): + # Handle function pointer types (CFUNCTYPE) + if callable(field.ctype_complex_type): + # Function pointers are represented as void pointers + logger.warning( + f"Field {field_name} is a function pointer, using void pointer" + ) + void_ptr = generator.create_pointer_type(None, 64) + return void_ptr, 64 + elif issubclass(field.ctype_complex_type, ctypes.Array): # Handle array types element_type, base_type_size = _get_basic_debug_type( field.containing_type, generator @@ -100,11 +127,13 @@ def _get_field_debug_type( for existing_struct, debug_info in generated_debug_info: if existing_struct.name == struct_name: # Use existing debug info - return debug_info, existing_struct.__sizeof__() + return debug_info, existing_struct.__sizeof__() * 8 # If not found, create a forward declaration # This will be completed when the actual struct is processed - logger.warning("Forward declaration in struct created") + logger.info( + f"Forward declaration created for {struct_name} in {parent_struct.name}" + ) forward_type = generator.create_struct_type([], 0, is_distinct=True) return forward_type, 0 diff --git a/pythonbpf/vmlinux_parser/ir_gen/ir_generation.py b/pythonbpf/vmlinux_parser/ir_gen/ir_generation.py index 14a74ad0..6a7088cd 100644 --- a/pythonbpf/vmlinux_parser/ir_gen/ir_generation.py +++ b/pythonbpf/vmlinux_parser/ir_gen/ir_generation.py @@ -11,6 +11,10 @@ class IRGenerator: + # This field keeps track of the non_struct names to avoid duplicate name errors. + type_number = 0 + unprocessed_store: list[str] = [] + # get the assignments dict and add this stuff to it. def __init__(self, llvm_module, handler: DependencyHandler, assignments): self.llvm_module = llvm_module @@ -129,7 +133,19 @@ def gen_ir(self, struct, generated_debug_info): for field_name, field in struct.fields.items(): # does not take arrays and similar types into consideration yet. - if field.ctype_complex_type is not None and issubclass( + if callable(field.ctype_complex_type): + # Function pointer case - generate a simple field accessor + field_co_re_name, returned = self._struct_name_generator( + struct, field, field_index + ) + field_index += 1 + globvar = ir.GlobalVariable( + self.llvm_module, ir.IntType(64), name=field_co_re_name + ) + globvar.linkage = "external" + globvar.set_metadata("llvm.preserve.access.index", debug_info) + self.generated_field_names[struct.name][field_name] = globvar + elif field.ctype_complex_type is not None and issubclass( field.ctype_complex_type, ctypes.Array ): array_size = field.type_size @@ -137,7 +153,7 @@ def gen_ir(self, struct, generated_debug_info): if containing_type.__module__ == ctypes.__name__: containing_type_size = ctypes.sizeof(containing_type) if array_size == 0: - field_co_re_name = self._struct_name_generator( + field_co_re_name, returned = self._struct_name_generator( struct, field, field_index, True, 0, containing_type_size ) globvar = ir.GlobalVariable( @@ -149,7 +165,7 @@ def gen_ir(self, struct, generated_debug_info): field_index += 1 continue for i in range(0, array_size): - field_co_re_name = self._struct_name_generator( + field_co_re_name, returned = self._struct_name_generator( struct, field, field_index, True, i, containing_type_size ) globvar = ir.GlobalVariable( @@ -163,12 +179,28 @@ def gen_ir(self, struct, generated_debug_info): array_size = field.type_size containing_type = field.containing_type if containing_type.__module__ == "vmlinux": - containing_type_size = self.handler[ - containing_type.__name__ - ].current_offset - for i in range(0, array_size): - field_co_re_name = self._struct_name_generator( - struct, field, field_index, True, i, containing_type_size + # Unwrap all pointer layers to get the base struct type + base_containing_type = containing_type + while hasattr(base_containing_type, "_type_"): + next_type = base_containing_type._type_ + # Stop if _type_ is a string (like 'c' for c_char) + # TODO: stacked pointers not handl;ing ctypes check here as well + if isinstance(next_type, str): + break + base_containing_type = next_type + + # Get the base struct name + base_struct_name = ( + base_containing_type.__name__ + if hasattr(base_containing_type, "__name__") + else str(base_containing_type) + ) + + # Look up the size using the base struct name + containing_type_size = self.handler[base_struct_name].current_offset + if array_size == 0: + field_co_re_name, returned = self._struct_name_generator( + struct, field, field_index, True, 0, containing_type_size ) globvar = ir.GlobalVariable( self.llvm_module, ir.IntType(64), name=field_co_re_name @@ -176,9 +208,30 @@ def gen_ir(self, struct, generated_debug_info): globvar.linkage = "external" globvar.set_metadata("llvm.preserve.access.index", debug_info) self.generated_field_names[struct.name][field_name] = globvar - field_index += 1 + field_index += 1 + else: + for i in range(0, array_size): + field_co_re_name, returned = self._struct_name_generator( + struct, + field, + field_index, + True, + i, + containing_type_size, + ) + globvar = ir.GlobalVariable( + self.llvm_module, ir.IntType(64), name=field_co_re_name + ) + globvar.linkage = "external" + globvar.set_metadata( + "llvm.preserve.access.index", debug_info + ) + self.generated_field_names[struct.name][field_name] = ( + globvar + ) + field_index += 1 else: - field_co_re_name = self._struct_name_generator( + field_co_re_name, returned = self._struct_name_generator( struct, field, field_index ) field_index += 1 @@ -198,7 +251,7 @@ def _struct_name_generator( is_indexed: bool = False, index: int = 0, containing_type_size: int = 0, - ) -> str: + ) -> tuple[str, bool]: # TODO: Does not support Unions as well as recursive pointer and array type naming if is_indexed: name = ( @@ -208,7 +261,7 @@ def _struct_name_generator( + "$" + f"0:{field_index}:{index}" ) - return name + return name, True elif struct.name.startswith("struct_"): name = ( "llvm." @@ -217,9 +270,18 @@ def _struct_name_generator( + "$" + f"0:{field_index}" ) - return name + return name, True else: - print(self.handler[struct.name]) - raise TypeError( - "Name generation cannot occur due to type name not starting with struct" + logger.warning( + "Blindly handling non-struct type to avoid type errors in vmlinux IR generation. Possibly a union." ) + self.type_number += 1 + unprocessed_type = "unprocessed_type_" + str(self.handler[struct.name].name) + if self.unprocessed_store.__contains__(unprocessed_type): + return unprocessed_type + "_" + str(self.type_number), False + else: + self.unprocessed_store.append(unprocessed_type) + return unprocessed_type, False + # raise TypeError( + # "Name generation cannot occur due to type name not starting with struct" + # ) diff --git a/pythonbpf/vmlinux_parser/vmlinux_exports_handler.py b/pythonbpf/vmlinux_parser/vmlinux_exports_handler.py index 30f30589..c26cac9e 100644 --- a/pythonbpf/vmlinux_parser/vmlinux_exports_handler.py +++ b/pythonbpf/vmlinux_parser/vmlinux_exports_handler.py @@ -94,17 +94,140 @@ def handle_vmlinux_struct_field( f"Attempting to access field {field_name} of possible vmlinux struct {struct_var_name}" ) python_type: type = var_info.metadata - struct_name = python_type.__name__ - globvar_ir, field_data = self.get_field_type(struct_name, field_name) - builder.function.args[0].type = ir.PointerType(ir.IntType(8)) - field_ptr = self.load_ctx_field( - builder, builder.function.args[0], globvar_ir, field_data, struct_name - ) - # Return pointer to field and field type - return field_ptr, field_data + # Check if this is a context field (ctx) or a cast struct + is_context_field = var_info.var is None + + if is_context_field: + # Handle context field access (original behavior) + struct_name = python_type.__name__ + globvar_ir, field_data = self.get_field_type(struct_name, field_name) + builder.function.args[0].type = ir.PointerType(ir.IntType(8)) + field_ptr = self.load_ctx_field( + builder, + builder.function.args[0], + globvar_ir, + field_data, + struct_name, + ) + return field_ptr, field_data + else: + # Handle cast struct field access + struct_name = python_type.__name__ + globvar_ir, field_data = self.get_field_type(struct_name, field_name) + + # Handle cast struct field access (use bpf_probe_read_kernel) + # Load the struct pointer from the local variable + struct_ptr = builder.load(var_info.var) + + # Use bpf_probe_read_kernel for non-context struct field access + field_value = self.load_struct_field( + builder, struct_ptr, globvar_ir, field_data, struct_name + ) + # Return field value and field type + return field_value, field_data else: raise RuntimeError("Variable accessed not found in symbol table") + @staticmethod + def load_struct_field( + builder, struct_ptr_int, offset_global, field_data, struct_name=None + ): + """ + Generate LLVM IR to load a field from a regular (non-context) struct using bpf_probe_read_kernel. + + Args: + builder: llvmlite IRBuilder instance + struct_ptr_int: The struct pointer as an i64 value (already loaded from alloca) + offset_global: Global variable containing the field offset (i64) + field_data: contains data about the field + struct_name: Name of the struct being accessed (optional) + Returns: + The loaded value + """ + + # Load the offset value + offset = builder.load(offset_global) + + # Convert i64 to pointer type (BPF stores pointers as i64) + i8_ptr_type = ir.PointerType(ir.IntType(8)) + struct_ptr = builder.inttoptr(struct_ptr_int, i8_ptr_type) + + # GEP with offset to get field pointer + field_ptr = builder.gep( + struct_ptr, + [offset], + inbounds=False, + ) + + # Determine the appropriate field size based on field information + field_size_bytes = 8 # Default to 8 bytes (64-bit) + int_width = 64 # Default to 64-bit + needs_zext = False + + if field_data is not None: + # Try to determine the size from field metadata + if field_data.type.__module__ == ctypes.__name__: + try: + field_size_bytes = ctypes.sizeof(field_data.type) + field_size_bits = field_size_bytes * 8 + + if field_size_bits in [8, 16, 32, 64]: + int_width = field_size_bits + logger.info( + f"Determined field size: {int_width} bits ({field_size_bytes} bytes)" + ) + + # Special handling for struct_xdp_md i32 fields + if struct_name == "struct_xdp_md" and int_width == 32: + needs_zext = True + logger.info( + "struct_xdp_md i32 field detected, will zero-extend to i64" + ) + else: + logger.warning( + f"Unusual field size {field_size_bits} bits, using default 64" + ) + except Exception as e: + logger.warning( + f"Could not determine field size: {e}, using default 64" + ) + + elif field_data.type.__module__ == "vmlinux": + # For pointers to structs or complex vmlinux types + if field_data.ctype_complex_type is not None and issubclass( + field_data.ctype_complex_type, ctypes._Pointer + ): + int_width = 64 # Pointers are always 64-bit + field_size_bytes = 8 + logger.info("Field is a pointer type, using 64 bits") + else: + logger.warning("Complex vmlinux field type, using default 64 bits") + + # Allocate local storage for the field value + local_storage = builder.alloca(ir.IntType(int_width)) + local_storage_i8_ptr = builder.bitcast(local_storage, i8_ptr_type) + + # Use bpf_probe_read_kernel to safely read the field + # This generates: + # %gep = getelementptr i8, ptr %struct_ptr, i64 %offset (already done above as field_ptr) + # %passed = tail call ptr @llvm.bpf.passthrough.p0.p0(i32 2, ptr %gep) + # %result = call i64 inttoptr (i64 113 to ptr)(ptr %local_storage, i32 %size, ptr %passed) + from pythonbpf.helper import emit_probe_read_kernel_call + + emit_probe_read_kernel_call( + builder, local_storage_i8_ptr, field_size_bytes, field_ptr + ) + + # Load the value from local storage + value = builder.load(local_storage) + + # Zero-extend i32 to i64 if needed + if needs_zext: + value = builder.zext(value, ir.IntType(64)) + logger.info("Zero-extended i32 value to i64") + + return value + @staticmethod def load_ctx_field(builder, ctx_arg, offset_global, field_data, struct_name=None): """ diff --git a/tests/c-form/Makefile b/tests/c-form/Makefile index 03f8ef2f..a34debac 100644 --- a/tests/c-form/Makefile +++ b/tests/c-form/Makefile @@ -3,21 +3,20 @@ CFLAGS := -emit-llvm -target bpf -c SRC := $(wildcard *.bpf.c) LL := $(SRC:.bpf.c=.bpf.ll) -LL2 := $(SRC:.bpf.c=.bpf.o2.ll) OBJ := $(SRC:.bpf.c=.bpf.o) - +LL0 := $(SRC:.bpf.c=.bpf.o0.ll) .PHONY: all clean -all: $(LL) $(OBJ) $(LL2) +all: $(LL) $(OBJ) $(LL0) %.bpf.o: %.bpf.c $(BPF_CLANG) -O2 -g -target bpf -c $< -o $@ %.bpf.ll: %.bpf.c - $(BPF_CLANG) -O0 $(CFLAGS) -g -S $< -o $@ + $(BPF_CLANG) $(CFLAGS) -O2 -g -S $< -o $@ -%.bpf.o2.ll: %.bpf.c - $(BPF_CLANG) -O2 $(CFLAGS) -g -S $< -o $@ +%.bpf.o0.ll: %.bpf.c + $(BPF_CLANG) $(CFLAGS) -O0 -g -S $< -o $@ clean: - rm -f $(LL) $(OBJ) $(LL2) + rm -f $(LL) $(OBJ) $(LL0) diff --git a/tests/c-form/requests.bpf.c b/tests/c-form/requests.bpf.c new file mode 100644 index 00000000..55b12397 --- /dev/null +++ b/tests/c-form/requests.bpf.c @@ -0,0 +1,18 @@ +#include "vmlinux.h" +#include +#include +#include + +char LICENSE[] SEC("license") = "GPL"; + +SEC("kprobe/blk_mq_start_request") +int example(struct pt_regs *ctx) +{ + u64 a = ctx->r15; + struct request *req = (struct request *)(ctx->di); + unsigned int something_ns = BPF_CORE_READ(req, timeout); + unsigned int data_len = BPF_CORE_READ(req, __data_len); + bpf_printk("data length %lld %ld %ld\n", data_len, something_ns, a); + + return 0; +} diff --git a/tests/c-form/requests2.bpf.c b/tests/c-form/requests2.bpf.c new file mode 100644 index 00000000..c0cbf9f2 --- /dev/null +++ b/tests/c-form/requests2.bpf.c @@ -0,0 +1,18 @@ +#include "vmlinux.h" +#include +#include +#include + +char LICENSE[] SEC("license") = "GPL"; + +SEC("kprobe/blk_mq_start_request") +int example(struct pt_regs *ctx) +{ + u64 a = ctx->r15; + struct request *req = (struct request *)(ctx->di); + unsigned int something_ns = req->timeout; + unsigned int data_len = req->__data_len; + bpf_printk("data length %lld %ld %ld\n", data_len, something_ns, a); + + return 0; +} diff --git a/tests/failing_tests/vmlinux/assignment_handling.py b/tests/failing_tests/vmlinux/assignment_handling.py new file mode 100644 index 00000000..b8fe43ee --- /dev/null +++ b/tests/failing_tests/vmlinux/assignment_handling.py @@ -0,0 +1,22 @@ +from vmlinux import XDP_PASS +from pythonbpf import bpf, section, bpfglobal, compile_to_ir +import logging +from ctypes import c_int64, c_void_p + + +@bpf +@section("kprobe/blk_mq_start_request") +def example(ctx: c_void_p) -> c_int64: + d = XDP_PASS # This gives an error, but + e = XDP_PASS + 0 # this does not + print(f"test1 {e} test2 {d}") + return c_int64(0) + + +@bpf +@bpfglobal +def LICENSE() -> str: + return "GPL" + + +compile_to_ir("assignment_handling.py", "assignment_handling.ll", loglevel=logging.INFO) diff --git a/tests/passing_tests/vmlinux/requests.py b/tests/passing_tests/vmlinux/requests.py new file mode 100644 index 00000000..bb7fb9d9 --- /dev/null +++ b/tests/passing_tests/vmlinux/requests.py @@ -0,0 +1,27 @@ +from vmlinux import struct_request, struct_pt_regs +from pythonbpf import bpf, section, bpfglobal, compile_to_ir, compile +import logging +from ctypes import c_int64 + + +@bpf +@section("kprobe/blk_mq_start_request") +def example(ctx: struct_pt_regs) -> c_int64: + a = ctx.r15 + req = struct_request(ctx.di) + d = req.__data_len + b = ctx.r12 + c = req.timeout + print(f"data length {d} and {c} and {a}") + print(f"ctx arg {b}") + return c_int64(0) + + +@bpf +@bpfglobal +def LICENSE() -> str: + return "GPL" + + +compile_to_ir("requests.py", "requests.ll", loglevel=logging.INFO) +compile() diff --git a/tests/passing_tests/vmlinux/requests2.py b/tests/passing_tests/vmlinux/requests2.py new file mode 100644 index 00000000..63e90c71 --- /dev/null +++ b/tests/passing_tests/vmlinux/requests2.py @@ -0,0 +1,21 @@ +from vmlinux import struct_pt_regs +from pythonbpf import bpf, section, bpfglobal, compile_to_ir +import logging +from ctypes import c_int64 + + +@bpf +@section("kprobe/blk_mq_start_request") +def example(ctx: struct_pt_regs) -> c_int64: + req = ctx.di + print(f"data length {req}") + return c_int64(0) + + +@bpf +@bpfglobal +def LICENSE() -> str: + return "GPL" + + +compile_to_ir("requests2.py", "requests2.ll", loglevel=logging.INFO)