diff --git a/pythonbpf/allocation_pass.py b/pythonbpf/allocation_pass.py new file mode 100644 index 00000000..5ec631a2 --- /dev/null +++ b/pythonbpf/allocation_pass.py @@ -0,0 +1,191 @@ +import ast +import logging + +from llvmlite import ir +from dataclasses import dataclass +from typing import Any +from pythonbpf.helper import HelperHandlerRegistry +from pythonbpf.type_deducer import ctypes_to_ir + +logger = logging.getLogger(__name__) + + +@dataclass +class LocalSymbol: + var: ir.AllocaInstr + ir_type: ir.Type + metadata: Any = None + + def __iter__(self): + yield self.var + yield self.ir_type + yield self.metadata + + +def _is_helper_call(call_node): + """Check if a call node is a BPF helper function call.""" + if isinstance(call_node.func, ast.Name): + # Exclude print from requiring temps (handles f-strings differently) + func_name = call_node.func.id + return HelperHandlerRegistry.has_handler(func_name) and func_name != "print" + + elif isinstance(call_node.func, ast.Attribute): + return HelperHandlerRegistry.has_handler(call_node.func.attr) + + return False + + +def handle_assign_allocation(builder, stmt, local_sym_tab, structs_sym_tab): + """Handle memory allocation for assignment statements.""" + + # Validate assignment + if len(stmt.targets) != 1: + logger.warning("Multi-target assignment not supported, skipping allocation") + return + + target = stmt.targets[0] + + # Skip non-name targets (e.g., struct field assignments) + if isinstance(target, ast.Attribute): + logger.debug(f"Struct field assignment to {target.attr}, no allocation needed") + return + + if not isinstance(target, ast.Name): + logger.warning(f"Unsupported assignment target type: {type(target).__name__}") + return + + var_name = target.id + rval = stmt.value + + # Skip if already allocated + if var_name in local_sym_tab: + logger.debug(f"Variable {var_name} already allocated, skipping") + return + + # Determine type and allocate based on rval + if isinstance(rval, ast.Call): + _allocate_for_call(builder, var_name, rval, local_sym_tab, structs_sym_tab) + elif isinstance(rval, ast.Constant): + _allocate_for_constant(builder, var_name, rval, local_sym_tab) + elif isinstance(rval, ast.BinOp): + _allocate_for_binop(builder, var_name, local_sym_tab) + else: + logger.warning( + f"Unsupported assignment value type for {var_name}: {type(rval).__name__}" + ) + + +def _allocate_for_call(builder, var_name, rval, local_sym_tab, structs_sym_tab): + """Allocate memory for variable assigned from a call.""" + + if isinstance(rval.func, ast.Name): + call_type = rval.func.id + + # C type constructors + if call_type in ("c_int32", "c_int64", "c_uint32", "c_uint64"): + ir_type = ctypes_to_ir(call_type) + var = builder.alloca(ir_type, name=var_name) + var.align = ir_type.width // 8 + local_sym_tab[var_name] = LocalSymbol(var, ir_type) + logger.info(f"Pre-allocated {var_name} as {call_type}") + + # Helper functions + elif HelperHandlerRegistry.has_handler(call_type): + ir_type = ir.IntType(64) # Assume i64 return type + var = builder.alloca(ir_type, name=var_name) + var.align = 8 + local_sym_tab[var_name] = LocalSymbol(var, ir_type) + logger.info(f"Pre-allocated {var_name} for helper {call_type}") + + # Deref function + elif call_type == "deref": + ir_type = ir.IntType(64) # Assume i64 return type + var = builder.alloca(ir_type, name=var_name) + var.align = 8 + local_sym_tab[var_name] = LocalSymbol(var, ir_type) + logger.info(f"Pre-allocated {var_name} for deref") + + # Struct constructors + elif call_type in structs_sym_tab: + struct_info = structs_sym_tab[call_type] + var = builder.alloca(struct_info.ir_type, name=var_name) + local_sym_tab[var_name] = LocalSymbol(var, struct_info.ir_type, call_type) + logger.info(f"Pre-allocated {var_name} for struct {call_type}") + + else: + logger.warning(f"Unknown call type for allocation: {call_type}") + + elif isinstance(rval.func, ast.Attribute): + # Map method calls - need double allocation for ptr handling + _allocate_for_map_method(builder, var_name, local_sym_tab) + + else: + logger.warning(f"Unsupported call function type for {var_name}") + + +def _allocate_for_map_method(builder, var_name, local_sym_tab): + """Allocate memory for variable assigned from map method (double alloc).""" + + # Main variable (pointer to pointer) + ir_type = ir.PointerType(ir.IntType(64)) + var = builder.alloca(ir_type, name=var_name) + local_sym_tab[var_name] = LocalSymbol(var, ir_type) + + # Temporary variable for computed values + tmp_ir_type = ir.IntType(64) + var_tmp = builder.alloca(tmp_ir_type, name=f"{var_name}_tmp") + local_sym_tab[f"{var_name}_tmp"] = LocalSymbol(var_tmp, tmp_ir_type) + + logger.info(f"Pre-allocated {var_name} and {var_name}_tmp for map method") + + +def _allocate_for_constant(builder, var_name, rval, local_sym_tab): + """Allocate memory for variable assigned from a constant.""" + + if isinstance(rval.value, bool): + ir_type = ir.IntType(1) + var = builder.alloca(ir_type, name=var_name) + var.align = 1 + local_sym_tab[var_name] = LocalSymbol(var, ir_type) + logger.info(f"Pre-allocated {var_name} as bool") + + elif isinstance(rval.value, int): + ir_type = ir.IntType(64) + var = builder.alloca(ir_type, name=var_name) + var.align = 8 + local_sym_tab[var_name] = LocalSymbol(var, ir_type) + logger.info(f"Pre-allocated {var_name} as i64") + + elif isinstance(rval.value, str): + ir_type = ir.PointerType(ir.IntType(8)) + var = builder.alloca(ir_type, name=var_name) + var.align = 8 + local_sym_tab[var_name] = LocalSymbol(var, ir_type) + logger.info(f"Pre-allocated {var_name} as string") + + else: + logger.warning( + f"Unsupported constant type for {var_name}: {type(rval.value).__name__}" + ) + + +def _allocate_for_binop(builder, var_name, local_sym_tab): + """Allocate memory for variable assigned from a binary operation.""" + ir_type = ir.IntType(64) # Assume i64 result + var = builder.alloca(ir_type, name=var_name) + var.align = 8 + local_sym_tab[var_name] = LocalSymbol(var, ir_type) + logger.info(f"Pre-allocated {var_name} for binop result") + + +def allocate_temp_pool(builder, max_temps, local_sym_tab): + """Allocate the temporary scratch space pool for helper arguments.""" + if max_temps == 0: + return + + logger.info(f"Allocating temp pool of {max_temps} variables") + for i in range(max_temps): + temp_name = f"__helper_temp_{i}" + temp_var = builder.alloca(ir.IntType(64), name=temp_name) + temp_var.align = 8 + local_sym_tab[temp_name] = LocalSymbol(temp_var, ir.IntType(64)) diff --git a/pythonbpf/assign_pass.py b/pythonbpf/assign_pass.py new file mode 100644 index 00000000..ab091415 --- /dev/null +++ b/pythonbpf/assign_pass.py @@ -0,0 +1,108 @@ +import ast +import logging +from llvmlite import ir +from pythonbpf.expr import eval_expr + +logger = logging.getLogger(__name__) + + +def handle_struct_field_assignment( + func, module, builder, target, rval, local_sym_tab, map_sym_tab, structs_sym_tab +): + """Handle struct field assignment (obj.field = value).""" + + var_name = target.value.id + field_name = target.attr + + if var_name not in local_sym_tab: + logger.error(f"Variable '{var_name}' not found in symbol table") + return + + struct_type = local_sym_tab[var_name].metadata + struct_info = structs_sym_tab[struct_type] + + if field_name not in struct_info.fields: + logger.error(f"Field '{field_name}' not found in struct '{struct_type}'") + return + + # Get field pointer and evaluate value + field_ptr = struct_info.gep(builder, local_sym_tab[var_name].var, field_name) + val = eval_expr( + func, module, builder, rval, local_sym_tab, map_sym_tab, structs_sym_tab + ) + + if val is None: + logger.error(f"Failed to evaluate value for {var_name}.{field_name}") + return + + # TODO: Handle string assignment to char array (not a priority) + field_type = struct_info.field_type(field_name) + if isinstance(field_type, ir.ArrayType) and val[1] == ir.PointerType(ir.IntType(8)): + logger.warning( + f"String to char array assignment not implemented for {var_name}.{field_name}" + ) + return + + # Store the value + builder.store(val[0], field_ptr) + logger.info(f"Assigned to struct field {var_name}.{field_name}") + + +def handle_variable_assignment( + func, module, builder, var_name, rval, local_sym_tab, map_sym_tab, structs_sym_tab +): + """Handle single named variable assignment.""" + + if var_name not in local_sym_tab: + logger.error(f"Variable {var_name} not declared.") + return False + + var_ptr = local_sym_tab[var_name].var + var_type = local_sym_tab[var_name].ir_type + + # NOTE: Special case for struct initialization + if isinstance(rval, ast.Call) and isinstance(rval.func, ast.Name): + struct_name = rval.func.id + if struct_name in structs_sym_tab and len(rval.args) == 0: + struct_info = structs_sym_tab[struct_name] + ir_struct = struct_info.ir_type + + builder.store(ir.Constant(ir_struct, None), var_ptr) + logger.info(f"Initialized struct {struct_name} for variable {var_name}") + return True + + val_result = eval_expr( + func, module, builder, rval, local_sym_tab, map_sym_tab, structs_sym_tab + ) + if val_result is None: + logger.error(f"Failed to evaluate value for {var_name}") + return False + + val, val_type = val_result + logger.info(f"Evaluated value for {var_name}: {val} of type {val_type}, {var_type}") + if val_type != var_type: + if isinstance(val_type, ir.IntType) and isinstance(var_type, ir.IntType): + # Allow implicit int widening + if val_type.width < var_type.width: + val = builder.sext(val, var_type) + logger.info(f"Implicitly widened int for variable {var_name}") + elif val_type.width > var_type.width: + val = builder.trunc(val, var_type) + logger.info(f"Implicitly truncated int for variable {var_name}") + elif isinstance(val_type, ir.IntType) and isinstance(var_type, ir.PointerType): + # NOTE: This is assignment to a PTR_TO_MAP_VALUE_OR_NULL + logger.info( + f"Creating temporary variable for pointer assignment to {var_name}" + ) + var_ptr_tmp = local_sym_tab[f"{var_name}_tmp"].var + builder.store(val, var_ptr_tmp) + val = var_ptr_tmp + else: + logger.error( + f"Type mismatch for variable {var_name}: {val_type} vs {var_type}" + ) + return False + + builder.store(val, var_ptr) + logger.info(f"Assigned value to variable {var_name}") + return True diff --git a/pythonbpf/binary_ops.py b/pythonbpf/binary_ops.py index c0ca0ec0..6ea534b3 100644 --- a/pythonbpf/binary_ops.py +++ b/pythonbpf/binary_ops.py @@ -3,43 +3,70 @@ from logging import Logger import logging -logger: Logger = logging.getLogger(__name__) - +from pythonbpf.expr import get_base_type_and_depth, deref_to_depth, eval_expr -def recursive_dereferencer(var, builder): - """dereference until primitive type comes out""" - # TODO: Not worrying about stack overflow for now - logger.info(f"Dereferencing {var}, type is {var.type}") - if isinstance(var.type, ir.PointerType): - a = builder.load(var) - return recursive_dereferencer(a, builder) - elif isinstance(var.type, ir.IntType): - return var - else: - raise TypeError(f"Unsupported type for dereferencing: {var.type}") +logger: Logger = logging.getLogger(__name__) -def get_operand_value(operand, builder, local_sym_tab): +def get_operand_value( + func, module, operand, builder, local_sym_tab, map_sym_tab, structs_sym_tab=None +): """Extract the value from an operand, handling variables and constants.""" + logger.info(f"Getting operand value for: {ast.dump(operand)}") if isinstance(operand, ast.Name): if operand.id in local_sym_tab: - return recursive_dereferencer(local_sym_tab[operand.id].var, builder) + var = local_sym_tab[operand.id].var + var_type = var.type + base_type, depth = get_base_type_and_depth(var_type) + logger.info(f"var is {var}, base_type is {base_type}, depth is {depth}") + val = deref_to_depth(func, builder, var, depth) + return val raise ValueError(f"Undefined variable: {operand.id}") elif isinstance(operand, ast.Constant): if isinstance(operand.value, int): - return ir.Constant(ir.IntType(64), operand.value) + cst = ir.Constant(ir.IntType(64), int(operand.value)) + return cst raise TypeError(f"Unsupported constant type: {type(operand.value)}") elif isinstance(operand, ast.BinOp): - return handle_binary_op_impl(operand, builder, local_sym_tab) + res = handle_binary_op_impl( + func, module, operand, builder, local_sym_tab, map_sym_tab, structs_sym_tab + ) + return res + else: + res = eval_expr( + func, module, builder, operand, local_sym_tab, map_sym_tab, structs_sym_tab + ) + if res is None: + raise ValueError(f"Failed to evaluate call expression: {operand}") + val, _ = res + logger.info(f"Evaluated expr to {val} of type {val.type}") + base_type, depth = get_base_type_and_depth(val.type) + if depth > 0: + val = deref_to_depth(func, builder, val, depth) + return val raise TypeError(f"Unsupported operand type: {type(operand)}") -def handle_binary_op_impl(rval, builder, local_sym_tab): +def handle_binary_op_impl( + func, module, rval, builder, local_sym_tab, map_sym_tab, structs_sym_tab=None +): op = rval.op - left = get_operand_value(rval.left, builder, local_sym_tab) - right = get_operand_value(rval.right, builder, local_sym_tab) + left = get_operand_value( + func, module, rval.left, builder, local_sym_tab, map_sym_tab, structs_sym_tab + ) + right = get_operand_value( + func, module, rval.right, builder, local_sym_tab, map_sym_tab, structs_sym_tab + ) logger.info(f"left is {left}, right is {right}, op is {op}") + # NOTE: Before doing the operation, if the operands are integers + # we always extend them to i64. The assignment to LHS will take + # care of truncation if needed. + if isinstance(left.type, ir.IntType) and left.type.width < 64: + left = builder.sext(left, ir.IntType(64)) + if isinstance(right.type, ir.IntType) and right.type.width < 64: + right = builder.sext(right, ir.IntType(64)) + # Map AST operation nodes to LLVM IR builder methods op_map = { ast.Add: builder.add, @@ -62,8 +89,19 @@ def handle_binary_op_impl(rval, builder, local_sym_tab): raise SyntaxError("Unsupported binary operation") -def handle_binary_op(rval, builder, var_name, local_sym_tab): - result = handle_binary_op_impl(rval, builder, local_sym_tab) +def handle_binary_op( + func, + module, + rval, + builder, + var_name, + local_sym_tab, + map_sym_tab, + structs_sym_tab=None, +): + result = handle_binary_op_impl( + func, module, rval, builder, local_sym_tab, map_sym_tab, structs_sym_tab + ) if var_name and var_name in local_sym_tab: logger.info( f"Storing result {result} into variable {local_sym_tab[var_name].var}" diff --git a/pythonbpf/expr/__init__.py b/pythonbpf/expr/__init__.py index d58c543a..dd5b4802 100644 --- a/pythonbpf/expr/__init__.py +++ b/pythonbpf/expr/__init__.py @@ -1,4 +1,10 @@ from .expr_pass import eval_expr, handle_expr -from .type_normalization import convert_to_bool +from .type_normalization import convert_to_bool, get_base_type_and_depth, deref_to_depth -__all__ = ["eval_expr", "handle_expr", "convert_to_bool"] +__all__ = [ + "eval_expr", + "handle_expr", + "convert_to_bool", + "get_base_type_and_depth", + "deref_to_depth", +] diff --git a/pythonbpf/expr/expr_pass.py b/pythonbpf/expr/expr_pass.py index 21be1961..ecf11192 100644 --- a/pythonbpf/expr/expr_pass.py +++ b/pythonbpf/expr/expr_pass.py @@ -26,7 +26,7 @@ def _handle_constant_expr(expr: ast.Constant): if isinstance(expr.value, int) or isinstance(expr.value, bool): return ir.Constant(ir.IntType(64), int(expr.value)), ir.IntType(64) else: - logger.error("Unsupported constant type") + logger.error(f"Unsupported constant type {ast.dump(expr)}") return None @@ -176,21 +176,28 @@ def _handle_unary_op( structs_sym_tab=None, ): """Handle ast.UnaryOp expressions.""" - if not isinstance(expr.op, ast.Not): - logger.error("Only 'not' unary operator is supported") + if not isinstance(expr.op, ast.Not) and not isinstance(expr.op, ast.USub): + logger.error("Only 'not' and '-' unary operators are supported") return None - operand = eval_expr( - func, module, builder, expr.operand, local_sym_tab, map_sym_tab, structs_sym_tab + from pythonbpf.binary_ops import get_operand_value + + operand = get_operand_value( + func, module, expr.operand, builder, local_sym_tab, map_sym_tab, structs_sym_tab ) if operand is None: logger.error("Failed to evaluate operand for unary operation") return None - operand_val, operand_type = operand - true_const = ir.Constant(ir.IntType(1), 1) - result = builder.xor(convert_to_bool(builder, operand_val), true_const) - return result, ir.IntType(1) + if isinstance(expr.op, ast.Not): + true_const = ir.Constant(ir.IntType(1), 1) + result = builder.xor(convert_to_bool(builder, operand), true_const) + return result, ir.IntType(1) + elif isinstance(expr.op, ast.USub): + # Multiply by -1 + neg_one = ir.Constant(ir.IntType(64), -1) + result = builder.mul(operand, neg_one) + return result, ir.IntType(64) def _handle_and_op(func, builder, expr, local_sym_tab, map_sym_tab, structs_sym_tab): @@ -402,7 +409,16 @@ def eval_expr( elif isinstance(expr, ast.BinOp): from pythonbpf.binary_ops import handle_binary_op - return handle_binary_op(expr, builder, None, local_sym_tab) + return handle_binary_op( + func, + module, + expr, + builder, + None, + local_sym_tab, + map_sym_tab, + structs_sym_tab, + ) elif isinstance(expr, ast.Compare): return _handle_compare( func, module, builder, expr, local_sym_tab, map_sym_tab, structs_sym_tab diff --git a/pythonbpf/expr/type_normalization.py b/pythonbpf/expr/type_normalization.py index 7a2fb574..fec53a41 100644 --- a/pythonbpf/expr/type_normalization.py +++ b/pythonbpf/expr/type_normalization.py @@ -16,7 +16,7 @@ } -def _get_base_type_and_depth(ir_type): +def get_base_type_and_depth(ir_type): """Get the base type for pointer types.""" cur_type = ir_type depth = 0 @@ -26,7 +26,7 @@ def _get_base_type_and_depth(ir_type): return cur_type, depth -def _deref_to_depth(func, builder, val, target_depth): +def deref_to_depth(func, builder, val, target_depth): """Dereference a pointer to a certain depth.""" cur_val = val @@ -88,13 +88,13 @@ def _normalize_types(func, builder, lhs, rhs): logger.error(f"Type mismatch: {lhs.type} vs {rhs.type}") return None, None else: - lhs_base, lhs_depth = _get_base_type_and_depth(lhs.type) - rhs_base, rhs_depth = _get_base_type_and_depth(rhs.type) + lhs_base, lhs_depth = get_base_type_and_depth(lhs.type) + rhs_base, rhs_depth = get_base_type_and_depth(rhs.type) if lhs_base == rhs_base: if lhs_depth < rhs_depth: - rhs = _deref_to_depth(func, builder, rhs, rhs_depth - lhs_depth) + rhs = deref_to_depth(func, builder, rhs, rhs_depth - lhs_depth) elif rhs_depth < lhs_depth: - lhs = _deref_to_depth(func, builder, lhs, lhs_depth - rhs_depth) + lhs = deref_to_depth(func, builder, lhs, lhs_depth - rhs_depth) return _normalize_types(func, builder, lhs, rhs) diff --git a/pythonbpf/functions/functions_pass.py b/pythonbpf/functions/functions_pass.py index 7fc3febc..45d7b0ae 100644 --- a/pythonbpf/functions/functions_pass.py +++ b/pythonbpf/functions/functions_pass.py @@ -1,13 +1,18 @@ from llvmlite import ir import ast import logging -from typing import Any -from dataclasses import dataclass -from pythonbpf.helper import HelperHandlerRegistry, handle_helper_call +from pythonbpf.helper import ( + HelperHandlerRegistry, + reset_scratch_pool, +) from pythonbpf.type_deducer import ctypes_to_ir -from pythonbpf.binary_ops import handle_binary_op from pythonbpf.expr import eval_expr, handle_expr, convert_to_bool +from pythonbpf.assign_pass import ( + handle_variable_assignment, + handle_struct_field_assignment, +) +from pythonbpf.allocation_pass import handle_assign_allocation, allocate_temp_pool from .return_utils import _handle_none_return, _handle_xdp_return, _is_xdp_name @@ -15,18 +20,6 @@ logger = logging.getLogger(__name__) -@dataclass -class LocalSymbol: - var: ir.AllocaInstr - ir_type: ir.Type - metadata: Any = None - - def __iter__(self): - yield self.var - yield self.ir_type - yield self.metadata - - def get_probe_string(func_node): """Extract the probe string from the decorator of the function node.""" # TODO: right now we have the whole string in the section decorator @@ -48,196 +41,49 @@ def handle_assign( func, module, builder, stmt, map_sym_tab, local_sym_tab, structs_sym_tab ): """Handle assignment statements in the function body.""" + + # TODO: Support this later + # GH #37 if len(stmt.targets) != 1: - logger.info("Unsupported multiassignment") + logger.error("Multi-target assignment is not supported for now") return - num_types = ("c_int32", "c_int64", "c_uint32", "c_uint64") - target = stmt.targets[0] - logger.info(f"Handling assignment to {ast.dump(target)}") - if not isinstance(target, ast.Name) and not isinstance(target, ast.Attribute): - logger.info("Unsupported assignment target") - return - var_name = target.id if isinstance(target, ast.Name) else target.value.id rval = stmt.value + + if isinstance(target, ast.Name): + # NOTE: Simple variable assignment case: x = 5 + var_name = target.id + result = handle_variable_assignment( + func, + module, + builder, + var_name, + rval, + local_sym_tab, + map_sym_tab, + structs_sym_tab, + ) + if not result: + logger.error(f"Failed to handle assignment to {var_name}") + return + if isinstance(target, ast.Attribute): - # struct field assignment - field_name = target.attr - if var_name in local_sym_tab: - struct_type = local_sym_tab[var_name].metadata - struct_info = structs_sym_tab[struct_type] - if field_name in struct_info.fields: - field_ptr = struct_info.gep( - builder, local_sym_tab[var_name].var, field_name - ) - val = eval_expr( - func, - module, - builder, - rval, - local_sym_tab, - map_sym_tab, - structs_sym_tab, - ) - if isinstance(struct_info.field_type(field_name), ir.ArrayType) and val[ - 1 - ] == ir.PointerType(ir.IntType(8)): - # TODO: Figure it out, not a priority rn - # Special case for string assignment to char array - # str_len = struct_info["field_types"][field_idx].count - # assign_string_to_array(builder, field_ptr, val[0], str_len) - # print(f"Assigned to struct field {var_name}.{field_name}") - pass - if val is None: - logger.info("Failed to evaluate struct field assignment") - return - logger.info(field_ptr) - builder.store(val[0], field_ptr) - logger.info(f"Assigned to struct field {var_name}.{field_name}") - return - elif isinstance(rval, ast.Constant): - if isinstance(rval.value, bool): - if rval.value: - builder.store( - ir.Constant(ir.IntType(1), 1), local_sym_tab[var_name].var - ) - else: - builder.store( - ir.Constant(ir.IntType(1), 0), local_sym_tab[var_name].var - ) - logger.info(f"Assigned constant {rval.value} to {var_name}") - elif isinstance(rval.value, int): - # Assume c_int64 for now - # var = builder.alloca(ir.IntType(64), name=var_name) - # var.align = 8 - builder.store( - ir.Constant(ir.IntType(64), rval.value), local_sym_tab[var_name].var - ) - logger.info(f"Assigned constant {rval.value} to {var_name}") - elif isinstance(rval.value, str): - str_val = rval.value.encode("utf-8") + b"\x00" - str_const = ir.Constant( - ir.ArrayType(ir.IntType(8), len(str_val)), bytearray(str_val) - ) - global_str = ir.GlobalVariable( - module, str_const.type, name=f"{var_name}_str" - ) - global_str.linkage = "internal" - global_str.global_constant = True - global_str.initializer = str_const - str_ptr = builder.bitcast(global_str, ir.PointerType(ir.IntType(8))) - builder.store(str_ptr, local_sym_tab[var_name].var) - logger.info(f"Assigned string constant '{rval.value}' to {var_name}") - else: - logger.info("Unsupported constant type") - elif isinstance(rval, ast.Call): - if isinstance(rval.func, ast.Name): - call_type = rval.func.id - logger.info(f"Assignment call type: {call_type}") - if ( - call_type in num_types - and len(rval.args) == 1 - and isinstance(rval.args[0], ast.Constant) - and isinstance(rval.args[0].value, int) - ): - ir_type = ctypes_to_ir(call_type) - # var = builder.alloca(ir_type, name=var_name) - # var.align = ir_type.width // 8 - builder.store( - ir.Constant(ir_type, rval.args[0].value), - local_sym_tab[var_name].var, - ) - logger.info( - f"Assigned {call_type} constant {rval.args[0].value} to {var_name}" - ) - elif HelperHandlerRegistry.has_handler(call_type): - # var = builder.alloca(ir.IntType(64), name=var_name) - # var.align = 8 - val = handle_helper_call( - rval, - module, - builder, - func, - local_sym_tab, - map_sym_tab, - structs_sym_tab, - ) - builder.store(val[0], local_sym_tab[var_name].var) - logger.info(f"Assigned constant {rval.func.id} to {var_name}") - elif call_type == "deref" and len(rval.args) == 1: - logger.info(f"Handling deref assignment {ast.dump(rval)}") - val = eval_expr( - func, - module, - builder, - rval, - local_sym_tab, - map_sym_tab, - structs_sym_tab, - ) - if val is None: - logger.info("Failed to evaluate deref argument") - return - logger.info(f"Dereferenced value: {val}, storing in {var_name}") - builder.store(val[0], local_sym_tab[var_name].var) - logger.info(f"Dereferenced and assigned to {var_name}") - elif call_type in structs_sym_tab and len(rval.args) == 0: - struct_info = structs_sym_tab[call_type] - ir_type = struct_info.ir_type - # var = builder.alloca(ir_type, name=var_name) - # Null init - builder.store(ir.Constant(ir_type, None), local_sym_tab[var_name].var) - logger.info(f"Assigned struct {call_type} to {var_name}") - else: - logger.info(f"Unsupported assignment call type: {call_type}") - elif isinstance(rval.func, ast.Attribute): - logger.info(f"Assignment call attribute: {ast.dump(rval.func)}") - if isinstance(rval.func.value, ast.Name): - if rval.func.value.id in map_sym_tab: - map_name = rval.func.value.id - method_name = rval.func.attr - if HelperHandlerRegistry.has_handler(method_name): - val = handle_helper_call( - rval, - module, - builder, - func, - local_sym_tab, - map_sym_tab, - structs_sym_tab, - ) - builder.store(val[0], local_sym_tab[var_name].var) - else: - # TODO: probably a struct access - logger.info(f"TODO STRUCT ACCESS {ast.dump(rval)}") - elif isinstance(rval.func.value, ast.Call) and isinstance( - rval.func.value.func, ast.Name - ): - map_name = rval.func.value.func.id - method_name = rval.func.attr - if map_name in map_sym_tab: - if HelperHandlerRegistry.has_handler(method_name): - val = handle_helper_call( - rval, - module, - builder, - func, - local_sym_tab, - map_sym_tab, - structs_sym_tab, - ) - # var = builder.alloca(ir.IntType(64), name=var_name) - # var.align = 8 - builder.store(val[0], local_sym_tab[var_name].var) - else: - logger.info("Unsupported assignment call structure") - else: - logger.info("Unsupported assignment call function type") - elif isinstance(rval, ast.BinOp): - handle_binary_op(rval, builder, var_name, local_sym_tab) - else: - logger.info("Unsupported assignment value type") + # NOTE: Struct field assignment case: pkt.field = value + handle_struct_field_assignment( + func, + module, + builder, + target, + rval, + local_sym_tab, + map_sym_tab, + structs_sym_tab, + ) + return + + # Unsupported target type + logger.error(f"Unsupported assignment target: {ast.dump(target)}") def handle_cond( @@ -330,6 +176,7 @@ def process_stmt( ret_type=ir.IntType(64), ): logger.info(f"Processing statement: {ast.dump(stmt)}") + reset_scratch_pool() if isinstance(stmt, ast.Expr): handle_expr( func, @@ -360,119 +207,107 @@ def process_stmt( return did_return +def handle_if_allocation( + module, builder, stmt, func, ret_type, map_sym_tab, local_sym_tab, structs_sym_tab +): + """Recursively handle allocations in if/else branches.""" + if stmt.body: + allocate_mem( + module, + builder, + stmt.body, + func, + ret_type, + map_sym_tab, + local_sym_tab, + structs_sym_tab, + ) + if stmt.orelse: + allocate_mem( + module, + builder, + stmt.orelse, + func, + ret_type, + map_sym_tab, + local_sym_tab, + structs_sym_tab, + ) + + +def count_temps_in_call(call_node, local_sym_tab): + """Count the number of temporary variables needed for a function call.""" + + count = 0 + is_helper = False + + # NOTE: We exclude print calls for now + if isinstance(call_node.func, ast.Name): + if ( + HelperHandlerRegistry.has_handler(call_node.func.id) + and call_node.func.id != "print" + ): + is_helper = True + elif isinstance(call_node.func, ast.Attribute): + if HelperHandlerRegistry.has_handler(call_node.func.attr): + is_helper = True + + if not is_helper: + return 0 + + for arg in call_node.args: + # NOTE: Count all non-name arguments + # For struct fields, if it is being passed as an argument, + # The struct object should already exist in the local_sym_tab + if not isinstance(arg, ast.Name) and not ( + isinstance(arg, ast.Attribute) and arg.value.id in local_sym_tab + ): + count += 1 + + return count + + def allocate_mem( module, builder, body, func, ret_type, map_sym_tab, local_sym_tab, structs_sym_tab ): + max_temps_needed = 0 + + def update_max_temps_for_stmt(stmt): + nonlocal max_temps_needed + temps_needed = 0 + + if isinstance(stmt, ast.If): + for s in stmt.body: + update_max_temps_for_stmt(s) + for s in stmt.orelse: + update_max_temps_for_stmt(s) + return + + for node in ast.walk(stmt): + if isinstance(node, ast.Call): + temps_needed += count_temps_in_call(node, local_sym_tab) + max_temps_needed = max(max_temps_needed, temps_needed) + for stmt in body: - has_metadata = False + update_max_temps_for_stmt(stmt) + + # Handle allocations if isinstance(stmt, ast.If): - if stmt.body: - local_sym_tab = allocate_mem( - module, - builder, - stmt.body, - func, - ret_type, - map_sym_tab, - local_sym_tab, - structs_sym_tab, - ) - if stmt.orelse: - local_sym_tab = allocate_mem( - module, - builder, - stmt.orelse, - func, - ret_type, - map_sym_tab, - local_sym_tab, - structs_sym_tab, - ) + handle_if_allocation( + module, + builder, + stmt, + func, + ret_type, + map_sym_tab, + local_sym_tab, + structs_sym_tab, + ) elif isinstance(stmt, ast.Assign): - if len(stmt.targets) != 1: - logger.info("Unsupported multiassignment") - continue - target = stmt.targets[0] - if not isinstance(target, ast.Name): - logger.info("Unsupported assignment target") - continue - var_name = target.id - rval = stmt.value - if var_name in local_sym_tab: - logger.info(f"Variable {var_name} already allocated") - continue - if isinstance(rval, ast.Call): - if isinstance(rval.func, ast.Name): - call_type = rval.func.id - if call_type in ("c_int32", "c_int64", "c_uint32", "c_uint64"): - ir_type = ctypes_to_ir(call_type) - var = builder.alloca(ir_type, name=var_name) - var.align = ir_type.width // 8 - logger.info( - f"Pre-allocated variable {var_name} of type {call_type}" - ) - elif HelperHandlerRegistry.has_handler(call_type): - # Assume return type is int64 for now - ir_type = ir.IntType(64) - var = builder.alloca(ir_type, name=var_name) - var.align = ir_type.width // 8 - logger.info(f"Pre-allocated variable {var_name} for helper") - elif call_type == "deref" and len(rval.args) == 1: - # Assume return type is int64 for now - ir_type = ir.IntType(64) - var = builder.alloca(ir_type, name=var_name) - var.align = ir_type.width // 8 - logger.info(f"Pre-allocated variable {var_name} for deref") - elif call_type in structs_sym_tab: - struct_info = structs_sym_tab[call_type] - ir_type = struct_info.ir_type - var = builder.alloca(ir_type, name=var_name) - has_metadata = True - logger.info( - f"Pre-allocated variable {var_name} for struct {call_type}" - ) - elif isinstance(rval.func, ast.Attribute): - ir_type = ir.PointerType(ir.IntType(64)) - var = builder.alloca(ir_type, name=var_name) - # var.align = ir_type.width // 8 - logger.info(f"Pre-allocated variable {var_name} for map") - else: - logger.info("Unsupported assignment call function type") - continue - elif isinstance(rval, ast.Constant): - if isinstance(rval.value, bool): - ir_type = ir.IntType(1) - var = builder.alloca(ir_type, name=var_name) - var.align = 1 - logger.info(f"Pre-allocated variable {var_name} of type c_bool") - elif isinstance(rval.value, int): - # Assume c_int64 for now - ir_type = ir.IntType(64) - var = builder.alloca(ir_type, name=var_name) - var.align = ir_type.width // 8 - logger.info(f"Pre-allocated variable {var_name} of type c_int64") - elif isinstance(rval.value, str): - ir_type = ir.PointerType(ir.IntType(8)) - var = builder.alloca(ir_type, name=var_name) - var.align = 8 - logger.info(f"Pre-allocated variable {var_name} of type string") - else: - logger.info("Unsupported constant type") - continue - elif isinstance(rval, ast.BinOp): - # Assume c_int64 for now - ir_type = ir.IntType(64) - var = builder.alloca(ir_type, name=var_name) - var.align = ir_type.width // 8 - logger.info(f"Pre-allocated variable {var_name} of type c_int64") - else: - logger.info("Unsupported assignment value type") - continue - - if has_metadata: - local_sym_tab[var_name] = LocalSymbol(var, ir_type, call_type) - else: - local_sym_tab[var_name] = LocalSymbol(var, ir_type) + handle_assign_allocation(builder, stmt, local_sym_tab, structs_sym_tab) + + allocate_temp_pool(builder, max_temps_needed, local_sym_tab) + return local_sym_tab diff --git a/pythonbpf/helper/__init__.py b/pythonbpf/helper/__init__.py index a7ad1697..007724f7 100644 --- a/pythonbpf/helper/__init__.py +++ b/pythonbpf/helper/__init__.py @@ -1,9 +1,10 @@ -from .helper_utils import HelperHandlerRegistry +from .helper_utils import HelperHandlerRegistry, reset_scratch_pool from .bpf_helper_handler import handle_helper_call from .helpers import ktime, pid, deref, XDP_DROP, XDP_PASS __all__ = [ "HelperHandlerRegistry", + "reset_scratch_pool", "handle_helper_call", "ktime", "pid", diff --git a/pythonbpf/helper/bpf_helper_handler.py b/pythonbpf/helper/bpf_helper_handler.py index f5ae9a0a..79cbf266 100644 --- a/pythonbpf/helper/bpf_helper_handler.py +++ b/pythonbpf/helper/bpf_helper_handler.py @@ -34,6 +34,7 @@ def bpf_ktime_get_ns_emitter( func, local_sym_tab=None, struct_sym_tab=None, + map_sym_tab=None, ): """ Emit LLVM IR for bpf_ktime_get_ns helper function call. @@ -56,6 +57,7 @@ def bpf_map_lookup_elem_emitter( func, local_sym_tab=None, struct_sym_tab=None, + map_sym_tab=None, ): """ Emit LLVM IR for bpf_map_lookup_elem helper function call. @@ -64,11 +66,17 @@ def bpf_map_lookup_elem_emitter( raise ValueError( f"Map lookup expects exactly one argument (key), got {len(call.args)}" ) - key_ptr = get_or_create_ptr_from_arg(call.args[0], builder, local_sym_tab) + key_ptr = get_or_create_ptr_from_arg( + func, module, call.args[0], builder, local_sym_tab, map_sym_tab, struct_sym_tab + ) map_void_ptr = builder.bitcast(map_ptr, ir.PointerType()) + # TODO: I have changed the return type to i64*, as we are + # allocating space for that type in allocate_mem. This is + # temporary, and we will honour other widths later. But this + # allows us to have cool binary ops on the returned value. fn_type = ir.FunctionType( - ir.PointerType(), # Return type: void* + ir.PointerType(ir.IntType(64)), # Return type: void* [ir.PointerType(), ir.PointerType()], # Args: (void*, void*) var_arg=False, ) @@ -91,6 +99,7 @@ def bpf_printk_emitter( func, local_sym_tab=None, struct_sym_tab=None, + map_sym_tab=None, ): """Emit LLVM IR for bpf_printk helper function call.""" if not hasattr(func, "_fmt_counter"): @@ -138,6 +147,7 @@ def bpf_map_update_elem_emitter( func, local_sym_tab=None, struct_sym_tab=None, + map_sym_tab=None, ): """ Emit LLVM IR for bpf_map_update_elem helper function call. @@ -152,8 +162,12 @@ def bpf_map_update_elem_emitter( value_arg = call.args[1] flags_arg = call.args[2] if len(call.args) > 2 else None - key_ptr = get_or_create_ptr_from_arg(key_arg, builder, local_sym_tab) - value_ptr = get_or_create_ptr_from_arg(value_arg, builder, local_sym_tab) + key_ptr = get_or_create_ptr_from_arg( + func, module, key_arg, builder, local_sym_tab, map_sym_tab, struct_sym_tab + ) + value_ptr = get_or_create_ptr_from_arg( + func, module, value_arg, builder, local_sym_tab, map_sym_tab, struct_sym_tab + ) flags_val = get_flags_val(flags_arg, builder, local_sym_tab) map_void_ptr = builder.bitcast(map_ptr, ir.PointerType()) @@ -188,6 +202,7 @@ def bpf_map_delete_elem_emitter( func, local_sym_tab=None, struct_sym_tab=None, + map_sym_tab=None, ): """ Emit LLVM IR for bpf_map_delete_elem helper function call. @@ -197,7 +212,9 @@ def bpf_map_delete_elem_emitter( raise ValueError( f"Map delete expects exactly one argument (key), got {len(call.args)}" ) - key_ptr = get_or_create_ptr_from_arg(call.args[0], builder, local_sym_tab) + key_ptr = get_or_create_ptr_from_arg( + func, module, call.args[0], builder, local_sym_tab, map_sym_tab, struct_sym_tab + ) map_void_ptr = builder.bitcast(map_ptr, ir.PointerType()) # Define function type for bpf_map_delete_elem @@ -225,6 +242,7 @@ def bpf_get_current_pid_tgid_emitter( func, local_sym_tab=None, struct_sym_tab=None, + map_sym_tab=None, ): """ Emit LLVM IR for bpf_get_current_pid_tgid helper function call. @@ -251,6 +269,7 @@ def bpf_perf_event_output_handler( func, local_sym_tab=None, struct_sym_tab=None, + map_sym_tab=None, ): if len(call.args) != 1: raise ValueError( @@ -315,6 +334,7 @@ def invoke_helper(method_name, map_ptr=None): func, local_sym_tab, struct_sym_tab, + map_sym_tab, ) # Handle direct function calls (e.g., print(), ktime()) diff --git a/pythonbpf/helper/helper_utils.py b/pythonbpf/helper/helper_utils.py index 68ab52cd..284aa686 100644 --- a/pythonbpf/helper/helper_utils.py +++ b/pythonbpf/helper/helper_utils.py @@ -3,7 +3,8 @@ from collections.abc import Callable from llvmlite import ir -from pythonbpf.expr import eval_expr +from pythonbpf.expr import eval_expr, get_base_type_and_depth, deref_to_depth +from pythonbpf.binary_ops import get_operand_value logger = logging.getLogger(__name__) @@ -34,6 +35,41 @@ def has_handler(cls, helper_name): return helper_name in cls._handlers +class ScratchPoolManager: + """Manage the temporary helper variables in local_sym_tab""" + + def __init__(self): + self._counter = 0 + + @property + def counter(self): + return self._counter + + def reset(self): + self._counter = 0 + logger.debug("Scratch pool counter reset to 0") + + def get_next_temp(self, local_sym_tab): + temp_name = f"__helper_temp_{self._counter}" + self._counter += 1 + + if temp_name not in local_sym_tab: + raise ValueError( + f"Scratch pool exhausted or inadequate: {temp_name}. " + f"Current counter: {self._counter}" + ) + + return local_sym_tab[temp_name].var, temp_name + + +_temp_pool_manager = ScratchPoolManager() # Singleton instance + + +def reset_scratch_pool(): + """Reset the scratch pool counter""" + _temp_pool_manager.reset() + + def get_var_ptr_from_name(var_name, local_sym_tab): """Get a pointer to a variable from the symbol table.""" if local_sym_tab and var_name in local_sym_tab: @@ -41,27 +77,41 @@ def get_var_ptr_from_name(var_name, local_sym_tab): raise ValueError(f"Variable '{var_name}' not found in local symbol table") -def create_int_constant_ptr(value, builder, int_width=64): +def create_int_constant_ptr(value, builder, local_sym_tab, int_width=64): """Create a pointer to an integer constant.""" + # Default to 64-bit integer - int_type = ir.IntType(int_width) - ptr = builder.alloca(int_type) - ptr.align = int_type.width // 8 - builder.store(ir.Constant(int_type, value), ptr) + ptr, temp_name = _temp_pool_manager.get_next_temp(local_sym_tab) + logger.info(f"Using temp variable '{temp_name}' for int constant {value}") + const_val = ir.Constant(ir.IntType(int_width), value) + builder.store(const_val, ptr) return ptr -def get_or_create_ptr_from_arg(arg, builder, local_sym_tab): +def get_or_create_ptr_from_arg( + func, module, arg, builder, local_sym_tab, map_sym_tab, struct_sym_tab=None +): """Extract or create pointer from the call arguments.""" if isinstance(arg, ast.Name): ptr = get_var_ptr_from_name(arg.id, local_sym_tab) elif isinstance(arg, ast.Constant) and isinstance(arg.value, int): - ptr = create_int_constant_ptr(arg.value, builder) + ptr = create_int_constant_ptr(arg.value, builder, local_sym_tab) else: - raise NotImplementedError( - "Only simple variable names are supported as args in map helpers." + # Evaluate the expression and store the result in a temp variable + val = get_operand_value( + func, module, arg, builder, local_sym_tab, map_sym_tab, struct_sym_tab ) + if val is None: + raise ValueError("Failed to evaluate expression for helper arg.") + + # NOTE: We assume the result is an int64 for now + # if isinstance(arg, ast.Attribute): + # return val + ptr, temp_name = _temp_pool_manager.get_next_temp(local_sym_tab) + logger.info(f"Using temp variable '{temp_name}' for expression result") + builder.store(val, ptr) + return ptr @@ -224,10 +274,27 @@ def _populate_fval(ftype, node, fmt_parts, exprs): raise NotImplementedError( f"Unsupported integer width in f-string: {ftype.width}" ) - elif ftype == ir.PointerType(ir.IntType(8)): - # NOTE: We assume i8* is a string - fmt_parts.append("%s") - exprs.append(node) + elif isinstance(ftype, ir.PointerType): + target, depth = get_base_type_and_depth(ftype) + if isinstance(target, ir.IntType): + if target.width == 64: + fmt_parts.append("%lld") + exprs.append(node) + elif target.width == 32: + fmt_parts.append("%d") + exprs.append(node) + elif target.width == 8 and depth == 1: + # NOTE: Assume i8* is a string + fmt_parts.append("%s") + exprs.append(node) + else: + raise NotImplementedError( + f"Unsupported pointer target type in f-string: {target}" + ) + else: + raise NotImplementedError( + f"Unsupported pointer target type in f-string: {target}" + ) else: raise NotImplementedError(f"Unsupported field type in f-string: {ftype}") @@ -264,7 +331,20 @@ def _prepare_expr_args(expr, func, module, builder, local_sym_tab, struct_sym_ta if val: if isinstance(val.type, ir.PointerType): - val = builder.ptrtoint(val, ir.IntType(64)) + target, depth = get_base_type_and_depth(val.type) + if isinstance(target, ir.IntType): + if target.width >= 32: + val = deref_to_depth(func, builder, val, depth) + val = builder.sext(val, ir.IntType(64)) + elif target.width == 8 and depth == 1: + # NOTE: i8* is string, no need to deref + pass + + else: + logger.warning( + "Only int and ptr supported in bpf_printk args. Others default to 0." + ) + val = ir.Constant(ir.IntType(64), 0) elif isinstance(val.type, ir.IntType): if val.type.width < 64: val = builder.sext(val, ir.IntType(64)) diff --git a/tests/failing_tests/assign/retype.py b/tests/failing_tests/assign/retype.py new file mode 100644 index 00000000..b4fc04ed --- /dev/null +++ b/tests/failing_tests/assign/retype.py @@ -0,0 +1,39 @@ +from pythonbpf import bpf, map, section, bpfglobal, compile +from ctypes import c_void_p, c_int64, c_uint64 +from pythonbpf.maps import HashMap + + +# NOTE: This example tries to reinterpret the variable `x` to a different type. +# We do not allow this for now, as stack allocations are typed and have to be +# done in the first basic block. Allowing re-interpretation would require +# re-allocation of stack space (possibly in a new basic block), which is not +# supported in eBPF yet. +# We can allow bitcasts in cases where the width of the types is the same in +# the future. But for now, we do not allow any re-interpretation of variables. + +@bpf +@map +def last() -> HashMap: + return HashMap(key=c_uint64, value=c_uint64, max_entries=3) + + +@bpf +@section("tracepoint/syscalls/sys_enter_execve") +def hello_world(ctx: c_void_p) -> c_int64: + last.update(0, 1) + x = last.lookup(0) + x = 20 + if x == 2: + print("Hello, World!") + else: + print("Goodbye, World!") + return + + +@bpf +@bpfglobal +def LICENSE() -> str: + return "GPL" + + +compile() diff --git a/tests/passing_tests/assign/comprehensive.py b/tests/passing_tests/assign/comprehensive.py new file mode 100644 index 00000000..6e53a3f5 --- /dev/null +++ b/tests/passing_tests/assign/comprehensive.py @@ -0,0 +1,69 @@ +from pythonbpf import bpf, map, section, bpfglobal, compile, struct +from ctypes import c_void_p, c_int64, c_int32, c_uint64 +from pythonbpf.maps import HashMap +from pythonbpf.helper import ktime + + +# NOTE: This is a comprehensive test combining struct, helper, and map features +# Please note that at line 50, though we have used an absurd expression to test +# the compiler, it is recommended to use named variables to reduce the amount of +# scratch space that needs to be allocated. + +@bpf +@struct +class data_t: + pid: c_uint64 + ts: c_uint64 + + +@bpf +@map +def last() -> HashMap: + return HashMap(key=c_uint64, value=c_uint64, max_entries=3) + + +@bpf +@section("tracepoint/syscalls/sys_enter_execve") +def hello_world(ctx: c_void_p) -> c_int64: + dat = data_t() + dat.pid = 123 + dat.pid = dat.pid + 1 + print(f"pid is {dat.pid}") + tu = 9 + last.update(0, tu) + last.update(1, -last.lookup(0)) + x = last.lookup(0) + print(f"Map value at index 0: {x}") + x = x + c_int32(1) + print(f"x after adding 32-bit 1 is {x}") + x = ktime() - 121 + print(f"ktime - 121 is {x}") + x = last.lookup(0) + x = x + 1 + print(f"x is {x}") + if x == 10: + jat = data_t() + jat.ts = 456 + print(f"Hello, World!, ts is {jat.ts}") + a = last.lookup(0) + print(f"a is {a}") + last.update(9, 9) + last.update(0, last.lookup(last.lookup(0)) + + last.lookup(last.lookup(0)) + last.lookup(last.lookup(0))) + z = last.lookup(0) + print(f"new map val at index 0 is {z}") + else: + a = last.lookup(0) + print("Goodbye, World!") + c = last.lookup(1 - 1) + print(f"c is {c}") + return + + +@bpf +@bpfglobal +def LICENSE() -> str: + return "GPL" + + +compile() diff --git a/tests/passing_tests/assign/cst_var_binop.py b/tests/passing_tests/assign/cst_var_binop.py new file mode 100644 index 00000000..957e6783 --- /dev/null +++ b/tests/passing_tests/assign/cst_var_binop.py @@ -0,0 +1,27 @@ +from pythonbpf import bpf, section, bpfglobal, compile +from ctypes import c_void_p, c_int64 + + +@bpf +@section("tracepoint/syscalls/sys_enter_execve") +def hello_world(ctx: c_void_p) -> c_int64: + x = 1 + print(f"Initial x: {x}") + a = 20 + x = a + print(f"Updated x with a: {x}") + x = (x + x) * 3 + if x == 2: + print("Hello, World!") + else: + print(f"Goodbye, World! {x}") + return + + +@bpf +@bpfglobal +def LICENSE() -> str: + return "GPL" + + +compile() diff --git a/tests/passing_tests/assign/helper.py b/tests/passing_tests/assign/helper.py new file mode 100644 index 00000000..9809a9c7 --- /dev/null +++ b/tests/passing_tests/assign/helper.py @@ -0,0 +1,34 @@ +from pythonbpf import bpf, map, section, bpfglobal, compile +from ctypes import c_void_p, c_int64, c_uint64 +from pythonbpf.maps import HashMap + +# NOTE: An example of i64** assignment with binops on the RHS + + +@bpf +@map +def last() -> HashMap: + return HashMap(key=c_uint64, value=c_uint64, max_entries=3) + + +@bpf +@section("tracepoint/syscalls/sys_enter_execve") +def hello_world(ctx: c_void_p) -> c_int64: + last.update(0, 1) + x = last.lookup(0) + print(f"{x}") + x = x + 1 + if x == 2: + print("Hello, World!") + else: + print("Goodbye, World!") + return + + +@bpf +@bpfglobal +def LICENSE() -> str: + return "GPL" + + +compile() diff --git a/tests/passing_tests/assign/struct_and_helper_binops.py b/tests/passing_tests/assign/struct_and_helper_binops.py new file mode 100644 index 00000000..7e75de6e --- /dev/null +++ b/tests/passing_tests/assign/struct_and_helper_binops.py @@ -0,0 +1,40 @@ +from pythonbpf import bpf, section, bpfglobal, compile, struct +from ctypes import c_void_p, c_int64, c_uint64 +from pythonbpf.helper import ktime + + +@bpf +@struct +class data_t: + pid: c_uint64 + ts: c_uint64 + + +@bpf +@section("tracepoint/syscalls/sys_enter_execve") +def hello_world(ctx: c_void_p) -> c_int64: + dat = data_t() + dat.pid = 123 + dat.pid = dat.pid + 1 + print(f"pid is {dat.pid}") + x = ktime() - 121 + print(f"ktime is {x}") + x = 1 + x = x + 1 + print(f"x is {x}") + if x == 2: + jat = data_t() + jat.ts = 456 + print(f"Hello, World!, ts is {jat.ts}") + else: + print("Goodbye, World!") + return + + +@bpf +@bpfglobal +def LICENSE() -> str: + return "GPL" + + +compile()