diff --git a/examples/struct_and_perf.py b/examples/struct_and_perf.py index 329d5ed..c545de0 100644 --- a/examples/struct_and_perf.py +++ b/examples/struct_and_perf.py @@ -23,13 +23,13 @@ def events() -> PerfEventArray: @section("tracepoint/syscalls/sys_enter_clone") def hello(ctx: c_void_p) -> c_int32: dataobj = data_t() - ts = ktime() strobj = "hellohellohello" dataobj.pid = pid() dataobj.ts = ktime() # dataobj.comm = strobj print( - f"clone called at {dataobj.ts} by pid {dataobj.pid}, comm {strobj} at time {ts}" + f"clone called at {dataobj.ts} by pid { + dataobj.pid}, comm {strobj} at time {ts}" ) events.output(dataobj) return c_int32(0) diff --git a/pythonbpf/bpf_helper_handler.py b/pythonbpf/bpf_helper_handler.py deleted file mode 100644 index 7e4a2ce..0000000 --- a/pythonbpf/bpf_helper_handler.py +++ /dev/null @@ -1,654 +0,0 @@ -import ast -from llvmlite import ir -from .expr_pass import eval_expr - - -def bpf_ktime_get_ns_emitter( - call, - map_ptr, - module, - builder, - func, - local_sym_tab=None, - struct_sym_tab=None, - local_var_metadata=None, -): - """ - Emit LLVM IR for bpf_ktime_get_ns helper function call. - """ - # func is an arg to just have a uniform signature with other emitters - helper_id = ir.Constant(ir.IntType(64), 5) - fn_type = ir.FunctionType(ir.IntType(64), [], var_arg=False) - fn_ptr_type = ir.PointerType(fn_type) - fn_ptr = builder.inttoptr(helper_id, fn_ptr_type) - result = builder.call(fn_ptr, [], tail=False) - return result, ir.IntType(64) - - -def bpf_map_lookup_elem_emitter( - call, - map_ptr, - module, - builder, - func, - local_sym_tab=None, - struct_sym_tab=None, - local_var_metadata=None, -): - """ - Emit LLVM IR for bpf_map_lookup_elem helper function call. - """ - if call.args and len(call.args) != 1: - raise ValueError( - "Map lookup expects exactly one argument, got " f"{len(call.args)}" - ) - key_arg = call.args[0] - if isinstance(key_arg, ast.Name): - key_name = key_arg.id - if local_sym_tab and key_name in local_sym_tab: - key_ptr = local_sym_tab[key_name][0] - else: - raise ValueError( - f"Key variable {key_name} not found in local symbol table." - ) - elif isinstance(key_arg, ast.Constant) and isinstance(key_arg.value, int): - # handle constant integer keys - key_val = key_arg.value - key_type = ir.IntType(64) - key_ptr = builder.alloca(key_type) - key_ptr.align = key_type // 8 - builder.store(ir.Constant(key_type, key_val), key_ptr) - else: - raise NotImplementedError( - "Only simple variable names are supported as keys in map lookup." - ) - - if key_ptr is None: - raise ValueError("Key pointer is None.") - - map_void_ptr = builder.bitcast(map_ptr, ir.PointerType()) - - fn_type = ir.FunctionType( - ir.PointerType(), # Return type: void* - [ir.PointerType(), ir.PointerType()], # Args: (void*, void*) - var_arg=False, - ) - fn_ptr_type = ir.PointerType(fn_type) - - # Helper ID 1 is bpf_map_lookup_elem - fn_addr = ir.Constant(ir.IntType(64), 1) - fn_ptr = builder.inttoptr(fn_addr, fn_ptr_type) - - result = builder.call(fn_ptr, [map_void_ptr, key_ptr], tail=False) - - return result, ir.PointerType() - - -def bpf_printk_emitter( - call, - map_ptr, - module, - builder, - func, - local_sym_tab=None, - struct_sym_tab=None, - local_var_metadata=None, -): - if not hasattr(func, "_fmt_counter"): - func._fmt_counter = 0 - - if not call.args: - raise ValueError("print expects at least one argument") - - if isinstance(call.args[0], ast.JoinedStr): - fmt_parts = [] - exprs = [] - - for value in call.args[0].values: - print("Value in f-string:", ast.dump(value)) - if isinstance(value, ast.Constant): - if isinstance(value.value, str): - fmt_parts.append(value.value) - elif isinstance(value.value, int): - fmt_parts.append("%lld") - exprs.append(ir.Constant(ir.IntType(64), value.value)) - else: - raise NotImplementedError( - "Only string and integer constants are supported in f-string." - ) - elif isinstance(value, ast.FormattedValue): - print("Formatted value:", ast.dump(value)) - # TODO: Dirty handling here, only checks for int or str - if isinstance(value.value, ast.Name): - if local_sym_tab and value.value.id in local_sym_tab: - var_ptr, var_type = local_sym_tab[value.value.id] - if isinstance(var_type, ir.IntType): - fmt_parts.append("%lld") - exprs.append(value.value) - elif var_type == ir.PointerType(ir.IntType(8)): - # Case with string - fmt_parts.append("%s") - exprs.append(value.value) - else: - raise NotImplementedError( - "Only integer and pointer types are supported in formatted values." - ) - else: - raise ValueError( - f"Variable {value.value.id} not found in local symbol table." - ) - elif isinstance(value.value, ast.Attribute): - # object field access from struct - if ( - isinstance(value.value.value, ast.Name) - and local_sym_tab - and value.value.value.id in local_sym_tab - ): - var_name = value.value.value.id - field_name = value.value.attr - if local_var_metadata and var_name in local_var_metadata: - var_type = local_var_metadata[var_name] - if var_type in struct_sym_tab: - struct_info = struct_sym_tab[var_type] - if field_name in struct_info.fields: - field_type = struct_info.field_type(field_name) - if isinstance(field_type, ir.IntType): - fmt_parts.append("%lld") - exprs.append(value.value) - elif field_type == ir.PointerType(ir.IntType(8)): - fmt_parts.append("%s") - exprs.append(value.value) - else: - raise NotImplementedError( - "Only integer and pointer types are supported in formatted values." - ) - else: - raise ValueError( - f"Field {field_name} not found in struct {var_type}." - ) - else: - raise ValueError( - f"Struct type {var_type} for variable {var_name} not found in struct symbol table." - ) - else: - raise ValueError( - f"Metadata for variable {var_name} not found in local variable metadata." - ) - else: - raise ValueError( - f"Variable {value.value.value.id} not found in local symbol table." - ) - else: - raise NotImplementedError( - "Only simple variable names are supported in formatted values." - ) - else: - raise NotImplementedError("Unsupported value type in f-string.") - - fmt_str = "".join(fmt_parts) + "\n" + "\0" - fmt_name = f"{func.name}____fmt{func._fmt_counter}" - func._fmt_counter += 1 - - fmt_gvar = ir.GlobalVariable( - module, ir.ArrayType(ir.IntType(8), len(fmt_str)), name=fmt_name - ) - fmt_gvar.global_constant = True - fmt_gvar.initializer = ir.Constant( # type: ignore - ir.ArrayType(ir.IntType(8), len(fmt_str)), bytearray(fmt_str.encode("utf8")) - ) - fmt_gvar.linkage = "internal" - fmt_gvar.align = 1 # type: ignore - - fmt_ptr = builder.bitcast(fmt_gvar, ir.PointerType()) - - args = [fmt_ptr, ir.Constant(ir.IntType(32), len(fmt_str))] - - # Only 3 args supported in bpf_printk - if len(exprs) > 3: - print( - "Warning: bpf_printk supports up to 3 arguments, extra arguments will be ignored." - ) - - for expr in exprs[:3]: - print(f"{ast.dump(expr)}") - val, _ = eval_expr( - func, - module, - builder, - expr, - local_sym_tab, - None, - struct_sym_tab, - local_var_metadata, - ) - if val: - if isinstance(val.type, ir.PointerType): - val = builder.ptrtoint(val, ir.IntType(64)) - elif isinstance(val.type, ir.IntType): - if val.type.width < 64: - val = builder.sext(val, ir.IntType(64)) - else: - print( - "Warning: Only integer and pointer types are supported in bpf_printk arguments. Others will be converted to 0." - ) - val = ir.Constant(ir.IntType(64), 0) - args.append(val) - else: - print( - "Warning: Failed to evaluate expression for bpf_printk argument. It will be converted to 0." - ) - args.append(ir.Constant(ir.IntType(64), 0)) - fn_type = ir.FunctionType( - ir.IntType(64), [ir.PointerType(), ir.IntType(32)], var_arg=True - ) - fn_ptr_type = ir.PointerType(fn_type) - fn_addr = ir.Constant(ir.IntType(64), 6) - fn_ptr = builder.inttoptr(fn_addr, fn_ptr_type) - return builder.call(fn_ptr, args, tail=True) - - for arg in call.args: - if isinstance(arg, ast.Constant) and isinstance(arg.value, str): - fmt_str = arg.value + "\n" + "\0" - fmt_name = f"{func.name}____fmt{func._fmt_counter}" - func._fmt_counter += 1 - - fmt_gvar = ir.GlobalVariable( - module, ir.ArrayType(ir.IntType(8), len(fmt_str)), name=fmt_name - ) - fmt_gvar.global_constant = True - fmt_gvar.initializer = ir.Constant( # type: ignore - ir.ArrayType(ir.IntType(8), len(fmt_str)), - bytearray(fmt_str.encode("utf8")), - ) - fmt_gvar.linkage = "internal" - fmt_gvar.align = 1 # type: ignore - - fmt_ptr = builder.bitcast(fmt_gvar, ir.PointerType()) - - fn_type = ir.FunctionType( - ir.IntType(64), [ir.PointerType(), ir.IntType(32)], var_arg=True - ) - fn_ptr_type = ir.PointerType(fn_type) - fn_addr = ir.Constant(ir.IntType(64), 6) - fn_ptr = builder.inttoptr(fn_addr, fn_ptr_type) - - builder.call( - fn_ptr, [fmt_ptr, ir.Constant(ir.IntType(32), len(fmt_str))], tail=True - ) - return None - - -def bpf_map_update_elem_emitter( - call, - map_ptr, - module, - builder, - func, - local_sym_tab=None, - struct_sym_tab=None, - local_var_metadata=None, -): - """ - Emit LLVM IR for bpf_map_update_elem helper function call. - Expected call signature: map.update(key, value, flags=0) - """ - if not call.args or len(call.args) < 2 or len(call.args) > 3: - raise ValueError( - "Map update expects 2 or 3 arguments (key, value, flags), got " - f"{len(call.args)}" - ) - - key_arg = call.args[0] - value_arg = call.args[1] - flags_arg = call.args[2] if len(call.args) > 2 else None - - # Handle key - if isinstance(key_arg, ast.Name): - key_name = key_arg.id - if local_sym_tab and key_name in local_sym_tab: - key_ptr = local_sym_tab[key_name][0] - else: - raise ValueError( - f"Key variable {key_name} not found in local symbol table." - ) - elif isinstance(key_arg, ast.Constant) and isinstance(key_arg.value, int): - # Handle constant integer keys - key_val = key_arg.value - key_type = ir.IntType(64) - key_ptr = builder.alloca(key_type) - key_ptr.align = key_type.width // 8 - builder.store(ir.Constant(key_type, key_val), key_ptr) - else: - raise NotImplementedError( - "Only simple variable names and integer constants are supported as keys in map update." - ) - - # Handle value - if isinstance(value_arg, ast.Name): - value_name = value_arg.id - if local_sym_tab and value_name in local_sym_tab: - value_ptr = local_sym_tab[value_name][0] - else: - raise ValueError( - f"Value variable {value_name} not found in local symbol table." - ) - elif isinstance(value_arg, ast.Constant) and isinstance(value_arg.value, int): - # Handle constant integers - value_val = value_arg.value - value_type = ir.IntType(64) - value_ptr = builder.alloca(value_type) - value_ptr.align = value_type.width // 8 - builder.store(ir.Constant(value_type, value_val), value_ptr) - else: - raise NotImplementedError( - "Only simple variable names and integer constants are supported as values in map update." - ) - - # Handle flags argument (defaults to 0) - if flags_arg is not None: - if isinstance(flags_arg, ast.Constant) and isinstance(flags_arg.value, int): - flags_val = flags_arg.value - elif isinstance(flags_arg, ast.Name): - flags_name = flags_arg.id - if local_sym_tab and flags_name in local_sym_tab: - # Assume it's a stored integer value, load it - flags_ptr = local_sym_tab[flags_name][0] - flags_val = builder.load(flags_ptr) - else: - raise ValueError( - f"Flags variable {flags_name} not found in local symbol table." - ) - else: - raise NotImplementedError( - "Only integer constants and simple variable names are supported as flags in map update." - ) - else: - flags_val = 0 - - if key_ptr is None or value_ptr is None: - raise ValueError("Key pointer or value pointer is None.") - - map_void_ptr = builder.bitcast(map_ptr, ir.PointerType()) - fn_type = ir.FunctionType( - ir.IntType(64), - [ir.PointerType(), ir.PointerType(), ir.PointerType(), ir.IntType(64)], - var_arg=False, - ) - fn_ptr_type = ir.PointerType(fn_type) - - # helper id - fn_addr = ir.Constant(ir.IntType(64), 2) - fn_ptr = builder.inttoptr(fn_addr, fn_ptr_type) - - if isinstance(flags_val, int): - flags_const = ir.Constant(ir.IntType(64), flags_val) - else: - flags_const = flags_val - - result = builder.call( - fn_ptr, [map_void_ptr, key_ptr, value_ptr, flags_const], tail=False - ) - - return result, None - - -def bpf_map_delete_elem_emitter( - call, - map_ptr, - module, - builder, - func, - local_sym_tab=None, - struct_sym_tab=None, - local_var_metadata=None, -): - """ - Emit LLVM IR for bpf_map_delete_elem helper function call. - Expected call signature: map.delete(key) - """ - # Check for correct number of arguments - if not call.args or len(call.args) != 1: - raise ValueError( - "Map delete expects exactly 1 argument (key), got " f"{len(call.args)}" - ) - - key_arg = call.args[0] - - # Handle key argument - if isinstance(key_arg, ast.Name): - key_name = key_arg.id - if local_sym_tab and key_name in local_sym_tab: - key_ptr = local_sym_tab[key_name][0] - else: - raise ValueError( - f"Key variable {key_name} not found in local symbol table." - ) - elif isinstance(key_arg, ast.Constant) and isinstance(key_arg.value, int): - # Handle constant integer keys - key_val = key_arg.value - key_type = ir.IntType(64) - key_ptr = builder.alloca(key_type) - key_ptr.align = key_type.width // 8 - builder.store(ir.Constant(key_type, key_val), key_ptr) - else: - raise NotImplementedError( - "Only simple variable names and integer constants are supported as keys in map delete." - ) - - if key_ptr is None: - raise ValueError("Key pointer is None.") - - # Cast map pointer to void* - map_void_ptr = builder.bitcast(map_ptr, ir.PointerType()) - - # Define function type for bpf_map_delete_elem - fn_type = ir.FunctionType( - ir.IntType(64), # Return type: int64 (status code) - [ir.PointerType(), ir.PointerType()], # Args: (void*, void*) - var_arg=False, - ) - fn_ptr_type = ir.PointerType(fn_type) - - # Helper ID 3 is bpf_map_delete_elem - fn_addr = ir.Constant(ir.IntType(64), 3) - fn_ptr = builder.inttoptr(fn_addr, fn_ptr_type) - - # Call the helper function - result = builder.call(fn_ptr, [map_void_ptr, key_ptr], tail=False) - - return result, None - - -def bpf_get_current_pid_tgid_emitter( - call, - map_ptr, - module, - builder, - func, - local_sym_tab=None, - struct_sym_tab=None, - local_var_metadata=None, -): - """ - Emit LLVM IR for bpf_get_current_pid_tgid helper function call. - """ - # func is an arg to just have a uniform signature with other emitters - helper_id = ir.Constant(ir.IntType(64), 14) - fn_type = ir.FunctionType(ir.IntType(64), [], var_arg=False) - fn_ptr_type = ir.PointerType(fn_type) - fn_ptr = builder.inttoptr(helper_id, fn_ptr_type) - result = builder.call(fn_ptr, [], tail=False) - - # Extract the lower 32 bits (PID) using bitwise AND with 0xFFFFFFFF - mask = ir.Constant(ir.IntType(64), 0xFFFFFFFF) - pid = builder.and_(result, mask) - return pid, ir.IntType(64) - - -def bpf_perf_event_output_handler( - call, - map_ptr, - module, - builder, - func, - local_sym_tab=None, - struct_sym_tab=None, - local_var_metadata=None, -): - if len(call.args) != 1: - raise ValueError( - "Perf event output expects exactly one argument (data), got " - f"{len(call.args)}" - ) - data_arg = call.args[0] - ctx_ptr = func.args[0] # First argument to the function is ctx - - if isinstance(data_arg, ast.Name): - data_name = data_arg.id - if local_sym_tab and data_name in local_sym_tab: - data_ptr = local_sym_tab[data_name][0] - else: - raise ValueError( - f"Data variable {data_name} not found in local symbol table." - ) - # Check is data_name is a struct - if local_var_metadata and data_name in local_var_metadata: - data_type = local_var_metadata[data_name] - if data_type in struct_sym_tab: - struct_info = struct_sym_tab[data_type] - size_val = ir.Constant(ir.IntType(64), struct_info.size) - else: - raise ValueError( - f"Struct type {data_type} for variable {data_name} not found in struct symbol table." - ) - else: - raise ValueError( - f"Metadata for variable {data_name} not found in local variable metadata." - ) - - # BPF_F_CURRENT_CPU is -1 in 32 bit - flags_val = ir.Constant(ir.IntType(64), 0xFFFFFFFF) - - map_void_ptr = builder.bitcast(map_ptr, ir.PointerType()) - data_void_ptr = builder.bitcast(data_ptr, ir.PointerType()) - fn_type = ir.FunctionType( - ir.IntType(64), - [ - ir.PointerType(ir.IntType(8)), - ir.PointerType(), - ir.IntType(64), - ir.PointerType(), - ir.IntType(64), - ], - var_arg=False, - ) - fn_ptr_type = ir.PointerType(fn_type) - - # helper id - fn_addr = ir.Constant(ir.IntType(64), 25) - fn_ptr = builder.inttoptr(fn_addr, fn_ptr_type) - - result = builder.call( - fn_ptr, - [ctx_ptr, map_void_ptr, flags_val, data_void_ptr, size_val], - tail=False, - ) - return result, None - else: - raise NotImplementedError( - "Only simple object names are supported as data in perf event output." - ) - - -helper_func_list = { - "lookup": bpf_map_lookup_elem_emitter, - "print": bpf_printk_emitter, - "ktime": bpf_ktime_get_ns_emitter, - "update": bpf_map_update_elem_emitter, - "delete": bpf_map_delete_elem_emitter, - "pid": bpf_get_current_pid_tgid_emitter, - "output": bpf_perf_event_output_handler, -} - - -def handle_helper_call( - call, - module, - builder, - func, - local_sym_tab=None, - map_sym_tab=None, - struct_sym_tab=None, - local_var_metadata=None, -): - print(local_var_metadata) - if isinstance(call.func, ast.Name): - func_name = call.func.id - if func_name in helper_func_list: - # it is not a map method call - return helper_func_list[func_name]( - call, - None, - module, - builder, - func, - local_sym_tab, - struct_sym_tab, - local_var_metadata, - ) - else: - raise NotImplementedError( - f"Function {func_name} is not implemented as a helper function." - ) - elif isinstance(call.func, ast.Attribute): - # likely a map method call - if isinstance(call.func.value, ast.Call) and isinstance( - call.func.value.func, ast.Name - ): - map_name = call.func.value.func.id - method_name = call.func.attr - if map_sym_tab and map_name in map_sym_tab: - map_ptr = map_sym_tab[map_name] - if method_name in helper_func_list: - print(local_var_metadata) - return helper_func_list[method_name]( - call, - map_ptr, - module, - builder, - func, - local_sym_tab, - struct_sym_tab, - local_var_metadata, - ) - else: - raise NotImplementedError( - f"Map method {method_name} is not implemented as a helper function." - ) - else: - raise ValueError(f"Map variable {map_name} not found in symbol tables.") - elif isinstance(call.func.value, ast.Name): - obj_name = call.func.value.id - method_name = call.func.attr - if map_sym_tab and obj_name in map_sym_tab: - map_ptr = map_sym_tab[obj_name] - if method_name in helper_func_list: - return helper_func_list[method_name]( - call, - map_ptr, - module, - builder, - func, - local_sym_tab, - struct_sym_tab, - local_var_metadata, - ) - else: - raise NotImplementedError( - f"Map method {method_name} is not implemented as a helper function." - ) - else: - raise ValueError(f"Map variable {obj_name} not found in symbol tables.") - else: - raise NotImplementedError("Attribute not supported for map method calls.") - return None diff --git a/pythonbpf/expr_pass.py b/pythonbpf/expr_pass.py index 03f6651..f81ba89 100644 --- a/pythonbpf/expr_pass.py +++ b/pythonbpf/expr_pass.py @@ -32,7 +32,7 @@ def eval_expr( return None elif isinstance(expr, ast.Call): # delayed import to avoid circular dependency - from .bpf_helper_handler import helper_func_list, handle_helper_call + from pythonbpf.helper import HelperHandlerRegistry, handle_helper_call if isinstance(expr.func, ast.Name): # check deref @@ -63,7 +63,7 @@ def eval_expr( return val, local_sym_tab[expr.args[0].id][1] # check for helpers - if expr.func.id in helper_func_list: + if HelperHandlerRegistry.has_handler(expr.func.id): return handle_helper_call( expr, module, @@ -80,7 +80,7 @@ def eval_expr( expr.func.value.func, ast.Name ): method_name = expr.func.attr - if method_name in helper_func_list: + if HelperHandlerRegistry.has_handler(method_name): return handle_helper_call( expr, module, @@ -95,7 +95,7 @@ def eval_expr( obj_name = expr.func.value.id method_name = expr.func.attr if obj_name in map_sym_tab: - if method_name in helper_func_list: + if HelperHandlerRegistry.has_handler(method_name): return handle_helper_call( expr, module, diff --git a/pythonbpf/functions_pass.py b/pythonbpf/functions_pass.py index 2aef65f..f4a071c 100644 --- a/pythonbpf/functions_pass.py +++ b/pythonbpf/functions_pass.py @@ -2,7 +2,7 @@ import ast -from .bpf_helper_handler import helper_func_list, handle_helper_call +from .helper import HelperHandlerRegistry, handle_helper_call from .type_deducer import ctypes_to_ir from .binary_ops import handle_binary_op from .expr_pass import eval_expr, handle_expr @@ -83,16 +83,19 @@ def handle_assign( elif isinstance(rval, ast.Constant): if isinstance(rval.value, bool): if rval.value: - builder.store(ir.Constant(ir.IntType(1), 1), local_sym_tab[var_name][0]) + builder.store(ir.Constant(ir.IntType(1), 1), + local_sym_tab[var_name][0]) else: - builder.store(ir.Constant(ir.IntType(1), 0), local_sym_tab[var_name][0]) + builder.store(ir.Constant(ir.IntType(1), 0), + local_sym_tab[var_name][0]) print(f"Assigned constant {rval.value} to {var_name}") elif isinstance(rval.value, int): # Assume c_int64 for now # var = builder.alloca(ir.IntType(64), name=var_name) # var.align = 8 builder.store( - ir.Constant(ir.IntType(64), rval.value), local_sym_tab[var_name][0] + ir.Constant(ir.IntType(64), + rval.value), local_sym_tab[var_name][0] ) # local_sym_tab[var_name] = var print(f"Assigned constant {rval.value} to {var_name}") @@ -107,7 +110,8 @@ def handle_assign( global_str.linkage = "internal" global_str.global_constant = True global_str.initializer = str_const - str_ptr = builder.bitcast(global_str, ir.PointerType(ir.IntType(8))) + str_ptr = builder.bitcast( + global_str, ir.PointerType(ir.IntType(8))) builder.store(str_ptr, local_sym_tab[var_name][0]) print(f"Assigned string constant '{rval.value}' to {var_name}") else: @@ -126,14 +130,15 @@ def handle_assign( # var = builder.alloca(ir_type, name=var_name) # var.align = ir_type.width // 8 builder.store( - ir.Constant(ir_type, rval.args[0].value), local_sym_tab[var_name][0] + ir.Constant( + ir_type, rval.args[0].value), local_sym_tab[var_name][0] ) print( f"Assigned {call_type} constant " f"{rval.args[0].value} to {var_name}" ) # local_sym_tab[var_name] = var - elif call_type in helper_func_list: + elif HelperHandlerRegistry.has_handler(call_type): # var = builder.alloca(ir.IntType(64), name=var_name) # var.align = 8 val = handle_helper_call( @@ -172,7 +177,8 @@ def handle_assign( ir_type = struct_info.ir_type # var = builder.alloca(ir_type, name=var_name) # Null init - builder.store(ir.Constant(ir_type, None), local_sym_tab[var_name][0]) + builder.store(ir.Constant(ir_type, None), + local_sym_tab[var_name][0]) local_var_metadata[var_name] = call_type print(f"Assigned struct {call_type} to {var_name}") # local_sym_tab[var_name] = var @@ -189,8 +195,7 @@ def handle_assign( map_name = rval.func.value.func.id method_name = rval.func.attr if map_name in map_sym_tab: - # map_ptr = map_sym_tab[map_name] - if method_name in helper_func_list: + if HelperHandlerRegistry.has_handler(method_name): val = handle_helper_call( rval, module, @@ -244,7 +249,8 @@ def handle_cond(func, module, builder, cond, local_sym_tab, map_sym_tab): print(f"Undefined variable {cond.id} in condition") return None elif isinstance(cond, ast.Compare): - lhs = eval_expr(func, module, builder, cond.left, local_sym_tab, map_sym_tab)[0] + lhs = eval_expr(func, module, builder, cond.left, + local_sym_tab, map_sym_tab)[0] if len(cond.ops) != 1 or len(cond.comparators) != 1: print("Unsupported complex comparison") return None @@ -297,7 +303,8 @@ def handle_if( else: else_block = None - cond = handle_cond(func, module, builder, stmt.test, local_sym_tab, map_sym_tab) + cond = handle_cond(func, module, builder, stmt.test, + local_sym_tab, map_sym_tab) if else_block: builder.cbranch(cond, then_block, else_block) else: @@ -442,8 +449,9 @@ def allocate_mem( ir_type = ctypes_to_ir(call_type) var = builder.alloca(ir_type, name=var_name) var.align = ir_type.width // 8 - print(f"Pre-allocated variable {var_name} of type {call_type}") - elif call_type in helper_func_list: + print( + f"Pre-allocated variable {var_name} of type {call_type}") + elif HelperHandlerRegistry.has_handler(call_type): # Assume return type is int64 for now ir_type = ir.IntType(64) var = builder.alloca(ir_type, name=var_name) @@ -461,7 +469,8 @@ def allocate_mem( var = builder.alloca(ir_type, name=var_name) local_var_metadata[var_name] = call_type print( - f"Pre-allocated variable {var_name} for struct {call_type}" + f"Pre-allocated variable { + var_name} for struct {call_type}" ) elif isinstance(rval.func, ast.Attribute): ir_type = ir.PointerType(ir.IntType(64)) @@ -662,7 +671,8 @@ def _expr_type(e): if found_type is None: found_type = t elif found_type != t: - raise ValueError("Conflicting return types:" f"{found_type} vs {t}") + raise ValueError("Conflicting return types:" f"{ + found_type} vs {t}") return found_type or "None" @@ -699,7 +709,8 @@ def assign_string_to_array(builder, target_array_ptr, source_string_ptr, array_l char = builder.load(src_ptr) # Store character in target - dst_ptr = builder.gep(target_array_ptr, [ir.Constant(ir.IntType(32), 0), idx]) + dst_ptr = builder.gep( + target_array_ptr, [ir.Constant(ir.IntType(32), 0), idx]) builder.store(char, dst_ptr) # Increment counter @@ -710,5 +721,6 @@ def assign_string_to_array(builder, target_array_ptr, source_string_ptr, array_l # Ensure null termination last_idx = ir.Constant(ir.IntType(32), array_length - 1) - null_ptr = builder.gep(target_array_ptr, [ir.Constant(ir.IntType(32), 0), last_idx]) + null_ptr = builder.gep( + target_array_ptr, [ir.Constant(ir.IntType(32), 0), last_idx]) builder.store(ir.Constant(ir.IntType(8), 0), null_ptr) diff --git a/pythonbpf/helper/__init__.py b/pythonbpf/helper/__init__.py new file mode 100644 index 0000000..5e538d3 --- /dev/null +++ b/pythonbpf/helper/__init__.py @@ -0,0 +1,2 @@ +from .helper_utils import HelperHandlerRegistry +from .bpf_helper_handler import handle_helper_call diff --git a/pythonbpf/helper/bpf_helper_handler.py b/pythonbpf/helper/bpf_helper_handler.py new file mode 100644 index 0000000..a24d9ee --- /dev/null +++ b/pythonbpf/helper/bpf_helper_handler.py @@ -0,0 +1,273 @@ +import ast +from llvmlite import ir +from enum import Enum +from .helper_utils import (HelperHandlerRegistry, + get_or_create_ptr_from_arg, get_flags_val, + handle_fstring_print, simple_string_print, + get_data_ptr_and_size) + + +class BPFHelperID(Enum): + BPF_MAP_LOOKUP_ELEM = 1 + BPF_MAP_UPDATE_ELEM = 2 + BPF_MAP_DELETE_ELEM = 3 + BPF_KTIME_GET_NS = 5 + BPF_PRINTK = 6 + BPF_GET_CURRENT_PID_TGID = 14 + BPF_PERF_EVENT_OUTPUT = 25 + + +@HelperHandlerRegistry.register("ktime") +def bpf_ktime_get_ns_emitter(call, map_ptr, module, builder, func, + local_sym_tab=None, struct_sym_tab=None, + local_var_metadata=None): + """ + Emit LLVM IR for bpf_ktime_get_ns helper function call. + """ + # func is an arg to just have a uniform signature with other emitters + helper_id = ir.Constant(ir.IntType(64), BPFHelperID.BPF_KTIME_GET_NS.value) + fn_type = ir.FunctionType(ir.IntType(64), [], var_arg=False) + fn_ptr_type = ir.PointerType(fn_type) + fn_ptr = builder.inttoptr(helper_id, fn_ptr_type) + result = builder.call(fn_ptr, [], tail=False) + return result, ir.IntType(64) + + +@HelperHandlerRegistry.register("lookup") +def bpf_map_lookup_elem_emitter(call, map_ptr, module, builder, func, + local_sym_tab=None, struct_sym_tab=None, + local_var_metadata=None): + """ + Emit LLVM IR for bpf_map_lookup_elem helper function call. + """ + if not call.args or len(call.args) != 1: + raise ValueError("Map lookup expects exactly one argument (key), got " + f"{len(call.args)}") + key_ptr = get_or_create_ptr_from_arg(call.args[0], builder, local_sym_tab) + map_void_ptr = builder.bitcast(map_ptr, ir.PointerType()) + + fn_type = ir.FunctionType( + ir.PointerType(), # Return type: void* + [ir.PointerType(), ir.PointerType()], # Args: (void*, void*) + var_arg=False + ) + fn_ptr_type = ir.PointerType(fn_type) + + fn_addr = ir.Constant(ir.IntType( + 64), BPFHelperID.BPF_MAP_LOOKUP_ELEM.value) + fn_ptr = builder.inttoptr(fn_addr, fn_ptr_type) + + result = builder.call(fn_ptr, [map_void_ptr, key_ptr], tail=False) + + return result, ir.PointerType() + + +@HelperHandlerRegistry.register("print") +def bpf_printk_emitter(call, map_ptr, module, builder, func, + local_sym_tab=None, struct_sym_tab=None, + local_var_metadata=None): + """Emit LLVM IR for bpf_printk helper function call.""" + if not hasattr(func, "_fmt_counter"): + func._fmt_counter = 0 + + if not call.args: + raise ValueError( + "bpf_printk expects at least one argument (format string)") + + args = [] + if isinstance(call.args[0], ast.JoinedStr): + args = handle_fstring_print(call.args[0], module, builder, func, + local_sym_tab, struct_sym_tab, + local_var_metadata) + elif (isinstance(call.args[0], ast.Constant) and + isinstance(call.args[0].value, str)): + # TODO: We are only supporting single arguments for now. + # In case of multiple args, the first one will be taken. + args = simple_string_print(call.args[0].value, module, builder, func) + else: + raise NotImplementedError( + "Only simple strings or f-strings are supported in bpf_printk.") + + fn_type = ir.FunctionType( + ir.IntType(64), [ir.PointerType(), ir.IntType(32)], var_arg=True) + fn_ptr_type = ir.PointerType(fn_type) + fn_addr = ir.Constant(ir.IntType(64), BPFHelperID.BPF_PRINTK.value) + fn_ptr = builder.inttoptr(fn_addr, fn_ptr_type) + + builder.call(fn_ptr, args, tail=True) + return None + + +@HelperHandlerRegistry.register("update") +def bpf_map_update_elem_emitter(call, map_ptr, module, builder, func, + local_sym_tab=None, struct_sym_tab=None, + local_var_metadata=None): + """ + Emit LLVM IR for bpf_map_update_elem helper function call. + Expected call signature: map.update(key, value, flags=0) + """ + if (not call.args or + len(call.args) < 2 or + len(call.args) > 3): + raise ValueError("Map update expects 2 or 3 args (key, value, flags), " + f"got {len(call.args)}") + + key_arg = call.args[0] + value_arg = call.args[1] + flags_arg = call.args[2] if len(call.args) > 2 else None + + key_ptr = get_or_create_ptr_from_arg(key_arg, builder, local_sym_tab) + value_ptr = get_or_create_ptr_from_arg(value_arg, builder, local_sym_tab) + flags_val = get_flags_val(flags_arg, builder, local_sym_tab) + + map_void_ptr = builder.bitcast(map_ptr, ir.PointerType()) + fn_type = ir.FunctionType( + ir.IntType(64), + [ir.PointerType(), ir.PointerType(), ir.PointerType(), ir.IntType(64)], + var_arg=False + ) + fn_ptr_type = ir.PointerType(fn_type) + + fn_addr = ir.Constant(ir.IntType( + 64), BPFHelperID.BPF_MAP_UPDATE_ELEM.value) + fn_ptr = builder.inttoptr(fn_addr, fn_ptr_type) + + if isinstance(flags_val, int): + flags_const = ir.Constant(ir.IntType(64), flags_val) + else: + flags_const = flags_val + + result = builder.call( + fn_ptr, [map_void_ptr, key_ptr, value_ptr, flags_const], tail=False) + + return result, None + + +@HelperHandlerRegistry.register("delete") +def bpf_map_delete_elem_emitter(call, map_ptr, module, builder, func, + local_sym_tab=None, struct_sym_tab=None, + local_var_metadata=None): + """ + Emit LLVM IR for bpf_map_delete_elem helper function call. + Expected call signature: map.delete(key) + """ + if not call.args or len(call.args) != 1: + raise ValueError("Map delete expects exactly one argument (key), got " + f"{len(call.args)}") + key_ptr = get_or_create_ptr_from_arg(call.args[0], builder, local_sym_tab) + map_void_ptr = builder.bitcast(map_ptr, ir.PointerType()) + + # Define function type for bpf_map_delete_elem + fn_type = ir.FunctionType( + ir.IntType(64), # Return type: int64 (status code) + [ir.PointerType(), ir.PointerType()], # Args: (void*, void*) + var_arg=False + ) + fn_ptr_type = ir.PointerType(fn_type) + + fn_addr = ir.Constant(ir.IntType( + 64), BPFHelperID.BPF_MAP_DELETE_ELEM.value) + fn_ptr = builder.inttoptr(fn_addr, fn_ptr_type) + + result = builder.call(fn_ptr, [map_void_ptr, key_ptr], tail=False) + + return result, None + + +@HelperHandlerRegistry.register("pid") +def bpf_get_current_pid_tgid_emitter(call, map_ptr, module, builder, func, + local_sym_tab=None, struct_sym_tab=None, + local_var_metadata=None): + """ + Emit LLVM IR for bpf_get_current_pid_tgid helper function call. + """ + # func is an arg to just have a uniform signature with other emitters + helper_id = ir.Constant(ir.IntType( + 64), BPFHelperID.BPF_GET_CURRENT_PID_TGID.value) + fn_type = ir.FunctionType(ir.IntType(64), [], var_arg=False) + fn_ptr_type = ir.PointerType(fn_type) + fn_ptr = builder.inttoptr(helper_id, fn_ptr_type) + result = builder.call(fn_ptr, [], tail=False) + + # Extract the lower 32 bits (PID) using bitwise AND with 0xFFFFFFFF + mask = ir.Constant(ir.IntType(64), 0xFFFFFFFF) + pid = builder.and_(result, mask) + return pid, ir.IntType(64) + + +@HelperHandlerRegistry.register("output") +def bpf_perf_event_output_handler(call, map_ptr, module, builder, func, + local_sym_tab=None, struct_sym_tab=None, + local_var_metadata=None): + if len(call.args) != 1: + raise ValueError("Perf event output expects exactly one argument, " + f"got {len(call.args)}") + data_arg = call.args[0] + ctx_ptr = func.args[0] # First argument to the function is ctx + + data_ptr, size_val = get_data_ptr_and_size(data_arg, local_sym_tab, + struct_sym_tab, + local_var_metadata) + + # BPF_F_CURRENT_CPU is -1 in 32 bit + flags_val = ir.Constant(ir.IntType(64), 0xFFFFFFFF) + + map_void_ptr = builder.bitcast(map_ptr, ir.PointerType()) + data_void_ptr = builder.bitcast(data_ptr, ir.PointerType()) + fn_type = ir.FunctionType( + ir.IntType(64), + [ir.PointerType(ir.IntType(8)), ir.PointerType(), ir.IntType(64), + ir.PointerType(), ir.IntType(64)], + var_arg=False + ) + fn_ptr_type = ir.PointerType(fn_type) + + # helper id + fn_addr = ir.Constant(ir.IntType(64), + BPFHelperID.BPF_PERF_EVENT_OUTPUT.value) + fn_ptr = builder.inttoptr(fn_addr, fn_ptr_type) + + result = builder.call( + fn_ptr, + [ctx_ptr, map_void_ptr, flags_val, data_void_ptr, size_val], + tail=False) + return result, None + + +def handle_helper_call(call, module, builder, func, + local_sym_tab=None, map_sym_tab=None, + struct_sym_tab=None, local_var_metadata=None): + """Process a BPF helper function call and emit the appropriate LLVM IR.""" + # Helper function to get map pointer and invoke handler + def invoke_helper(method_name, map_ptr=None): + handler = HelperHandlerRegistry.get_handler(method_name) + if not handler: + raise NotImplementedError( + f"Helper function '{method_name}' is not implemented.") + return handler(call, map_ptr, module, builder, func, + local_sym_tab, struct_sym_tab, local_var_metadata) + + # Handle direct function calls (e.g., print(), ktime()) + if isinstance(call.func, ast.Name): + return invoke_helper(call.func.id) + + # Handle method calls (e.g., map.lookup(), map.update()) + elif isinstance(call.func, ast.Attribute): + method_name = call.func.attr + value = call.func.value + + # Get map pointer from different styles of map access + if isinstance(value, ast.Call) and isinstance(value.func, ast.Name): + # Variable style: my_map.lookup(key) + map_name = value.func.id + else: + raise NotImplementedError( + f"Unsupported map access pattern: {ast.dump(value)}") + + # Verify map exists and get pointer + if not map_sym_tab or map_name not in map_sym_tab: + raise ValueError(f"Map '{map_name}' not found in symbol table") + + return invoke_helper(method_name, map_sym_tab[map_name]) + + return None diff --git a/pythonbpf/helper/helper_utils.py b/pythonbpf/helper/helper_utils.py new file mode 100644 index 0000000..adb766e --- /dev/null +++ b/pythonbpf/helper/helper_utils.py @@ -0,0 +1,288 @@ +import ast +import logging +from llvmlite import ir +from pythonbpf.expr_pass import eval_expr + +logger = logging.getLogger(__name__) + + +class HelperHandlerRegistry: + """Registry for BPF helpers""" + _handlers = {} + + @classmethod + def register(cls, helper_name): + """Decorator to register a handler function for a helper""" + def decorator(func): + cls._handlers[helper_name] = func + return func + return decorator + + @classmethod + def get_handler(cls, helper_name): + """Get the handler function for a helper""" + return cls._handlers.get(helper_name) + + @classmethod + def has_handler(cls, helper_name): + """Check if a handler function is registered for a helper""" + return helper_name in cls._handlers + + +def get_var_ptr_from_name(var_name, local_sym_tab): + """Get a pointer to a variable from the symbol table.""" + if local_sym_tab and var_name in local_sym_tab: + return local_sym_tab[var_name][0] + raise ValueError(f"Variable '{var_name}' not found in local symbol table") + + +def create_int_constant_ptr(value, builder, int_width=64): + """Create a pointer to an integer constant.""" + # Default to 64-bit integer + int_type = ir.IntType(int_width) + ptr = builder.alloca(int_type) + ptr.align = int_type.width // 8 + builder.store(ir.Constant(int_type, value), ptr) + return ptr + + +def get_or_create_ptr_from_arg(arg, builder, local_sym_tab): + """Extract or create pointer from the call arguments.""" + + if isinstance(arg, ast.Name): + ptr = get_var_ptr_from_name(arg.id, local_sym_tab) + elif isinstance(arg, ast.Constant) and isinstance(arg.value, int): + ptr = create_int_constant_ptr(arg.value, builder) + else: + raise NotImplementedError( + "Only simple variable names are supported as args in map helpers.") + return ptr + + +def get_flags_val(arg, builder, local_sym_tab): + """Extract or create flags value from the call arguments.""" + if not arg: + return 0 + + if isinstance(arg, ast.Name): + if local_sym_tab and arg.id in local_sym_tab: + flags_ptr = local_sym_tab[arg.id][0] + return builder.load(flags_ptr) + else: + raise ValueError( + f"Variable '{arg.id}' not found in local symbol table") + elif isinstance(arg, ast.Constant) and isinstance(arg.value, int): + return arg.value + + raise NotImplementedError( + "Only var names or int consts are supported as map helpers flags.") + + +def simple_string_print(string_value, module, builder, func): + """Prepare arguments for bpf_printk from a simple string value""" + fmt_str = string_value + "\n\0" + fmt_ptr = _create_format_string_global(fmt_str, func, module, builder) + + args = [fmt_ptr, ir.Constant(ir.IntType(32), len(fmt_str))] + return args + + +def handle_fstring_print(joined_str, module, builder, func, + local_sym_tab=None, struct_sym_tab=None, + local_var_metadata=None): + """Handle f-string formatting for bpf_printk emitter.""" + fmt_parts = [] + exprs = [] + + for value in joined_str.values: + logger.debug(f"Processing f-string value: {ast.dump(value)}") + + if isinstance(value, ast.Constant): + _process_constant_in_fstring(value, fmt_parts, exprs) + elif isinstance(value, ast.FormattedValue): + _process_fval(value, fmt_parts, exprs, + local_sym_tab, struct_sym_tab, + local_var_metadata) + else: + raise NotImplementedError( + f"Unsupported f-string value type: {type(value)}") + + fmt_str = "".join(fmt_parts) + args = simple_string_print(fmt_str, module, builder, func) + + # NOTE: Process expressions (limited to 3 due to BPF constraints) + if len(exprs) > 3: + logger.warning( + "bpf_printk supports up to 3 args, extra args will be ignored.") + + for expr in exprs[:3]: + arg_value = _prepare_expr_args(expr, func, module, builder, + local_sym_tab, struct_sym_tab, + local_var_metadata) + args.append(arg_value) + + return args + + +def _process_constant_in_fstring(cst, fmt_parts, exprs): + """Process constant values in f-string.""" + if isinstance(cst.value, str): + fmt_parts.append(cst.value) + elif isinstance(cst.value, int): + fmt_parts.append("%lld") + exprs.append(ir.Constant(ir.IntType(64), cst.value)) + else: + raise NotImplementedError( + f"Unsupported constant type in f-string: {type(cst.value)}") + + +def _process_fval(fval, fmt_parts, exprs, + local_sym_tab, struct_sym_tab, + local_var_metadata): + """Process formatted values in f-string.""" + logger.debug(f"Processing formatted value: {ast.dump(fval)}") + + if isinstance(fval.value, ast.Name): + _process_name_in_fval(fval.value, fmt_parts, exprs, local_sym_tab) + elif isinstance(fval.value, ast.Attribute): + _process_attr_in_fval(fval.value, fmt_parts, exprs, + local_sym_tab, struct_sym_tab, + local_var_metadata) + else: + raise NotImplementedError( + f"Unsupported formatted value in f-string: {type(fval.value)}") + + +def _process_name_in_fval(name_node, fmt_parts, exprs, local_sym_tab): + """Process name nodes in formatted values.""" + if local_sym_tab and name_node.id in local_sym_tab: + _, var_type = local_sym_tab[name_node.id] + _populate_fval(var_type, name_node, fmt_parts, exprs) + + +def _process_attr_in_fval(attr_node, fmt_parts, exprs, + local_sym_tab, struct_sym_tab, + local_var_metadata): + """Process attribute nodes in formatted values.""" + if (isinstance(attr_node.value, ast.Name) and + local_sym_tab and attr_node.value.id in local_sym_tab): + var_name = attr_node.value.id + field_name = attr_node.attr + + if not local_var_metadata or var_name not in local_var_metadata: + raise ValueError( + f"Metadata for '{var_name}' not found in local var metadata") + + var_type = local_var_metadata[var_name] + if var_type not in struct_sym_tab: + raise ValueError( + f"Struct '{var_type}' for '{var_name}' not in symbol table") + + struct_info = struct_sym_tab[var_type] + if field_name not in struct_info.fields: + raise ValueError( + f"Field '{field_name}' not found in struct '{var_type}'") + + field_type = struct_info.field_type(field_name) + _populate_fval(field_type, attr_node, fmt_parts, exprs) + else: + raise NotImplementedError( + "Only simple attribute on local vars is supported in f-strings.") + + +def _populate_fval(ftype, node, fmt_parts, exprs): + """Populate format parts and expressions based on field type.""" + if isinstance(ftype, ir.IntType): + # TODO: We print as signed integers only for now + if ftype.width == 64: + fmt_parts.append("%lld") + exprs.append(node) + elif ftype.width == 32: + fmt_parts.append("%d") + exprs.append(node) + else: + raise NotImplementedError( + f"Unsupported integer width in f-string: {ftype.width}") + elif ftype == ir.PointerType(ir.IntType(8)): + # NOTE: We assume i8* is a string + fmt_parts.append("%s") + exprs.append(node) + else: + raise NotImplementedError( + f"Unsupported field type in f-string: {ftype}") + + +def _create_format_string_global(fmt_str, func, module, builder): + """Create a global variable for the format string.""" + fmt_name = f"{func.name}____fmt{func._fmt_counter}" + func._fmt_counter += 1 + + fmt_gvar = ir.GlobalVariable( + module, ir.ArrayType(ir.IntType(8), len(fmt_str)), name=fmt_name) + fmt_gvar.global_constant = True + fmt_gvar.initializer = ir.Constant( + ir.ArrayType(ir.IntType(8), len(fmt_str)), + bytearray(fmt_str.encode("utf8")) + ) + fmt_gvar.linkage = "internal" + fmt_gvar.align = 1 + + return builder.bitcast(fmt_gvar, ir.PointerType()) + + +def _prepare_expr_args(expr, func, module, builder, + local_sym_tab, struct_sym_tab, + local_var_metadata): + """Evaluate and prepare an expression to use as an arg for bpf_printk.""" + val, _ = eval_expr(func, module, builder, expr, + local_sym_tab, None, struct_sym_tab, + local_var_metadata) + + if val: + if isinstance(val.type, ir.PointerType): + val = builder.ptrtoint(val, ir.IntType(64)) + elif isinstance(val.type, ir.IntType): + if val.type.width < 64: + val = builder.sext(val, ir.IntType(64)) + else: + logger.warning( + "Only int and ptr supported in bpf_printk args. " + "Others default to 0.") + val = ir.Constant(ir.IntType(64), 0) + return val + else: + logger.warning( + "Failed to evaluate expression for bpf_printk argument. " + "It will be converted to 0.") + return ir.Constant(ir.IntType(64), 0) + + +def get_data_ptr_and_size(data_arg, local_sym_tab, struct_sym_tab, + local_var_metadata): + """Extract data pointer and size information for perf event output.""" + if isinstance(data_arg, ast.Name): + data_name = data_arg.id + if local_sym_tab and data_name in local_sym_tab: + data_ptr = local_sym_tab[data_name][0] + else: + raise ValueError( + f"Data variable {data_name} not found in local symbol table.") + + # Check if data_name is a struct + if local_var_metadata and data_name in local_var_metadata: + data_type = local_var_metadata[data_name] + if data_type in struct_sym_tab: + struct_info = struct_sym_tab[data_type] + size_val = ir.Constant(ir.IntType(64), struct_info.size) + return data_ptr, size_val + else: + raise ValueError( + f"Struct {data_type} for {data_name} not in symbol table.") + else: + raise ValueError( + f"Metadata for variable {data_name} " + "not found in local variable metadata.") + else: + raise NotImplementedError( + "Only simple object names are supported " + "as data in perf event output.")