diff --git a/pythonbpf/allocation_pass.py b/pythonbpf/allocation_pass.py index b96a9cf0..3b25b6f8 100644 --- a/pythonbpf/allocation_pass.py +++ b/pythonbpf/allocation_pass.py @@ -1,6 +1,6 @@ import ast import logging - +import ctypes from llvmlite import ir from .local_symbol import LocalSymbol from pythonbpf.helper import HelperHandlerRegistry @@ -81,7 +81,7 @@ def _allocate_for_call(builder, var_name, rval, local_sym_tab, structs_sym_tab): call_type = rval.func.id # C type constructors - if call_type in ("c_int32", "c_int64", "c_uint32", "c_uint64"): + if call_type in ("c_int32", "c_int64", "c_uint32", "c_uint64", "c_void_p"): ir_type = ctypes_to_ir(call_type) var = builder.alloca(ir_type, name=var_name) var.align = ir_type.width // 8 @@ -249,7 +249,58 @@ def _allocate_for_attribute(builder, var_name, rval, local_sym_tab, structs_sym_ ].var = base_ptr # This is repurposing of var to store the pointer of the base type local_sym_tab[struct_var].ir_type = field_ir - actual_ir_type = ir.IntType(64) + # Determine the actual IR type based on the field's type + actual_ir_type = None + + # Check if it's a ctypes primitive + if field.type.__module__ == ctypes.__name__: + try: + field_size_bytes = ctypes.sizeof(field.type) + field_size_bits = field_size_bytes * 8 + + if field_size_bits in [8, 16, 32, 64]: + # Special case: struct_xdp_md i32 fields should allocate as i64 + # because load_ctx_field will zero-extend them to i64 + if ( + vmlinux_struct_name == "struct_xdp_md" + and field_size_bits == 32 + ): + actual_ir_type = ir.IntType(64) + logger.info( + f"Allocating {var_name} as i64 for i32 field from struct_xdp_md.{field_name} " + "(will be zero-extended during load)" + ) + else: + actual_ir_type = ir.IntType(field_size_bits) + else: + logger.warning( + f"Unusual field size {field_size_bits} bits for {field_name}" + ) + actual_ir_type = ir.IntType(64) + except Exception as e: + logger.warning( + f"Could not determine size for ctypes field {field_name}: {e}" + ) + actual_ir_type = ir.IntType(64) + + # Check if it's a nested vmlinux struct or complex type + elif field.type.__module__ == "vmlinux": + # For pointers to structs, use pointer type (64-bit) + if field.ctype_complex_type is not None and issubclass( + field.ctype_complex_type, ctypes._Pointer + ): + actual_ir_type = ir.IntType(64) # Pointer is always 64-bit + # For embedded structs, this is more complex - might need different handling + else: + logger.warning( + f"Field {field_name} is a nested vmlinux struct, using i64 for now" + ) + actual_ir_type = ir.IntType(64) + else: + logger.warning( + f"Unknown field type module {field.type.__module__} for {field_name}" + ) + actual_ir_type = ir.IntType(64) # Allocate with the actual IR type, not the GlobalVariable var = _allocate_with_type(builder, var_name, actual_ir_type) diff --git a/pythonbpf/assign_pass.py b/pythonbpf/assign_pass.py index a1c2798d..0bd48c61 100644 --- a/pythonbpf/assign_pass.py +++ b/pythonbpf/assign_pass.py @@ -152,15 +152,30 @@ def handle_variable_assignment( if val_type != var_type: if isinstance(val_type, Field): logger.info("Handling assignment to struct field") + # Special handling for struct_xdp_md i32 fields that are zero-extended to i64 + # The load_ctx_field already extended them, so val is i64 but val_type.type shows c_uint + if ( + hasattr(val_type, "type") + and val_type.type.__name__ == "c_uint" + and isinstance(var_type, ir.IntType) + and var_type.width == 64 + ): + # This is the struct_xdp_md case - value is already i64 + builder.store(val, var_ptr) + logger.info( + f"Assigned zero-extended struct_xdp_md i32 field to {var_name} (i64)" + ) + return True # TODO: handling only ctype struct fields for now. Handle other stuff too later. - if var_type == ctypes_to_ir(val_type.type.__name__): + elif var_type == ctypes_to_ir(val_type.type.__name__): builder.store(val, var_ptr) logger.info(f"Assigned ctype struct field to {var_name}") return True - logger.error( - f"Failed to assign ctype struct field to {var_name}: {val_type} != {var_type}" - ) - return False + else: + logger.error( + f"Failed to assign ctype struct field to {var_name}: {val_type} != {var_type}" + ) + return False elif isinstance(val_type, ir.IntType) and isinstance(var_type, ir.IntType): # Allow implicit int widening if val_type.width < var_type.width: diff --git a/pythonbpf/expr/expr_pass.py b/pythonbpf/expr/expr_pass.py index 1d10fcbd..a9eab987 100644 --- a/pythonbpf/expr/expr_pass.py +++ b/pythonbpf/expr/expr_pass.py @@ -12,6 +12,7 @@ get_base_type_and_depth, deref_to_depth, ) +from pythonbpf.vmlinux_parser.assignment_info import Field from .vmlinux_registry import VmlinuxHandlerRegistry logger: Logger = logging.getLogger(__name__) @@ -279,16 +280,45 @@ def _handle_ctypes_call( call_type = expr.func.id expected_type = ctypes_to_ir(call_type) - if val[1] != expected_type: + # Extract the actual IR value and type + # val could be (value, ir_type) or (value, Field) + value, val_type = val + + # If val_type is a Field object (from vmlinux struct), get the actual IR type of the value + if isinstance(val_type, Field): + # The value is already the correct IR value (potentially zero-extended) + # Get the IR type from the value itself + actual_ir_type = value.type + logger.info( + f"Converting vmlinux field {val_type.name} (IR type: {actual_ir_type}) to {call_type}" + ) + else: + actual_ir_type = val_type + + if actual_ir_type != expected_type: # NOTE: We are only considering casting to and from int types for now - if isinstance(val[1], ir.IntType) and isinstance(expected_type, ir.IntType): - if val[1].width < expected_type.width: - val = (builder.sext(val[0], expected_type), expected_type) + if isinstance(actual_ir_type, ir.IntType) and isinstance( + expected_type, ir.IntType + ): + if actual_ir_type.width < expected_type.width: + value = builder.sext(value, expected_type) + logger.info( + f"Sign-extended from i{actual_ir_type.width} to i{expected_type.width}" + ) + elif actual_ir_type.width > expected_type.width: + value = builder.trunc(value, expected_type) + logger.info( + f"Truncated from i{actual_ir_type.width} to i{expected_type.width}" + ) else: - val = (builder.trunc(val[0], expected_type), expected_type) + # Same width, just use as-is (e.g., both i64) + pass else: - raise ValueError(f"Type mismatch: expected {expected_type}, got {val[1]}") - return val + raise ValueError( + f"Type mismatch: expected {expected_type}, got {actual_ir_type} (original type: {val_type})" + ) + + return value, expected_type def _handle_compare( diff --git a/pythonbpf/functions/function_debug_info.py b/pythonbpf/functions/function_debug_info.py index f924ebcf..985eb92a 100644 --- a/pythonbpf/functions/function_debug_info.py +++ b/pythonbpf/functions/function_debug_info.py @@ -49,17 +49,27 @@ def generate_function_debug_info( "The first argument should always be a pointer to a struct or a void pointer" ) context_debug_info = VmlinuxHandlerRegistry.get_struct_debug_info(annotation.id) + + # Create pointer to context this must be created fresh for each function + # to avoid circular reference issues when the same struct is used in multiple functions pointer_to_context_debug_info = generator.create_pointer_type( context_debug_info, 64 ) + + # Create subroutine type - also fresh for each function subroutine_type = generator.create_subroutine_type( return_type, pointer_to_context_debug_info ) + + # Create local variable - fresh for each function with unique name context_local_variable = generator.create_local_variable_debug_info( leading_argument_name, 1, pointer_to_context_debug_info ) + retained_nodes = [context_local_variable] - print("function name", func_node.name) + logger.info(f"Generating debug info for function {func_node.name}") + + # Create subprogram with is_distinct=True to ensure each function gets unique debug info subprogram_debug_info = generator.create_subprogram( func_node.name, subroutine_type, retained_nodes ) diff --git a/pythonbpf/type_deducer.py b/pythonbpf/type_deducer.py index a6834a9b..fd589ae0 100644 --- a/pythonbpf/type_deducer.py +++ b/pythonbpf/type_deducer.py @@ -16,6 +16,8 @@ "c_long": ir.IntType(64), "c_ulong": ir.IntType(64), "c_longlong": ir.IntType(64), + "c_uint": ir.IntType(32), + "c_int": ir.IntType(32), # Not so sure about this one "str": ir.PointerType(ir.IntType(8)), } diff --git a/pythonbpf/vmlinux_parser/vmlinux_exports_handler.py b/pythonbpf/vmlinux_parser/vmlinux_exports_handler.py index 62c03278..30f30589 100644 --- a/pythonbpf/vmlinux_parser/vmlinux_exports_handler.py +++ b/pythonbpf/vmlinux_parser/vmlinux_exports_handler.py @@ -1,6 +1,6 @@ import logging from typing import Any - +import ctypes from llvmlite import ir from pythonbpf.local_symbol import LocalSymbol @@ -94,22 +94,19 @@ def handle_vmlinux_struct_field( f"Attempting to access field {field_name} of possible vmlinux struct {struct_var_name}" ) python_type: type = var_info.metadata - globvar_ir, field_data = self.get_field_type( - python_type.__name__, field_name - ) + struct_name = python_type.__name__ + globvar_ir, field_data = self.get_field_type(struct_name, field_name) builder.function.args[0].type = ir.PointerType(ir.IntType(8)) - print(builder.function.args[0]) field_ptr = self.load_ctx_field( - builder, builder.function.args[0], globvar_ir + builder, builder.function.args[0], globvar_ir, field_data, struct_name ) - print(field_ptr) # Return pointer to field and field type return field_ptr, field_data else: raise RuntimeError("Variable accessed not found in symbol table") @staticmethod - def load_ctx_field(builder, ctx_arg, offset_global): + def load_ctx_field(builder, ctx_arg, offset_global, field_data, struct_name=None): """ Generate LLVM IR to load a field from BPF context using offset. @@ -117,9 +114,10 @@ def load_ctx_field(builder, ctx_arg, offset_global): builder: llvmlite IRBuilder instance ctx_arg: The context pointer argument (ptr/i8*) offset_global: Global variable containing the field offset (i64) - + field_data: contains data about the field + struct_name: Name of the struct being accessed (optional) Returns: - The loaded value (i64 register) + The loaded value (i64 register or appropriately sized) """ # Load the offset value @@ -164,13 +162,61 @@ def load_ctx_field(builder, ctx_arg, offset_global): passthrough_fn, [ir.Constant(ir.IntType(32), 0), field_ptr], tail=True ) - # Bitcast to i64* (assuming field is 64-bit, adjust if needed) - i64_ptr_type = ir.PointerType(ir.IntType(64)) - typed_ptr = builder.bitcast(verified_ptr, i64_ptr_type) + # Determine the appropriate IR type based on field information + int_width = 64 # Default to 64-bit + needs_zext = False # Track if we need zero-extension for xdp_md + + if field_data is not None: + # Try to determine the size from field metadata + if field_data.type.__module__ == ctypes.__name__: + try: + field_size_bytes = ctypes.sizeof(field_data.type) + field_size_bits = field_size_bytes * 8 + + if field_size_bits in [8, 16, 32, 64]: + int_width = field_size_bits + logger.info(f"Determined field size: {int_width} bits") + + # Special handling for struct_xdp_md i32 fields + # Load as i32 but extend to i64 before storing + if struct_name == "struct_xdp_md" and int_width == 32: + needs_zext = True + logger.info( + "struct_xdp_md i32 field detected, will zero-extend to i64" + ) + else: + logger.warning( + f"Unusual field size {field_size_bits} bits, using default 64" + ) + except Exception as e: + logger.warning( + f"Could not determine field size: {e}, using default 64" + ) + + elif field_data.type.__module__ == "vmlinux": + # For pointers to structs or complex vmlinux types + if field_data.ctype_complex_type is not None and issubclass( + field_data.ctype_complex_type, ctypes._Pointer + ): + int_width = 64 # Pointers are always 64-bit + logger.info("Field is a pointer type, using 64 bits") + # TODO: Add handling for other complex types (arrays, embedded structs, etc.) + else: + logger.warning("Complex vmlinux field type, using default 64 bits") + + # Bitcast to appropriate pointer type based on determined width + ptr_type = ir.PointerType(ir.IntType(int_width)) + + typed_ptr = builder.bitcast(verified_ptr, ptr_type) # Load and return the value value = builder.load(typed_ptr) + # Zero-extend i32 to i64 for struct_xdp_md fields + if needs_zext: + value = builder.zext(value, ir.IntType(64)) + logger.info("Zero-extended i32 value to i64 for struct_xdp_md field") + return value def has_field(self, struct_name, field_name): diff --git a/tests/c-form/Makefile b/tests/c-form/Makefile index 64ff9006..03f8ef2f 100644 --- a/tests/c-form/Makefile +++ b/tests/c-form/Makefile @@ -1,19 +1,23 @@ BPF_CLANG := clang -CFLAGS := -O0 -emit-llvm -target bpf -c +CFLAGS := -emit-llvm -target bpf -c SRC := $(wildcard *.bpf.c) LL := $(SRC:.bpf.c=.bpf.ll) +LL2 := $(SRC:.bpf.c=.bpf.o2.ll) OBJ := $(SRC:.bpf.c=.bpf.o) .PHONY: all clean -all: $(LL) $(OBJ) +all: $(LL) $(OBJ) $(LL2) %.bpf.o: %.bpf.c $(BPF_CLANG) -O2 -g -target bpf -c $< -o $@ %.bpf.ll: %.bpf.c - $(BPF_CLANG) $(CFLAGS) -g -S $< -o $@ + $(BPF_CLANG) -O0 $(CFLAGS) -g -S $< -o $@ + +%.bpf.o2.ll: %.bpf.c + $(BPF_CLANG) -O2 $(CFLAGS) -g -S $< -o $@ clean: - rm -f $(LL) $(OBJ) + rm -f $(LL) $(OBJ) $(LL2) diff --git a/tests/c-form/i32test.bpf.c b/tests/c-form/i32test.bpf.c new file mode 100644 index 00000000..8457babc --- /dev/null +++ b/tests/c-form/i32test.bpf.c @@ -0,0 +1,15 @@ +#include +#include + +SEC("xdp") +int print_xdp_data(struct xdp_md *ctx) +{ + // 'data' is a pointer to the start of packet data + long data = (long)ctx->data; + + bpf_printk("ctx->data = %lld\n", data); + + return XDP_PASS; +} + +char LICENSE[] SEC("license") = "GPL"; diff --git a/tests/failing_tests/vmlinux/args_test.py b/tests/failing_tests/vmlinux/args_test.py new file mode 100644 index 00000000..7acca25f --- /dev/null +++ b/tests/failing_tests/vmlinux/args_test.py @@ -0,0 +1,30 @@ +import logging + +from pythonbpf import bpf, section, bpfglobal, compile_to_ir +from pythonbpf import compile # noqa: F401 +from vmlinux import TASK_COMM_LEN # noqa: F401 +from vmlinux import struct_trace_event_raw_sys_enter # noqa: F401 +from ctypes import c_int64, c_int32, c_void_p # noqa: F401 + + +# from vmlinux import struct_uinput_device +# from vmlinux import struct_blk_integrity_iter + + +@bpf +@section("tracepoint/syscalls/sys_enter_execve") +def hello_world(ctx: struct_trace_event_raw_sys_enter) -> c_int64: + b = ctx.args + c = b[0] + print(f"This is context args field {c}") + return c_int64(0) + + +@bpf +@bpfglobal +def LICENSE() -> str: + return "GPL" + + +compile_to_ir("args_test.py", "args_test.ll", loglevel=logging.INFO) +compile() diff --git a/tests/passing_tests/vmlinux/i32_test.py b/tests/passing_tests/vmlinux/i32_test.py new file mode 100644 index 00000000..4ba09693 --- /dev/null +++ b/tests/passing_tests/vmlinux/i32_test.py @@ -0,0 +1,31 @@ +from ctypes import c_int64, c_void_p +from pythonbpf import bpf, section, bpfglobal, compile_to_ir, compile +from vmlinux import struct_xdp_md +from vmlinux import XDP_PASS + + +@bpf +@section("xdp") +def print_xdp_dat2a(ct2x: struct_xdp_md) -> c_int64: + data = ct2x.data # 32-bit field: packet start pointer + print(f"ct2x->data = {data}") + return c_int64(XDP_PASS) + + +@bpf +@section("xdp") +def print_xdp_data(ctx: struct_xdp_md) -> c_int64: + data = ctx.data # 32-bit field: packet start pointer + something = c_void_p(data) + print(f"ctx->data = {something}") + return c_int64(XDP_PASS) + + +@bpf +@bpfglobal +def LICENSE() -> str: + return "GPL" + + +compile_to_ir("i32_test.py", "i32_test.ll") +compile() diff --git a/tests/passing_tests/vmlinux/i32_test_fail_1.py b/tests/passing_tests/vmlinux/i32_test_fail_1.py new file mode 100644 index 00000000..3f6d3c1f --- /dev/null +++ b/tests/passing_tests/vmlinux/i32_test_fail_1.py @@ -0,0 +1,24 @@ +from ctypes import c_int64 +from pythonbpf import bpf, section, bpfglobal, compile +from vmlinux import struct_xdp_md +from vmlinux import XDP_PASS +import logging + + +@bpf +@section("xdp") +def print_xdp_data(ctx: struct_xdp_md) -> c_int64: + data = 0 + data = ctx.data # 32-bit field: packet start pointer + something = 2 + data + print(f"ctx->data = {something}") + return c_int64(XDP_PASS) + + +@bpf +@bpfglobal +def LICENSE() -> str: + return "GPL" + + +compile(logging.INFO) diff --git a/tests/passing_tests/vmlinux/i32_test_fail_2.py b/tests/passing_tests/vmlinux/i32_test_fail_2.py new file mode 100644 index 00000000..4792fd69 --- /dev/null +++ b/tests/passing_tests/vmlinux/i32_test_fail_2.py @@ -0,0 +1,24 @@ +from ctypes import c_int64 +from pythonbpf import bpf, section, bpfglobal, compile, compile_to_ir +from vmlinux import struct_xdp_md +from vmlinux import XDP_PASS +import logging + + +@bpf +@section("xdp") +def print_xdp_data(ctx: struct_xdp_md) -> c_int64: + data = c_int64(ctx.data) # 32-bit field: packet start pointer + something = 2 + data + print(f"ctx->data = {something}") + return c_int64(XDP_PASS) + + +@bpf +@bpfglobal +def LICENSE() -> str: + return "GPL" + + +compile_to_ir("i32_test_fail_2.py", "i32_test_fail_2.ll") +compile(logging.INFO) diff --git a/tests/passing_tests/vmlinux/simple_struct_test.py b/tests/passing_tests/vmlinux/simple_struct_test.py index 97ab54a1..2f34ba4f 100644 --- a/tests/passing_tests/vmlinux/simple_struct_test.py +++ b/tests/passing_tests/vmlinux/simple_struct_test.py @@ -44,4 +44,4 @@ def LICENSE() -> str: compile_to_ir("simple_struct_test.py", "simple_struct_test.ll", loglevel=logging.DEBUG) -# compile() +compile()