diff --git a/pythonbpf/codegen.py b/pythonbpf/codegen.py index 5de23a5..772c2b2 100644 --- a/pythonbpf/codegen.py +++ b/pythonbpf/codegen.py @@ -4,7 +4,11 @@ from .functions_pass import func_proc from .maps import maps_proc from .structs import structs_proc -from .globals_pass import globals_processing +from .globals_pass import ( + globals_list_creation, + globals_processing, + populate_global_symbol_table, +) from .debuginfo import DW_LANG_C11, DwarfBehaviorEnum, DebugInfoGenerator import os import subprocess @@ -40,12 +44,15 @@ def processor(source_code, filename, module): for func_node in bpf_chunks: logger.info(f"Found BPF function/struct: {func_node.name}") + populate_global_symbol_table(tree, module) + license_processing(tree, module) + globals_processing(tree, module) + structs_sym_tab = structs_proc(tree, module, bpf_chunks) map_sym_tab = maps_proc(tree, module, bpf_chunks) func_proc(tree, module, bpf_chunks, map_sym_tab, structs_sym_tab) - license_processing(tree, module) - globals_processing(tree, module) + globals_list_creation(tree, module) def compile_to_ir(filename: str, output: str, loglevel=logging.WARNING): diff --git a/pythonbpf/globals_pass.py b/pythonbpf/globals_pass.py index 1228809..1e97763 100644 --- a/pythonbpf/globals_pass.py +++ b/pythonbpf/globals_pass.py @@ -1,8 +1,121 @@ from llvmlite import ir import ast +from logging import Logger +import logging +from .type_deducer import ctypes_to_ir -def emit_globals(module: ir.Module, names: list[str]): +logger: Logger = logging.getLogger(__name__) + +# TODO: this is going to be a huge fuck of a headache in the future. +global_sym_tab = [] + + +def populate_global_symbol_table(tree, module: ir.Module): + for node in tree.body: + if isinstance(node, ast.FunctionDef): + for dec in node.decorator_list: + if ( + isinstance(dec, ast.Call) + and isinstance(dec.func, ast.Name) + and dec.func.id == "section" + and len(dec.args) == 1 + and isinstance(dec.args[0], ast.Constant) + and isinstance(dec.args[0].value, str) + ): + global_sym_tab.append(node) + elif isinstance(dec, ast.Name) and dec.id == "bpfglobal": + global_sym_tab.append(node) + + elif isinstance(dec, ast.Name) and dec.id == "map": + global_sym_tab.append(node) + return False + + +def emit_global(module: ir.Module, node, name): + logger.info(f"global identifier {name} processing") + # deduce LLVM type from the annotated return + if not isinstance(node.returns, ast.Name): + raise ValueError(f"Unsupported return annotation {ast.dump(node.returns)}") + ty = ctypes_to_ir(node.returns.id) + + # extract the return expression + # TODO: turn this return extractor into a generic function I can use everywhere. + ret_stmt = node.body[0] + if not isinstance(ret_stmt, ast.Return) or ret_stmt.value is None: + raise ValueError(f"Global '{name}' has no valid return") + + init_val = ret_stmt.value + + # simple constant like "return 0" + if isinstance(init_val, ast.Constant): + llvm_init = ir.Constant(ty, init_val.value) + + # variable reference like "return SOME_CONST" + elif isinstance(init_val, ast.Name): + # need symbol resolution here, stub as 0 for now + raise ValueError(f"Name reference {init_val.id} not yet supported") + + # constructor call like "return c_int64(0)" or dataclass(...) + elif isinstance(init_val, ast.Call): + if len(init_val.args) >= 1 and isinstance(init_val.args[0], ast.Constant): + llvm_init = ir.Constant(ty, init_val.args[0].value) + else: + logger.info("Defaulting to zero as no constant argument found") + llvm_init = ir.Constant(ty, 0) + else: + raise ValueError(f"Unsupported return expr {ast.dump(init_val)}") + + gvar = ir.GlobalVariable(module, ty, name=name) + gvar.initializer = llvm_init + gvar.align = 8 + gvar.linkage = "dso_local" + gvar.global_constant = False + return gvar + + +def globals_processing(tree, module): + """Process stuff decorated with @bpf and @bpfglobal except license and return the section name""" + globals_sym_tab = [] + + for node in tree.body: + # Skip non-assignment and non-function nodes + if not (isinstance(node, ast.FunctionDef)): + continue + + # Get the name based on node type + if isinstance(node, ast.FunctionDef): + name = node.name + else: + continue + + # Check for duplicate names + if name in globals_sym_tab: + raise SyntaxError(f"ERROR: Global name '{name}' previously defined") + else: + globals_sym_tab.append(name) + + if isinstance(node, ast.FunctionDef) and node.name != "LICENSE": + decorators = [ + dec.id for dec in node.decorator_list if isinstance(dec, ast.Name) + ] + if "bpf" in decorators and "bpfglobal" in decorators: + if ( + len(node.body) == 1 + and isinstance(node.body[0], ast.Return) + and node.body[0].value is not None + and isinstance( + node.body[0].value, (ast.Constant, ast.Name, ast.Call) + ) + ): + emit_global(module, node, name) + else: + raise SyntaxError(f"ERROR: Invalid syntax for {name} global") + + return None + + +def emit_llvm_compiler_used(module: ir.Module, names: list[str]): """ Emit the @llvm.compiler.used global given a list of function/global names. """ @@ -24,7 +137,7 @@ def emit_globals(module: ir.Module, names: list[str]): gv.section = "llvm.metadata" -def globals_processing(tree, module: ir.Module): +def globals_list_creation(tree, module: ir.Module): collected = ["LICENSE"] for node in tree.body: @@ -40,10 +153,11 @@ def globals_processing(tree, module: ir.Module): ): collected.append(node.name) - elif isinstance(dec, ast.Name) and dec.id == "bpfglobal": - collected.append(node.name) + # NOTE: all globals other than + # elif isinstance(dec, ast.Name) and dec.id == "bpfglobal": + # collected.append(node.name) elif isinstance(dec, ast.Name) and dec.id == "map": collected.append(node.name) - emit_globals(module, collected) + emit_llvm_compiler_used(module, collected) diff --git a/tests/c-form/globals.bpf.c b/tests/c-form/globals.bpf.c new file mode 100644 index 0000000..588cac5 --- /dev/null +++ b/tests/c-form/globals.bpf.c @@ -0,0 +1,27 @@ +// SPDX-License-Identifier: GPL-2.0 OR BSD-3-Clause +#include +#include +#include +#include + +struct test_struct { + __u64 a; + __u64 b; +}; + +struct test_struct w = {}; +volatile __u64 prev_time = 0; + +SEC("tracepoint/syscalls/sys_enter_execve") +int trace_execve(void *ctx) +{ + bpf_printk("previous %ul now %ul", w.b, w.a); + __u64 ts = bpf_ktime_get_ns(); + bpf_printk("prev %ul now %ul", prev_time, ts); + w.a = ts; + w.b = prev_time; + prev_time = ts; + return 0; +} + +char LICENSE[] SEC("license") = "GPL"; diff --git a/tests/failing_tests/globals.py b/tests/failing_tests/globals.py new file mode 100644 index 0000000..55d9740 --- /dev/null +++ b/tests/failing_tests/globals.py @@ -0,0 +1,101 @@ +import logging + +from pythonbpf import compile, bpf, section, bpfglobal, compile_to_ir +from ctypes import c_void_p, c_int64, c_int32 + +@bpf +@bpfglobal +def somevalue() -> c_int32: + return c_int32(42) + +@bpf +@bpfglobal +def somevalue2() -> c_int64: + return c_int64(69) + +@bpf +@bpfglobal +def somevalue1() -> c_int32: + return c_int32(42) + + +# --- Passing examples --- + +# Simple constant return +@bpf +@bpfglobal +def g1() -> c_int64: + return c_int64(42) + +# Constructor with one constant argument +@bpf +@bpfglobal +def g2() -> c_int64: + return c_int64(69) + + +# --- Failing examples --- + +# No return annotation +# @bpf +# @bpfglobal +# def g3(): +# return 42 + +# Return annotation is complex +# @bpf +# @bpfglobal +# def g4() -> List[int]: +# return [] + +# # Return is missing +# @bpf +# @bpfglobal +# def g5() -> c_int64: +# pass + +# # Return is a variable reference +# #TODO: maybe fix this sometime later. It defaults to 0 +# CONST = 5 +# @bpf +# @bpfglobal +# def g6() -> c_int64: +# return c_int64(CONST) + +# Constructor with multiple args +#TODO: this is not working. should it work ? +@bpf +@bpfglobal +def g7() -> c_int64: + return c_int64(1) + +# Dataclass call +#TODO: fails with dataclass +# @dataclass +# class Point: +# x: c_int64 +# y: c_int64 + +# @bpf +# @bpfglobal +# def g8() -> Point: +# return Point(1, 2) + + +@bpf +@section("tracepoint/syscalls/sys_enter_execve") +def sometag(ctx: c_void_p) -> c_int64: + print("test") + global somevalue + somevalue = 2 + print(f"{somevalue}") + return c_int64(1) + +@bpf +@bpfglobal +def LICENSE() -> str: + return "GPL" + + +compile_to_ir("globals.py", "globals.ll", loglevel=logging.INFO) +compile() diff --git a/tests/failing_tests/undeclared_values.py b/tests/failing_tests/undeclared_values.py new file mode 100644 index 0000000..02f5184 --- /dev/null +++ b/tests/failing_tests/undeclared_values.py @@ -0,0 +1,21 @@ +import logging + +from pythonbpf import compile, bpf, section, bpfglobal, compile_to_ir +from ctypes import c_void_p, c_int64 + +# This should not pass as somevalue is not declared at all. +@bpf +@section("tracepoint/syscalls/sys_enter_execve") +def sometag(ctx: c_void_p) -> c_int64: + print("test") + print(f"{somevalue}") #type: ignore + return c_int64(1) + +@bpf +@bpfglobal +def LICENSE() -> str: + return "GPL" + + +compile_to_ir("globals.py", "globals.ll", loglevel=logging.INFO) +compile()