diff --git a/pythonbpf/binary_ops.py b/pythonbpf/binary_ops.py index e0b69f3..5ab393f 100644 --- a/pythonbpf/binary_ops.py +++ b/pythonbpf/binary_ops.py @@ -8,68 +8,59 @@ def recursive_dereferencer(var, builder): """dereference until primitive type comes out""" - if var.type == ir.PointerType(ir.PointerType(ir.IntType(64))): + # TODO: Not worrying about stack overflow for now + if isinstance(var.type, ir.PointerType): a = builder.load(var) return recursive_dereferencer(a, builder) - elif var.type == ir.PointerType(ir.IntType(64)): - a = builder.load(var) - return recursive_dereferencer(a, builder) - elif var.type == ir.IntType(64): + elif isinstance(var.type, ir.IntType): return var else: raise TypeError(f"Unsupported type for dereferencing: {var.type}") -def handle_binary_op(rval, module, builder, var_name, local_sym_tab, map_sym_tab, func): - logger.info(f"module {module}") - left = rval.left - right = rval.right - op = rval.op - - # Handle left operand - if isinstance(left, ast.Name): - if left.id in local_sym_tab: - left = recursive_dereferencer(local_sym_tab[left.id].var, builder) - else: - raise SyntaxError(f"Undefined variable: {left.id}") - elif isinstance(left, ast.Constant): - left = ir.Constant(ir.IntType(64), left.value) - else: - raise SyntaxError("Unsupported left operand type") +def get_operand_value(operand, module, builder, local_sym_tab): + """Extract the value from an operand, handling variables and constants.""" + if isinstance(operand, ast.Name): + if operand.id in local_sym_tab: + return recursive_dereferencer(local_sym_tab[operand.id].var, builder) + raise ValueError(f"Undefined variable: {operand.id}") + elif isinstance(operand, ast.Constant): + if isinstance(operand.value, int): + return ir.Constant(ir.IntType(64), operand.value) + raise TypeError(f"Unsupported constant type: {type(operand.value)}") + elif isinstance(operand, ast.BinOp): + return handle_binary_op_impl(operand, module, builder, local_sym_tab) + raise TypeError(f"Unsupported operand type: {type(operand)}") - if isinstance(right, ast.Name): - if right.id in local_sym_tab: - right = recursive_dereferencer(local_sym_tab[right.id].var, builder) - else: - raise SyntaxError(f"Undefined variable: {right.id}") - elif isinstance(right, ast.Constant): - right = ir.Constant(ir.IntType(64), right.value) - else: - raise SyntaxError("Unsupported right operand type") +def handle_binary_op_impl(rval, module, builder, local_sym_tab): + op = rval.op + left = get_operand_value(rval.left, module, builder, local_sym_tab) + right = get_operand_value(rval.right, module, builder, local_sym_tab) logger.info(f"left is {left}, right is {right}, op is {op}") - if isinstance(op, ast.Add): - builder.store(builder.add(left, right), local_sym_tab[var_name].var) - elif isinstance(op, ast.Sub): - builder.store(builder.sub(left, right), local_sym_tab[var_name].var) - elif isinstance(op, ast.Mult): - builder.store(builder.mul(left, right), local_sym_tab[var_name].var) - elif isinstance(op, ast.Div): - builder.store(builder.sdiv(left, right), local_sym_tab[var_name].var) - elif isinstance(op, ast.Mod): - builder.store(builder.srem(left, right), local_sym_tab[var_name].var) - elif isinstance(op, ast.LShift): - builder.store(builder.shl(left, right), local_sym_tab[var_name].var) - elif isinstance(op, ast.RShift): - builder.store(builder.lshr(left, right), local_sym_tab[var_name].var) - elif isinstance(op, ast.BitOr): - builder.store(builder.or_(left, right), local_sym_tab[var_name].var) - elif isinstance(op, ast.BitXor): - builder.store(builder.xor(left, right), local_sym_tab[var_name].var) - elif isinstance(op, ast.BitAnd): - builder.store(builder.and_(left, right), local_sym_tab[var_name].var) - elif isinstance(op, ast.FloorDiv): - builder.store(builder.udiv(left, right), local_sym_tab[var_name].var) + # Map AST operation nodes to LLVM IR builder methods + op_map = { + ast.Add: builder.add, + ast.Sub: builder.sub, + ast.Mult: builder.mul, + ast.Div: builder.sdiv, + ast.Mod: builder.srem, + ast.LShift: builder.shl, + ast.RShift: builder.lshr, + ast.BitOr: builder.or_, + ast.BitXor: builder.xor, + ast.BitAnd: builder.and_, + ast.FloorDiv: builder.udiv, + } + + if type(op) in op_map: + result = op_map[type(op)](left, right) + return result else: raise SyntaxError("Unsupported binary operation") + + +def handle_binary_op(rval, module, builder, var_name, local_sym_tab): + result = handle_binary_op_impl(rval, module, builder, local_sym_tab) + builder.store(result, local_sym_tab[var_name].var) diff --git a/pythonbpf/expr_pass.py b/pythonbpf/expr_pass.py index 1befbb4..40d0800 100644 --- a/pythonbpf/expr_pass.py +++ b/pythonbpf/expr_pass.py @@ -2,10 +2,92 @@ from llvmlite import ir from logging import Logger import logging +from typing import Dict logger: Logger = logging.getLogger(__name__) +def _handle_name_expr(expr: ast.Name, local_sym_tab: Dict, builder: ir.IRBuilder): + """Handle ast.Name expressions.""" + if expr.id in local_sym_tab: + var = local_sym_tab[expr.id].var + val = builder.load(var) + return val, local_sym_tab[expr.id].ir_type + else: + logger.info(f"Undefined variable {expr.id}") + return None + + +def _handle_constant_expr(expr: ast.Constant): + """Handle ast.Constant expressions.""" + if isinstance(expr.value, int): + return ir.Constant(ir.IntType(64), expr.value), ir.IntType(64) + elif isinstance(expr.value, bool): + return ir.Constant(ir.IntType(1), int(expr.value)), ir.IntType(1) + else: + logger.info("Unsupported constant type") + return None + + +def _handle_attribute_expr( + expr: ast.Attribute, + local_sym_tab: Dict, + structs_sym_tab: Dict, + builder: ir.IRBuilder, +): + """Handle ast.Attribute expressions for struct field access.""" + if isinstance(expr.value, ast.Name): + var_name = expr.value.id + attr_name = expr.attr + if var_name in local_sym_tab: + var_ptr, var_type, var_metadata = local_sym_tab[var_name] + logger.info(f"Loading attribute {attr_name} from variable {var_name}") + logger.info(f"Variable type: {var_type}, Variable ptr: {var_ptr}") + + metadata = structs_sym_tab[var_metadata] + if attr_name in metadata.fields: + gep = metadata.gep(builder, var_ptr, attr_name) + val = builder.load(gep) + field_type = metadata.field_type(attr_name) + return val, field_type + return None + + +def _handle_deref_call(expr: ast.Call, local_sym_tab: Dict, builder: ir.IRBuilder): + """Handle deref function calls.""" + logger.info(f"Handling deref {ast.dump(expr)}") + if len(expr.args) != 1: + logger.info("deref takes exactly one argument") + return None + + arg = expr.args[0] + if ( + isinstance(arg, ast.Call) + and isinstance(arg.func, ast.Name) + and arg.func.id == "deref" + ): + logger.info("Multiple deref not supported") + return None + + if isinstance(arg, ast.Name): + if arg.id in local_sym_tab: + arg_ptr = local_sym_tab[arg.id].var + else: + logger.info(f"Undefined variable {arg.id}") + return None + else: + logger.info("Unsupported argument type for deref") + return None + + if arg_ptr is None: + logger.info("Failed to evaluate deref argument") + return None + + # Load the value from pointer + val = builder.load(arg_ptr) + return val, local_sym_tab[arg.id].ir_type + + def eval_expr( func, module, @@ -17,64 +99,28 @@ def eval_expr( ): logger.info(f"Evaluating expression: {ast.dump(expr)}") if isinstance(expr, ast.Name): - if expr.id in local_sym_tab: - var = local_sym_tab[expr.id].var - val = builder.load(var) - return val, local_sym_tab[expr.id].ir_type # return value and type - else: - logger.info(f"Undefined variable {expr.id}") - return None + return _handle_name_expr(expr, local_sym_tab, builder) elif isinstance(expr, ast.Constant): - if isinstance(expr.value, int): - return ir.Constant(ir.IntType(64), expr.value), ir.IntType(64) - elif isinstance(expr.value, bool): - return ir.Constant(ir.IntType(1), int(expr.value)), ir.IntType(1) - else: - logger.info("Unsupported constant type") - return None + return _handle_constant_expr(expr) elif isinstance(expr, ast.Call): + if isinstance(expr.func, ast.Name) and expr.func.id == "deref": + return _handle_deref_call(expr, local_sym_tab, builder) + # delayed import to avoid circular dependency from pythonbpf.helper import HelperHandlerRegistry, handle_helper_call - if isinstance(expr.func, ast.Name): - # check deref - if expr.func.id == "deref": - logger.info(f"Handling deref {ast.dump(expr)}") - if len(expr.args) != 1: - logger.info("deref takes exactly one argument") - return None - arg = expr.args[0] - if ( - isinstance(arg, ast.Call) - and isinstance(arg.func, ast.Name) - and arg.func.id == "deref" - ): - logger.info("Multiple deref not supported") - return None - if isinstance(arg, ast.Name): - if arg.id in local_sym_tab: - arg = local_sym_tab[arg.id].var - else: - logger.info(f"Undefined variable {arg.id}") - return None - if arg is None: - logger.info("Failed to evaluate deref argument") - return None - # Since we are handling only name case, directly take type from sym tab - val = builder.load(arg) - return val, local_sym_tab[expr.args[0].id].ir_type - - # check for helpers - if HelperHandlerRegistry.has_handler(expr.func.id): - return handle_helper_call( - expr, - module, - builder, - func, - local_sym_tab, - map_sym_tab, - structs_sym_tab, - ) + if isinstance(expr.func, ast.Name) and HelperHandlerRegistry.has_handler( + expr.func.id + ): + return handle_helper_call( + expr, + module, + builder, + func, + local_sym_tab, + map_sym_tab, + structs_sym_tab, + ) elif isinstance(expr.func, ast.Attribute): logger.info(f"Handling method call: {ast.dump(expr.func)}") if isinstance(expr.func.value, ast.Call) and isinstance( @@ -106,19 +152,7 @@ def eval_expr( structs_sym_tab, ) elif isinstance(expr, ast.Attribute): - if isinstance(expr.value, ast.Name): - var_name = expr.value.id - attr_name = expr.attr - if var_name in local_sym_tab: - var_ptr, var_type, var_metadata = local_sym_tab[var_name] - logger.info(f"Loading attribute {attr_name} from variable {var_name}") - logger.info(f"Variable type: {var_type}, Variable ptr: {var_ptr}") - metadata = structs_sym_tab[var_metadata] - if attr_name in metadata.fields: - gep = metadata.gep(builder, var_ptr, attr_name) - val = builder.load(gep) - field_type = metadata.field_type(attr_name) - return val, field_type + return _handle_attribute_expr(expr, local_sym_tab, structs_sym_tab, builder) logger.info("Unsupported expression evaluation") return None diff --git a/pythonbpf/functions_pass.py b/pythonbpf/functions_pass.py index 6653677..430b55f 100644 --- a/pythonbpf/functions_pass.py +++ b/pythonbpf/functions_pass.py @@ -233,9 +233,7 @@ def handle_assign( else: logger.info("Unsupported assignment call function type") elif isinstance(rval, ast.BinOp): - handle_binary_op( - rval, module, builder, var_name, local_sym_tab, map_sym_tab, func - ) + handle_binary_op(rval, module, builder, var_name, local_sym_tab) else: logger.info("Unsupported assignment value type") diff --git a/tests/failing_tests/binops.py b/tests/passing_tests/binops.py similarity index 76% rename from tests/failing_tests/binops.py rename to tests/passing_tests/binops.py index 3a86765..1fdff54 100644 --- a/tests/failing_tests/binops.py +++ b/tests/passing_tests/binops.py @@ -3,9 +3,9 @@ @bpf -@section("sometag1") +@section("tracepoint/syscalls/sys_enter_sync") def sometag(ctx: c_void_p) -> c_int64: - a = 1 + 2 + 1 + a = 1 + 2 + 1 + 12 + 13 print(f"{a}") return c_int64(0) diff --git a/tests/failing_tests/binops1.py b/tests/passing_tests/binops1.py similarity index 72% rename from tests/failing_tests/binops1.py rename to tests/passing_tests/binops1.py index a3a06cc..a16158c 100644 --- a/tests/failing_tests/binops1.py +++ b/tests/passing_tests/binops1.py @@ -3,11 +3,12 @@ @bpf -@section("sometag1") +@section("tracepoint/syscalls/sys_enter_sync") def sometag(ctx: c_void_p) -> c_int64: b = 1 + 2 a = 1 + b - return c_int64(a) + print(f"{a}") + return c_int64(0) @bpf