diff --git a/pythonbpf/binary_ops.py b/pythonbpf/binary_ops.py index 5ab393f..c0ca0ec 100644 --- a/pythonbpf/binary_ops.py +++ b/pythonbpf/binary_ops.py @@ -9,6 +9,7 @@ def recursive_dereferencer(var, builder): """dereference until primitive type comes out""" # TODO: Not worrying about stack overflow for now + logger.info(f"Dereferencing {var}, type is {var.type}") if isinstance(var.type, ir.PointerType): a = builder.load(var) return recursive_dereferencer(a, builder) @@ -18,7 +19,7 @@ def recursive_dereferencer(var, builder): raise TypeError(f"Unsupported type for dereferencing: {var.type}") -def get_operand_value(operand, module, builder, local_sym_tab): +def get_operand_value(operand, builder, local_sym_tab): """Extract the value from an operand, handling variables and constants.""" if isinstance(operand, ast.Name): if operand.id in local_sym_tab: @@ -29,14 +30,14 @@ def get_operand_value(operand, module, builder, local_sym_tab): return ir.Constant(ir.IntType(64), operand.value) raise TypeError(f"Unsupported constant type: {type(operand.value)}") elif isinstance(operand, ast.BinOp): - return handle_binary_op_impl(operand, module, builder, local_sym_tab) + return handle_binary_op_impl(operand, builder, local_sym_tab) raise TypeError(f"Unsupported operand type: {type(operand)}") -def handle_binary_op_impl(rval, module, builder, local_sym_tab): +def handle_binary_op_impl(rval, builder, local_sym_tab): op = rval.op - left = get_operand_value(rval.left, module, builder, local_sym_tab) - right = get_operand_value(rval.right, module, builder, local_sym_tab) + left = get_operand_value(rval.left, builder, local_sym_tab) + right = get_operand_value(rval.right, builder, local_sym_tab) logger.info(f"left is {left}, right is {right}, op is {op}") # Map AST operation nodes to LLVM IR builder methods @@ -61,6 +62,11 @@ def handle_binary_op_impl(rval, module, builder, local_sym_tab): raise SyntaxError("Unsupported binary operation") -def handle_binary_op(rval, module, builder, var_name, local_sym_tab): - result = handle_binary_op_impl(rval, module, builder, local_sym_tab) - builder.store(result, local_sym_tab[var_name].var) +def handle_binary_op(rval, builder, var_name, local_sym_tab): + result = handle_binary_op_impl(rval, builder, local_sym_tab) + if var_name and var_name in local_sym_tab: + logger.info( + f"Storing result {result} into variable {local_sym_tab[var_name].var}" + ) + builder.store(result, local_sym_tab[var_name].var) + return result, result.type diff --git a/pythonbpf/codegen.py b/pythonbpf/codegen.py index 5de23a5..cf20f06 100644 --- a/pythonbpf/codegen.py +++ b/pythonbpf/codegen.py @@ -48,7 +48,7 @@ def processor(source_code, filename, module): globals_processing(tree, module) -def compile_to_ir(filename: str, output: str, loglevel=logging.WARNING): +def compile_to_ir(filename: str, output: str, loglevel=logging.INFO): logging.basicConfig( level=loglevel, format="%(asctime)s [%(levelname)s] %(name)s: %(message)s" ) @@ -121,7 +121,7 @@ def compile_to_ir(filename: str, output: str, loglevel=logging.WARNING): return output -def compile(loglevel=logging.WARNING) -> bool: +def compile(loglevel=logging.INFO) -> bool: # Look one level up the stack to the caller of this function caller_frame = inspect.stack()[1] caller_file = Path(caller_frame.filename).resolve() @@ -154,7 +154,7 @@ def compile(loglevel=logging.WARNING) -> bool: return success -def BPF(loglevel=logging.WARNING) -> BpfProgram: +def BPF(loglevel=logging.INFO) -> BpfProgram: caller_frame = inspect.stack()[1] src = inspect.getsource(caller_frame.frame) with tempfile.NamedTemporaryFile( diff --git a/pythonbpf/functions_pass.py b/pythonbpf/functions_pass.py index 430b55f..d5537da 100644 --- a/pythonbpf/functions_pass.py +++ b/pythonbpf/functions_pass.py @@ -233,7 +233,7 @@ def handle_assign( else: logger.info("Unsupported assignment call function type") elif isinstance(rval, ast.BinOp): - handle_binary_op(rval, module, builder, var_name, local_sym_tab) + handle_binary_op(rval, builder, var_name, local_sym_tab) else: logger.info("Unsupported assignment value type") @@ -385,24 +385,49 @@ def process_stmt( ) elif isinstance(stmt, ast.Return): if stmt.value is None: - builder.ret(ir.Constant(ir.IntType(32), 0)) + builder.ret(ir.Constant(ir.IntType(64), 0)) did_return = True elif ( isinstance(stmt.value, ast.Call) and isinstance(stmt.value.func, ast.Name) and len(stmt.value.args) == 1 - and isinstance(stmt.value.args[0], ast.Constant) - and isinstance(stmt.value.args[0].value, int) ): - call_type = stmt.value.func.id - if ctypes_to_ir(call_type) != ret_type: - raise ValueError( - "Return type mismatch: expected" - f"{ctypes_to_ir(call_type)}, got {call_type}" - ) - else: - builder.ret(ir.Constant(ret_type, stmt.value.args[0].value)) + if isinstance(stmt.value.args[0], ast.Constant) and isinstance( + stmt.value.args[0].value, int + ): + call_type = stmt.value.func.id + if ctypes_to_ir(call_type) != ret_type: + raise ValueError( + "Return type mismatch: expected" + f"{ctypes_to_ir(call_type)}, got {call_type}" + ) + else: + builder.ret(ir.Constant(ret_type, stmt.value.args[0].value)) + did_return = True + elif isinstance(stmt.value.args[0], ast.BinOp): + # TODO: Should be routed through eval_expr + val = handle_binary_op(stmt.value.args[0], builder, None, local_sym_tab) + if val is None: + raise ValueError("Failed to evaluate return expression") + if val[1] != ret_type: + raise ValueError( + "Return type mismatch: expected " f"{ret_type}, got {val[1]}" + ) + builder.ret(val[0]) did_return = True + elif isinstance(stmt.value.args[0], ast.Name): + if stmt.value.args[0].id in local_sym_tab: + var = local_sym_tab[stmt.value.args[0].id].var + val = builder.load(var) + if val.type != ret_type: + raise ValueError( + "Return type mismatch: expected" + f"{ret_type}, got {val.type}" + ) + builder.ret(val) + did_return = True + else: + raise ValueError("Failed to evaluate return expression") elif isinstance(stmt.value, ast.Name): if stmt.value.id == "XDP_PASS": builder.ret(ir.Constant(ret_type, 2)) @@ -455,6 +480,9 @@ def allocate_mem( continue var_name = target.id rval = stmt.value + if var_name in local_sym_tab: + logger.info(f"Variable {var_name} already allocated") + continue if isinstance(rval, ast.Call): if isinstance(rval.func, ast.Name): call_type = rval.func.id @@ -568,7 +596,7 @@ def process_func_body( ) if not did_return: - builder.ret(ir.Constant(ir.IntType(32), 0)) + builder.ret(ir.Constant(ir.IntType(64), 0)) def process_bpf_chunk(func_node, module, return_type, map_sym_tab, structs_sym_tab): diff --git a/tests/failing_tests/direct_assign.py b/tests/failing_tests/direct_assign.py index 18ff266..a784313 100644 --- a/tests/failing_tests/direct_assign.py +++ b/tests/failing_tests/direct_assign.py @@ -4,6 +4,18 @@ from ctypes import c_void_p, c_int64 +# NOTE: I have decided to not fix this example for now. +# The issue is in line 31, where we are passing an expression. +# The update helper expects a pointer type. But the problem is +# that we must allocate the space for said pointer in the first +# basic block. As that usage is in a different basic block, we +# are unable to cast the expression to a pointer type. (as we never +# allocated space for it). +# Shall we change our space allocation logic? That allows users to +# spam the same helper with the same args, and still run out of +# stack space. So we consider this usage invalid for now. +# Might fix it later. + @bpf @map @@ -14,12 +26,12 @@ def count() -> HashMap: @bpf @section("xdp") def hello_world(ctx: c_void_p) -> c_int64: - prev = count().lookup(0) + prev = count.lookup(0) if prev: - count().update(0, prev + 1) + count.update(0, prev + 1) return XDP_PASS else: - count().update(0, 1) + count.update(0, 1) return XDP_PASS diff --git a/tests/failing_tests/named_arg.py b/tests/failing_tests/named_arg.py new file mode 100644 index 0000000..79ac830 --- /dev/null +++ b/tests/failing_tests/named_arg.py @@ -0,0 +1,40 @@ +from pythonbpf import bpf, map, section, bpfglobal, compile +from pythonbpf.helper import XDP_PASS +from pythonbpf.maps import HashMap + +from ctypes import c_void_p, c_int64 + +# NOTE: This example exposes the problems with our typing system. +# We can't do steps on line 25 and 27. +# prev is of type i64**. For prev + 1, we deref it down to i64 +# To assign it back to prev, we need to go back to i64**. +# We cannot allocate space for the intermediate type now. +# We probably need to track the ref/deref chain for each variable. + +@bpf +@map +def count() -> HashMap: + return HashMap(key=c_int64, value=c_int64, max_entries=1) + + +@bpf +@section("xdp") +def hello_world(ctx: c_void_p) -> c_int64: + prev = count.lookup(0) + if prev: + prev = prev + 1 + count.update(0, prev) + return XDP_PASS + else: + count.update(0, 1) + + return XDP_PASS + + +@bpf +@bpfglobal +def LICENSE() -> str: + return "GPL" + + +compile() diff --git a/tests/failing_tests/return.py b/tests/passing_tests/return.py similarity index 100% rename from tests/failing_tests/return.py rename to tests/passing_tests/return.py diff --git a/tests/passing_tests/var_rval.py b/tests/passing_tests/var_rval.py new file mode 100644 index 0000000..ee1735e --- /dev/null +++ b/tests/passing_tests/var_rval.py @@ -0,0 +1,20 @@ +import logging + +from pythonbpf import compile, bpf, section, bpfglobal +from ctypes import c_void_p, c_int64 + + +@bpf +@section("sometag1") +def sometag(ctx: c_void_p) -> c_int64: + a = 1 - 1 + return c_int64(a) + + +@bpf +@bpfglobal +def LICENSE() -> str: + return "GPL" + + +compile(loglevel=logging.INFO)