diff --git a/pythonbpf/allocation_pass.py b/pythonbpf/allocation_pass.py index 49c787f4..ae748dd1 100644 --- a/pythonbpf/allocation_pass.py +++ b/pythonbpf/allocation_pass.py @@ -199,17 +199,33 @@ def _allocate_for_binop(builder, var_name, local_sym_tab): logger.info(f"Pre-allocated {var_name} for binop result") +def _get_type_name(ir_type): + """Get a string representation of an IR type.""" + if isinstance(ir_type, ir.IntType): + return f"i{ir_type.width}" + elif isinstance(ir_type, ir.PointerType): + return "ptr" + elif isinstance(ir_type, ir.ArrayType): + return f"[{ir_type.count}x{_get_type_name(ir_type.element)}]" + else: + return str(ir_type).replace(" ", "") + + def allocate_temp_pool(builder, max_temps, local_sym_tab): """Allocate the temporary scratch space pool for helper arguments.""" - if max_temps == 0: + if not max_temps: + logger.info("No temp pool allocation needed") return - logger.info(f"Allocating temp pool of {max_temps} variables") - for i in range(max_temps): - temp_name = f"__helper_temp_{i}" - temp_var = builder.alloca(ir.IntType(64), name=temp_name) - temp_var.align = 8 - local_sym_tab[temp_name] = LocalSymbol(temp_var, ir.IntType(64)) + for tmp_type, cnt in max_temps.items(): + type_name = _get_type_name(tmp_type) + logger.info(f"Allocating temp pool of {cnt} variables of type {type_name}") + for i in range(cnt): + temp_name = f"__helper_temp_{type_name}_{i}" + temp_var = builder.alloca(tmp_type, name=temp_name) + temp_var.align = _get_alignment(tmp_type) + local_sym_tab[temp_name] = LocalSymbol(temp_var, tmp_type) + logger.debug(f"Allocated temp variable: {temp_name}") def _allocate_for_name(builder, var_name, rval, local_sym_tab): diff --git a/pythonbpf/functions/functions_pass.py b/pythonbpf/functions/functions_pass.py index e391092b..c9943682 100644 --- a/pythonbpf/functions/functions_pass.py +++ b/pythonbpf/functions/functions_pass.py @@ -33,7 +33,7 @@ def count_temps_in_call(call_node, local_sym_tab): """Count the number of temporary variables needed for a function call.""" - count = 0 + count = {} is_helper = False # NOTE: We exclude print calls for now @@ -43,21 +43,28 @@ def count_temps_in_call(call_node, local_sym_tab): and call_node.func.id != "print" ): is_helper = True + func_name = call_node.func.id elif isinstance(call_node.func, ast.Attribute): if HelperHandlerRegistry.has_handler(call_node.func.attr): is_helper = True + func_name = call_node.func.attr if not is_helper: - return 0 + return {} # No temps needed - for arg in call_node.args: + for arg_idx in range(len(call_node.args)): # NOTE: Count all non-name arguments # For struct fields, if it is being passed as an argument, # The struct object should already exist in the local_sym_tab - if not isinstance(arg, ast.Name) and not ( + arg = call_node.args[arg_idx] + if isinstance(arg, ast.Name) or ( isinstance(arg, ast.Attribute) and arg.value.id in local_sym_tab ): - count += 1 + continue + param_type = HelperHandlerRegistry.get_param_type(func_name, arg_idx) + if isinstance(param_type, ir.PointerType): + pointee_type = param_type.pointee + count[pointee_type] = count.get(pointee_type, 0) + 1 return count @@ -93,11 +100,15 @@ def handle_if_allocation( def allocate_mem( module, builder, body, func, ret_type, map_sym_tab, local_sym_tab, structs_sym_tab ): - max_temps_needed = 0 + max_temps_needed = {} + + def merge_type_counts(count_dict): + nonlocal max_temps_needed + for typ, cnt in count_dict.items(): + max_temps_needed[typ] = max(max_temps_needed.get(typ, 0), cnt) def update_max_temps_for_stmt(stmt): nonlocal max_temps_needed - temps_needed = 0 if isinstance(stmt, ast.If): for s in stmt.body: @@ -106,10 +117,13 @@ def update_max_temps_for_stmt(stmt): update_max_temps_for_stmt(s) return + stmt_temps = {} for node in ast.walk(stmt): if isinstance(node, ast.Call): - temps_needed += count_temps_in_call(node, local_sym_tab) - max_temps_needed = max(max_temps_needed, temps_needed) + call_temps = count_temps_in_call(node, local_sym_tab) + for typ, cnt in call_temps.items(): + stmt_temps[typ] = stmt_temps.get(typ, 0) + cnt + merge_type_counts(stmt_temps) for stmt in body: update_max_temps_for_stmt(stmt) diff --git a/pythonbpf/helper/__init__.py b/pythonbpf/helper/__init__.py index 2f9c3473..4c1d2831 100644 --- a/pythonbpf/helper/__init__.py +++ b/pythonbpf/helper/__init__.py @@ -1,7 +1,20 @@ from .helper_registry import HelperHandlerRegistry from .helper_utils import reset_scratch_pool from .bpf_helper_handler import handle_helper_call, emit_probe_read_kernel_str_call -from .helpers import ktime, pid, deref, comm, probe_read_str, XDP_DROP, XDP_PASS +from .helpers import ( + ktime, + pid, + deref, + comm, + probe_read_str, + random, + probe_read, + smp_processor_id, + uid, + skb_store_bytes, + XDP_DROP, + XDP_PASS, +) # Register the helper handler with expr module @@ -65,6 +78,11 @@ def helper_call_handler( "deref", "comm", "probe_read_str", + "random", + "probe_read", + "smp_processor_id", + "uid", + "skb_store_bytes", "XDP_DROP", "XDP_PASS", ] diff --git a/pythonbpf/helper/bpf_helper_handler.py b/pythonbpf/helper/bpf_helper_handler.py index 78686778..acfedefa 100644 --- a/pythonbpf/helper/bpf_helper_handler.py +++ b/pythonbpf/helper/bpf_helper_handler.py @@ -8,8 +8,8 @@ get_flags_val, get_data_ptr_and_size, get_buffer_ptr_and_size, - get_char_array_ptr_and_size, get_ptr_from_arg, + get_int_value_from_arg, ) from .printk_formatter import simple_string_print, handle_fstring_print @@ -23,15 +23,24 @@ class BPFHelperID(Enum): BPF_MAP_LOOKUP_ELEM = 1 BPF_MAP_UPDATE_ELEM = 2 BPF_MAP_DELETE_ELEM = 3 + BPF_PROBE_READ = 4 BPF_KTIME_GET_NS = 5 BPF_PRINTK = 6 + BPF_GET_PRANDOM_U32 = 7 + BPF_GET_SMP_PROCESSOR_ID = 8 + BPF_SKB_STORE_BYTES = 9 BPF_GET_CURRENT_PID_TGID = 14 + BPF_GET_CURRENT_UID_GID = 15 BPF_GET_CURRENT_COMM = 16 BPF_PERF_EVENT_OUTPUT = 25 BPF_PROBE_READ_KERNEL_STR = 115 -@HelperHandlerRegistry.register("ktime") +@HelperHandlerRegistry.register( + "ktime", + param_types=[], + return_type=ir.IntType(64), +) def bpf_ktime_get_ns_emitter( call, map_ptr, @@ -54,7 +63,11 @@ def bpf_ktime_get_ns_emitter( return result, ir.IntType(64) -@HelperHandlerRegistry.register("lookup") +@HelperHandlerRegistry.register( + "lookup", + param_types=[ir.PointerType(ir.IntType(64))], + return_type=ir.PointerType(ir.IntType(64)), +) def bpf_map_lookup_elem_emitter( call, map_ptr, @@ -96,6 +109,7 @@ def bpf_map_lookup_elem_emitter( return result, ir.PointerType() +# NOTE: This has special handling so we won't reflect the signature here. @HelperHandlerRegistry.register("print") def bpf_printk_emitter( call, @@ -144,7 +158,15 @@ def bpf_printk_emitter( return True -@HelperHandlerRegistry.register("update") +@HelperHandlerRegistry.register( + "update", + param_types=[ + ir.PointerType(ir.IntType(64)), + ir.PointerType(ir.IntType(64)), + ir.IntType(64), + ], + return_type=ir.PointerType(ir.IntType(64)), +) def bpf_map_update_elem_emitter( call, map_ptr, @@ -199,7 +221,11 @@ def bpf_map_update_elem_emitter( return result, None -@HelperHandlerRegistry.register("delete") +@HelperHandlerRegistry.register( + "delete", + param_types=[ir.PointerType(ir.IntType(64))], + return_type=ir.PointerType(ir.IntType(64)), +) def bpf_map_delete_elem_emitter( call, map_ptr, @@ -239,7 +265,11 @@ def bpf_map_delete_elem_emitter( return result, None -@HelperHandlerRegistry.register("comm") +@HelperHandlerRegistry.register( + "comm", + param_types=[ir.PointerType(ir.IntType(8))], + return_type=ir.IntType(64), +) def bpf_get_current_comm_emitter( call, map_ptr, @@ -296,7 +326,11 @@ def bpf_get_current_comm_emitter( return result, None -@HelperHandlerRegistry.register("pid") +@HelperHandlerRegistry.register( + "pid", + param_types=[], + return_type=ir.IntType(64), +) def bpf_get_current_pid_tgid_emitter( call, map_ptr, @@ -318,12 +352,17 @@ def bpf_get_current_pid_tgid_emitter( result = builder.call(fn_ptr, [], tail=False) # Extract the lower 32 bits (PID) using bitwise AND with 0xFFFFFFFF + # TODO: return both PID and TGID if we end up needing TGID somewhere mask = ir.Constant(ir.IntType(64), 0xFFFFFFFF) pid = builder.and_(result, mask) return pid, ir.IntType(64) -@HelperHandlerRegistry.register("output") +@HelperHandlerRegistry.register( + "output", + param_types=[ir.PointerType(ir.IntType(8))], + return_type=ir.IntType(64), +) def bpf_perf_event_output_handler( call, map_ptr, @@ -398,7 +437,14 @@ def emit_probe_read_kernel_str_call(builder, dst_ptr, dst_size, src_ptr): return result -@HelperHandlerRegistry.register("probe_read_str") +@HelperHandlerRegistry.register( + "probe_read_str", + param_types=[ + ir.PointerType(ir.IntType(8)), + ir.PointerType(ir.IntType(8)), + ], + return_type=ir.IntType(64), +) def bpf_probe_read_kernel_str_emitter( call, map_ptr, @@ -417,8 +463,8 @@ def bpf_probe_read_kernel_str_emitter( ) # Get destination buffer (char array -> i8*) - dst_ptr, dst_size = get_char_array_ptr_and_size( - call.args[0], builder, local_sym_tab, struct_sym_tab + dst_ptr, dst_size = get_or_create_ptr_from_arg( + func, module, call.args[0], builder, local_sym_tab, map_sym_tab, struct_sym_tab ) # Get source pointer (evaluate expression) @@ -433,6 +479,263 @@ def bpf_probe_read_kernel_str_emitter( return result, ir.IntType(64) +@HelperHandlerRegistry.register( + "random", + param_types=[], + return_type=ir.IntType(32), +) +def bpf_get_prandom_u32_emitter( + call, + map_ptr, + module, + builder, + func, + local_sym_tab=None, + struct_sym_tab=None, + map_sym_tab=None, +): + """ + Emit LLVM IR for bpf_get_prandom_u32 helper function call. + """ + helper_id = ir.Constant(ir.IntType(64), BPFHelperID.BPF_GET_PRANDOM_U32.value) + fn_type = ir.FunctionType(ir.IntType(32), [], var_arg=False) + fn_ptr_type = ir.PointerType(fn_type) + fn_ptr = builder.inttoptr(helper_id, fn_ptr_type) + result = builder.call(fn_ptr, [], tail=False) + return result, ir.IntType(32) + + +@HelperHandlerRegistry.register( + "probe_read", + param_types=[ + ir.PointerType(ir.IntType(8)), + ir.IntType(32), + ir.PointerType(ir.IntType(8)), + ], + return_type=ir.IntType(64), +) +def bpf_probe_read_emitter( + call, + map_ptr, + module, + builder, + func, + local_sym_tab=None, + struct_sym_tab=None, + map_sym_tab=None, +): + """ + Emit LLVM IR for bpf_probe_read helper function + """ + + if len(call.args) != 3: + logger.warn("Expected 3 args for probe_read helper") + return + dst_ptr = get_or_create_ptr_from_arg( + func, + module, + call.args[0], + builder, + local_sym_tab, + map_sym_tab, + struct_sym_tab, + ir.IntType(8), + ) + size_val = get_int_value_from_arg( + call.args[1], + func, + module, + builder, + local_sym_tab, + map_sym_tab, + struct_sym_tab, + ) + src_ptr = get_or_create_ptr_from_arg( + func, + module, + call.args[2], + builder, + local_sym_tab, + map_sym_tab, + struct_sym_tab, + ir.IntType(8), + ) + fn_type = ir.FunctionType( + ir.IntType(64), + [ir.PointerType(), ir.IntType(32), ir.PointerType()], + var_arg=False, + ) + fn_ptr = builder.inttoptr( + ir.Constant(ir.IntType(64), BPFHelperID.BPF_PROBE_READ.value), + ir.PointerType(fn_type), + ) + result = builder.call( + fn_ptr, + [ + builder.bitcast(dst_ptr, ir.PointerType()), + builder.trunc(size_val, ir.IntType(32)), + builder.bitcast(src_ptr, ir.PointerType()), + ], + tail=False, + ) + logger.info(f"Emitted bpf_probe_read (size={size_val})") + return result, ir.IntType(64) + + +@HelperHandlerRegistry.register( + "smp_processor_id", + param_types=[], + return_type=ir.IntType(32), +) +def bpf_get_smp_processor_id_emitter( + call, + map_ptr, + module, + builder, + func, + local_sym_tab=None, + struct_sym_tab=None, + map_sym_tab=None, +): + """ + Emit LLVM IR for bpf_get_smp_processor_id helper function call. + """ + helper_id = ir.Constant(ir.IntType(64), BPFHelperID.BPF_GET_SMP_PROCESSOR_ID.value) + fn_type = ir.FunctionType(ir.IntType(32), [], var_arg=False) + fn_ptr_type = ir.PointerType(fn_type) + fn_ptr = builder.inttoptr(helper_id, fn_ptr_type) + result = builder.call(fn_ptr, [], tail=False) + logger.info("Emitted bpf_get_smp_processor_id call") + return result, ir.IntType(32) + + +@HelperHandlerRegistry.register( + "uid", + param_types=[], + return_type=ir.IntType(64), +) +def bpf_get_current_uid_gid_emitter( + call, + map_ptr, + module, + builder, + func, + local_sym_tab=None, + struct_sym_tab=None, + map_sym_tab=None, +): + """ + Emit LLVM IR for bpf_get_current_uid_gid helper function call. + """ + helper_id = ir.Constant(ir.IntType(64), BPFHelperID.BPF_GET_CURRENT_UID_GID.value) + fn_type = ir.FunctionType(ir.IntType(64), [], var_arg=False) + fn_ptr_type = ir.PointerType(fn_type) + fn_ptr = builder.inttoptr(helper_id, fn_ptr_type) + result = builder.call(fn_ptr, [], tail=False) + + # Extract the lower 32 bits (UID) using bitwise AND with 0xFFFFFFFF + # TODO: return both UID and GID if we end up needing GID somewhere + mask = ir.Constant(ir.IntType(64), 0xFFFFFFFF) + pid = builder.and_(result, mask) + return pid, ir.IntType(64) + + +@HelperHandlerRegistry.register( + "skb_store_bytes", + param_types=[ + ir.IntType(32), + ir.PointerType(ir.IntType(8)), + ir.IntType(32), + ir.IntType(64), + ], + return_type=ir.IntType(64), +) +def bpf_skb_store_bytes_emitter( + call, + map_ptr, + module, + builder, + func, + local_sym_tab=None, + struct_sym_tab=None, + map_sym_tab=None, +): + """ + Emit LLVM IR for bpf_skb_store_bytes helper function call. + Expected call signature: skb_store_bytes(skb, offset, from, len, flags) + """ + + args_signature = [ + ir.PointerType(), # skb pointer + ir.IntType(32), # offset + ir.PointerType(), # from + ir.IntType(32), # len + ir.IntType(64), # flags + ] + + if len(call.args) not in (3, 4): + raise ValueError( + f"skb_store_bytes expects 3 or 4 args (offset, from, len, flags), got {len(call.args)}" + ) + + skb_ptr = func.args[0] # First argument to the function is skb + offset_val = get_int_value_from_arg( + call.args[0], + func, + module, + builder, + local_sym_tab, + map_sym_tab, + struct_sym_tab, + ) + from_ptr = get_or_create_ptr_from_arg( + func, + module, + call.args[1], + builder, + local_sym_tab, + map_sym_tab, + struct_sym_tab, + args_signature[2], + ) + len_val = get_int_value_from_arg( + call.args[2], + func, + module, + builder, + local_sym_tab, + map_sym_tab, + struct_sym_tab, + ) + if len(call.args) == 4: + flags_val = get_flags_val(call.args[3], builder, local_sym_tab) + else: + flags_val = 0 + flags = ir.Constant(ir.IntType(64), flags_val) + fn_type = ir.FunctionType( + ir.IntType(64), + args_signature, + var_arg=False, + ) + fn_ptr = builder.inttoptr( + ir.Constant(ir.IntType(64), BPFHelperID.BPF_SKB_STORE_BYTES.value), + ir.PointerType(fn_type), + ) + result = builder.call( + fn_ptr, + [ + builder.bitcast(skb_ptr, ir.PointerType()), + builder.trunc(offset_val, ir.IntType(32)), + builder.bitcast(from_ptr, ir.PointerType()), + builder.trunc(len_val, ir.IntType(32)), + flags, + ], + tail=False, + ) + logger.info("Emitted bpf_skb_store_bytes call") + return result, ir.IntType(64) + + def handle_helper_call( call, module, diff --git a/pythonbpf/helper/helper_registry.py b/pythonbpf/helper/helper_registry.py index 476e3b60..0e09d70c 100644 --- a/pythonbpf/helper/helper_registry.py +++ b/pythonbpf/helper/helper_registry.py @@ -1,17 +1,31 @@ +from dataclasses import dataclass +from llvmlite import ir from typing import Callable +@dataclass +class HelperSignature: + """Signature of a BPF helper function""" + + arg_types: list[ir.Type] + return_type: ir.Type + func: Callable + + class HelperHandlerRegistry: """Registry for BPF helpers""" - _handlers: dict[str, Callable] = {} + _handlers: dict[str, HelperSignature] = {} @classmethod - def register(cls, helper_name): + def register(cls, helper_name, param_types=None, return_type=None): """Decorator to register a handler function for a helper""" def decorator(func): - cls._handlers[helper_name] = func + helper_sig = HelperSignature( + arg_types=param_types, return_type=return_type, func=func + ) + cls._handlers[helper_name] = helper_sig return func return decorator @@ -19,9 +33,29 @@ def decorator(func): @classmethod def get_handler(cls, helper_name): """Get the handler function for a helper""" - return cls._handlers.get(helper_name) + handler = cls._handlers.get(helper_name) + return handler.func if handler else None @classmethod def has_handler(cls, helper_name): """Check if a handler function is registered for a helper""" return helper_name in cls._handlers + + @classmethod + def get_signature(cls, helper_name): + """Get the signature of a helper function""" + return cls._handlers.get(helper_name) + + @classmethod + def get_param_type(cls, helper_name, index): + """Get the type of a parameter of a helper function by the index""" + signature = cls.get_signature(helper_name) + if signature and signature.arg_types and 0 <= index < len(signature.arg_types): + return signature.arg_types[index] + return None + + @classmethod + def get_return_type(cls, helper_name): + """Get the return type of a helper function""" + signature = cls.get_signature(helper_name) + return signature.return_type if signature else None diff --git a/pythonbpf/helper/helper_utils.py b/pythonbpf/helper/helper_utils.py index fdfd4524..06d3cf14 100644 --- a/pythonbpf/helper/helper_utils.py +++ b/pythonbpf/helper/helper_utils.py @@ -14,26 +14,43 @@ class ScratchPoolManager: """Manage the temporary helper variables in local_sym_tab""" def __init__(self): - self._counter = 0 + self._counters = {} @property def counter(self): - return self._counter + return sum(self._counters.values()) def reset(self): - self._counter = 0 + self._counters.clear() logger.debug("Scratch pool counter reset to 0") - def get_next_temp(self, local_sym_tab): - temp_name = f"__helper_temp_{self._counter}" - self._counter += 1 + def _get_type_name(self, ir_type): + if isinstance(ir_type, ir.PointerType): + return "ptr" + elif isinstance(ir_type, ir.IntType): + return f"i{ir_type.width}" + elif isinstance(ir_type, ir.ArrayType): + return f"[{ir_type.count}x{self._get_type_name(ir_type.element)}]" + else: + return str(ir_type).replace(" ", "") + + def get_next_temp(self, local_sym_tab, expected_type=None): + # Default to i64 if no expected type provided + type_name = self._get_type_name(expected_type) if expected_type else "i64" + if type_name not in self._counters: + self._counters[type_name] = 0 + + counter = self._counters[type_name] + temp_name = f"__helper_temp_{type_name}_{counter}" + self._counters[type_name] += 1 if temp_name not in local_sym_tab: raise ValueError( f"Scratch pool exhausted or inadequate: {temp_name}. " - f"Current counter: {self._counter}" + f"Type: {type_name} Counter: {counter}" ) + logger.debug(f"Using {temp_name} for type {type_name}") return local_sym_tab[temp_name].var, temp_name @@ -60,24 +77,73 @@ def get_var_ptr_from_name(var_name, local_sym_tab): def create_int_constant_ptr(value, builder, local_sym_tab, int_width=64): """Create a pointer to an integer constant.""" - # Default to 64-bit integer - ptr, temp_name = _temp_pool_manager.get_next_temp(local_sym_tab) + int_type = ir.IntType(int_width) + ptr, temp_name = _temp_pool_manager.get_next_temp(local_sym_tab, int_type) logger.info(f"Using temp variable '{temp_name}' for int constant {value}") - const_val = ir.Constant(ir.IntType(int_width), value) + const_val = ir.Constant(int_type, value) builder.store(const_val, ptr) return ptr def get_or_create_ptr_from_arg( - func, module, arg, builder, local_sym_tab, map_sym_tab, struct_sym_tab=None + func, + module, + arg, + builder, + local_sym_tab, + map_sym_tab, + struct_sym_tab=None, + expected_type=None, ): """Extract or create pointer from the call arguments.""" + logger.info(f"Getting pointer from arg: {ast.dump(arg)}") + sz = None if isinstance(arg, ast.Name): + # Stack space is already allocated ptr = get_var_ptr_from_name(arg.id, local_sym_tab) elif isinstance(arg, ast.Constant) and isinstance(arg.value, int): - ptr = create_int_constant_ptr(arg.value, builder, local_sym_tab) + int_width = 64 # Default to i64 + if expected_type and isinstance(expected_type, ir.IntType): + int_width = expected_type.width + ptr = create_int_constant_ptr(arg.value, builder, local_sym_tab, int_width) + elif isinstance(arg, ast.Attribute): + # A struct field + struct_name = arg.value.id + field_name = arg.attr + + if not local_sym_tab or struct_name not in local_sym_tab: + raise ValueError(f"Struct '{struct_name}' not found") + + struct_type = local_sym_tab[struct_name].metadata + if not struct_sym_tab or struct_type not in struct_sym_tab: + raise ValueError(f"Struct type '{struct_type}' not found") + + struct_info = struct_sym_tab[struct_type] + if field_name not in struct_info.fields: + raise ValueError( + f"Field '{field_name}' not found in struct '{struct_name}'" + ) + + field_type = struct_info.field_type(field_name) + struct_ptr = local_sym_tab[struct_name].var + + # Special handling for char arrays + if ( + isinstance(field_type, ir.ArrayType) + and isinstance(field_type.element, ir.IntType) + and field_type.element.width == 8 + ): + ptr, sz = get_char_array_ptr_and_size( + arg, builder, local_sym_tab, struct_sym_tab + ) + if not ptr: + raise ValueError("Failed to get char array pointer from struct field") + else: + ptr = struct_info.gep(builder, struct_ptr, field_name) + else: + # NOTE: For any integer expression reaching this branch, it is probably a struct field or a binop # Evaluate the expression and store the result in a temp variable val = get_operand_value( func, module, arg, builder, local_sym_tab, map_sym_tab, struct_sym_tab @@ -85,13 +151,20 @@ def get_or_create_ptr_from_arg( if val is None: raise ValueError("Failed to evaluate expression for helper arg.") - # NOTE: We assume the result is an int64 for now - # if isinstance(arg, ast.Attribute): - # return val - ptr, temp_name = _temp_pool_manager.get_next_temp(local_sym_tab) + ptr, temp_name = _temp_pool_manager.get_next_temp(local_sym_tab, expected_type) logger.info(f"Using temp variable '{temp_name}' for expression result") + if ( + isinstance(val.type, ir.IntType) + and expected_type + and val.type.width > expected_type.width + ): + val = builder.trunc(val, expected_type) builder.store(val, ptr) + # NOTE: For char arrays, also return size + if sz: + return ptr, sz + return ptr @@ -274,3 +347,23 @@ def get_ptr_from_arg( raise ValueError(f"Expected pointer type, got {val_type}") return val, val_type + + +def get_int_value_from_arg( + arg, func, module, builder, local_sym_tab, map_sym_tab, struct_sym_tab +): + """Evaluate argument and return integer value""" + + result = eval_expr( + func, module, builder, arg, local_sym_tab, map_sym_tab, struct_sym_tab + ) + + if not result: + raise ValueError("Failed to evaluate argument") + + val, val_type = result + + if not isinstance(val_type, ir.IntType): + raise ValueError(f"Expected integer type, got {val_type}") + + return val diff --git a/pythonbpf/helper/helpers.py b/pythonbpf/helper/helpers.py index cb1a8e12..302b5263 100644 --- a/pythonbpf/helper/helpers.py +++ b/pythonbpf/helper/helpers.py @@ -27,6 +27,31 @@ def probe_read_str(dst, src): return ctypes.c_int64(0) +def random(): + """get a pseudorandom u32 number""" + return ctypes.c_int32(0) + + +def probe_read(dst, size, src): + """Safely read data from kernel memory""" + return ctypes.c_int64(0) + + +def smp_processor_id(): + """get the current CPU id""" + return ctypes.c_int32(0) + + +def uid(): + """get current user id""" + return ctypes.c_int32(0) + + +def skb_store_bytes(offset, from_buf, size, flags=0): + """store bytes into a socket buffer""" + return ctypes.c_int64(0) + + XDP_ABORTED = ctypes.c_int64(0) XDP_DROP = ctypes.c_int64(1) XDP_PASS = ctypes.c_int64(2) diff --git a/pythonbpf/helper/printk_formatter.py b/pythonbpf/helper/printk_formatter.py index a18f1354..58990c05 100644 --- a/pythonbpf/helper/printk_formatter.py +++ b/pythonbpf/helper/printk_formatter.py @@ -4,6 +4,7 @@ from llvmlite import ir from pythonbpf.expr import eval_expr, get_base_type_and_depth, deref_to_depth from pythonbpf.expr.vmlinux_registry import VmlinuxHandlerRegistry +from pythonbpf.helper.helper_utils import get_char_array_ptr_and_size logger = logging.getLogger(__name__) @@ -219,7 +220,7 @@ def _prepare_expr_args(expr, func, module, builder, local_sym_tab, struct_sym_ta """Evaluate and prepare an expression to use as an arg for bpf_printk.""" # Special case: struct field char array needs pointer to first element - char_array_ptr = _get_struct_char_array_ptr( + char_array_ptr, _ = get_char_array_ptr_and_size( expr, builder, local_sym_tab, struct_sym_tab ) if char_array_ptr: @@ -242,52 +243,6 @@ def _prepare_expr_args(expr, func, module, builder, local_sym_tab, struct_sym_ta return ir.Constant(ir.IntType(64), 0) -def _get_struct_char_array_ptr(expr, builder, local_sym_tab, struct_sym_tab): - """Get pointer to first element of char array in struct field, or None.""" - if not (isinstance(expr, ast.Attribute) and isinstance(expr.value, ast.Name)): - return None - - var_name = expr.value.id - field_name = expr.attr - - # Check if it's a valid struct field - if not ( - local_sym_tab - and var_name in local_sym_tab - and struct_sym_tab - and local_sym_tab[var_name].metadata in struct_sym_tab - ): - return None - - struct_type = local_sym_tab[var_name].metadata - struct_info = struct_sym_tab[struct_type] - - if field_name not in struct_info.fields: - return None - - field_type = struct_info.field_type(field_name) - - # Check if it's a char array - is_char_array = ( - isinstance(field_type, ir.ArrayType) - and isinstance(field_type.element, ir.IntType) - and field_type.element.width == 8 - ) - - if not is_char_array: - return None - - # Get field pointer and GEP to first element: [N x i8]* -> i8* - struct_ptr = local_sym_tab[var_name].var - field_ptr = struct_info.gep(builder, struct_ptr, field_name) - - return builder.gep( - field_ptr, - [ir.Constant(ir.IntType(32), 0), ir.Constant(ir.IntType(32), 0)], - inbounds=True, - ) - - def _handle_pointer_arg(val, func, builder): """Convert pointer type for bpf_printk.""" target, depth = get_base_type_and_depth(val.type) diff --git a/tests/passing_tests/helpers/bpf_probe_read.py b/tests/passing_tests/helpers/bpf_probe_read.py new file mode 100644 index 00000000..fcece4d6 --- /dev/null +++ b/tests/passing_tests/helpers/bpf_probe_read.py @@ -0,0 +1,29 @@ +from pythonbpf import bpf, section, bpfglobal, compile, struct +from ctypes import c_void_p, c_int64, c_uint64, c_uint32 +from pythonbpf.helper import probe_read + + +@bpf +@struct +class data_t: + pid: c_uint32 + value: c_uint64 + + +@bpf +@section("tracepoint/syscalls/sys_enter_execve") +def test_probe_read(ctx: c_void_p) -> c_int64: + """Test bpf_probe_read helper function""" + data = data_t() + probe_read(data.value, 8, ctx) + probe_read(data.pid, 4, ctx) + return 0 + + +@bpf +@bpfglobal +def LICENSE() -> str: + return "GPL" + + +compile() diff --git a/tests/passing_tests/helpers/prandom.py b/tests/passing_tests/helpers/prandom.py new file mode 100644 index 00000000..396927ba --- /dev/null +++ b/tests/passing_tests/helpers/prandom.py @@ -0,0 +1,25 @@ +from pythonbpf import bpf, bpfglobal, section, BPF, trace_pipe +from ctypes import c_void_p, c_int64 +from pythonbpf.helper import random + + +@bpf +@section("tracepoint/syscalls/sys_enter_clone") +def hello_world(ctx: c_void_p) -> c_int64: + r = random() + print(f"Hello, World!, {r}") + return 0 # type: ignore [return-value] + + +@bpf +@bpfglobal +def LICENSE() -> str: + return "GPL" + + +# Compile and load +b = BPF() +b.load() +b.attach_all() + +trace_pipe() diff --git a/tests/passing_tests/helpers/smp_processor_id.py b/tests/passing_tests/helpers/smp_processor_id.py new file mode 100644 index 00000000..8c17a756 --- /dev/null +++ b/tests/passing_tests/helpers/smp_processor_id.py @@ -0,0 +1,40 @@ +from pythonbpf import bpf, section, bpfglobal, compile, struct +from ctypes import c_void_p, c_int64, c_uint32, c_uint64 +from pythonbpf.helper import smp_processor_id, ktime + + +@bpf +@struct +class cpu_event_t: + cpu_id: c_uint32 + timestamp: c_uint64 + + +@bpf +@section("tracepoint/syscalls/sys_enter_execve") +def trace_with_cpu(ctx: c_void_p) -> c_int64: + """Test bpf_get_smp_processor_id helper function""" + + # Get the current CPU ID + cpu = smp_processor_id() + + # Print it + print(f"Running on CPU {cpu}") + + # Use it in a struct + event = cpu_event_t() + event.cpu_id = smp_processor_id() + event.timestamp = ktime() + + print(f"Event on CPU {event.cpu_id} at time {event.timestamp}") + + return 0 + + +@bpf +@bpfglobal +def LICENSE() -> str: + return "GPL" + + +compile() diff --git a/tests/passing_tests/helpers/uid_gid.py b/tests/passing_tests/helpers/uid_gid.py new file mode 100644 index 00000000..e4e50b44 --- /dev/null +++ b/tests/passing_tests/helpers/uid_gid.py @@ -0,0 +1,31 @@ +from pythonbpf import bpf, section, bpfglobal, compile +from ctypes import c_void_p, c_int64 +from pythonbpf.helper import uid, pid + + +@bpf +@section("tracepoint/syscalls/sys_enter_execve") +def filter_by_user(ctx: c_void_p) -> c_int64: + """Filter events by specific user ID""" + + current_uid = uid() + + # Only trace root user (UID 0) + if current_uid == 0: + process_id = pid() + print(f"Root process {process_id} executed") + + # Or trace specific user (e.g., UID 1000) + if current_uid == 1002: + print("User 1002 executed something") + + return 0 + + +@bpf +@bpfglobal +def LICENSE() -> str: + return "GPL" + + +compile()