diff --git a/pythonbpf/codegen.py b/pythonbpf/codegen.py index 078adf72..beac470c 100644 --- a/pythonbpf/codegen.py +++ b/pythonbpf/codegen.py @@ -55,11 +55,11 @@ def processor(source_code, filename, module): for func_node in bpf_chunks: logger.info(f"Found BPF function/struct: {func_node.name}") - vmlinux_proc(tree, module) + vmlinux_symtab = vmlinux_proc(tree, module) populate_global_symbol_table(tree, module) license_processing(tree, module) globals_processing(tree, module) - + print("DEBUG:", vmlinux_symtab) structs_sym_tab = structs_proc(tree, module, bpf_chunks) map_sym_tab = maps_proc(tree, module, bpf_chunks) func_proc(tree, module, bpf_chunks, map_sym_tab, structs_sym_tab) diff --git a/pythonbpf/vmlinux_parser/assignment_info.py b/pythonbpf/vmlinux_parser/assignment_info.py new file mode 100644 index 00000000..465432d5 --- /dev/null +++ b/pythonbpf/vmlinux_parser/assignment_info.py @@ -0,0 +1,36 @@ +from enum import Enum, auto +from typing import Any, Dict, List, Optional, TypedDict +from dataclasses import dataclass +import llvmlite.ir as ir + +from pythonbpf.vmlinux_parser.dependency_node import Field + + +@dataclass +class AssignmentType(Enum): + CONSTANT = auto() + STRUCT = auto() + ARRAY = auto() # probably won't be used + FUNCTION_POINTER = auto() + POINTER = auto() # again, probably won't be used + + +@dataclass +class FunctionSignature(TypedDict): + return_type: str + param_types: List[str] + varargs: bool + + +# Thew name of the assignment will be in the dict that uses this class +@dataclass +class AssignmentInfo(TypedDict): + value_type: AssignmentType + python_type: type + value: Optional[Any] + pointer_level: Optional[int] + signature: Optional[FunctionSignature] # For function pointers + # The key of the dict is the name of the field. + # Value is a tuple that contains the global variable representing that field + # along with all the information about that field as a Field type. + members: Optional[Dict[str, tuple[ir.GlobalVariable, Field]]] # For structs. diff --git a/pythonbpf/vmlinux_parser/class_handler.py b/pythonbpf/vmlinux_parser/class_handler.py index 108fa9fc..a508ff75 100644 --- a/pythonbpf/vmlinux_parser/class_handler.py +++ b/pythonbpf/vmlinux_parser/class_handler.py @@ -1,6 +1,7 @@ import logging from functools import lru_cache import importlib + from .dependency_handler import DependencyHandler from .dependency_node import DependencyNode import ctypes @@ -15,7 +16,11 @@ def get_module_symbols(module_name: str): return [name for name in dir(imported_module)], imported_module -def process_vmlinux_class(node, llvm_module, handler: DependencyHandler): +def process_vmlinux_class( + node, + llvm_module, + handler: DependencyHandler, +): symbols_in_module, imported_module = get_module_symbols("vmlinux") if node.name in symbols_in_module: vmlinux_type = getattr(imported_module, node.name) @@ -25,7 +30,10 @@ def process_vmlinux_class(node, llvm_module, handler: DependencyHandler): def process_vmlinux_post_ast( - elem_type_class, llvm_handler, handler: DependencyHandler, processing_stack=None + elem_type_class, + llvm_handler, + handler: DependencyHandler, + processing_stack=None, ): # Initialize processing stack on first call if processing_stack is None: @@ -46,7 +54,7 @@ def process_vmlinux_post_ast( logger.debug(f"Node {current_symbol_name} already processed and ready") return True - # XXX:Check it's use. It's probably not being used. + # XXX:Check its use. It's probably not being used. if current_symbol_name in processing_stack: logger.debug( f"Dependency already in processing stack for {current_symbol_name}, skipping" @@ -98,6 +106,7 @@ def process_vmlinux_post_ast( [elem_type, elem_bitfield_size] = elem_temp_list local_module_name = getattr(elem_type, "__module__", None) new_dep_node.add_field(elem_name, elem_type, ready=False) + if local_module_name == ctypes.__name__: # TODO: need to process pointer to ctype and also CFUNCTYPES here recursively. Current processing is a single dereference new_dep_node.set_field_bitfield_size(elem_name, elem_bitfield_size) @@ -226,7 +235,10 @@ def process_vmlinux_post_ast( else str(elem_type) ) process_vmlinux_post_ast( - elem_type, llvm_handler, handler, processing_stack + elem_type, + llvm_handler, + handler, + processing_stack, ) new_dep_node.set_field_ready(elem_name, True) else: @@ -237,7 +249,7 @@ def process_vmlinux_post_ast( else: raise ImportError("UNSUPPORTED Module") - logging.info( + logger.info( f"{current_symbol_name} processed and handler readiness {handler.is_ready}" ) return True diff --git a/pythonbpf/vmlinux_parser/dependency_node.py b/pythonbpf/vmlinux_parser/dependency_node.py index e266761b..dd413ad4 100644 --- a/pythonbpf/vmlinux_parser/dependency_node.py +++ b/pythonbpf/vmlinux_parser/dependency_node.py @@ -18,6 +18,31 @@ class Field: value: Any = None ready: bool = False + def __hash__(self): + """ + Create a hash based on the immutable attributes that define this field's identity. + This allows Field objects to be used as dictionary keys. + """ + # Use a tuple of the fields that uniquely identify this field + identity = ( + self.name, + id(self.type), # Use id for non-hashable types + id(self.ctype_complex_type) if self.ctype_complex_type else None, + id(self.containing_type) if self.containing_type else None, + self.type_size, + self.bitfield_size, + self.offset, + self.value if self.value else None, + ) + return hash(identity) + + def __eq__(self, other): + """ + Define equality consistent with the hash function. + Two fields are equal if they have they are the same + """ + return self is other + def set_ready(self, is_ready: bool = True) -> None: """Set the readiness state of this field.""" self.ready = is_ready diff --git a/pythonbpf/vmlinux_parser/import_detector.py b/pythonbpf/vmlinux_parser/import_detector.py index 972b1ff2..6df7a980 100644 --- a/pythonbpf/vmlinux_parser/import_detector.py +++ b/pythonbpf/vmlinux_parser/import_detector.py @@ -1,9 +1,9 @@ import ast import logging -from typing import List, Tuple, Any import importlib import inspect +from .assignment_info import AssignmentInfo, AssignmentType from .dependency_handler import DependencyHandler from .ir_gen import IRGenerator from .class_handler import process_vmlinux_class @@ -11,7 +11,7 @@ logger = logging.getLogger(__name__) -def detect_import_statement(tree: ast.AST) -> List[Tuple[str, ast.ImportFrom]]: +def detect_import_statement(tree: ast.AST) -> list[tuple[str, ast.ImportFrom]]: """ Parse AST and detect import statements from vmlinux. @@ -82,7 +82,7 @@ def vmlinux_proc(tree: ast.AST, module): # initialise dependency handler handler = DependencyHandler() # initialise assignment dictionary of name to type - assignments: dict[str, tuple[type, Any]] = {} + assignments: dict[str, AssignmentInfo] = {} if not import_statements: logger.info("No vmlinux imports found") @@ -128,20 +128,35 @@ def vmlinux_proc(tree: ast.AST, module): f"{imported_name} not found as ClassDef or Assign in vmlinux" ) - IRGenerator(module, handler) + IRGenerator(module, handler, assignments) return assignments -def process_vmlinux_assign(node, module, assignments: dict[str, tuple[type, Any]]): - # Check if this is a simple assignment with a constant value +def process_vmlinux_assign(node, module, assignments: dict[str, AssignmentInfo]): + """Process assignments from vmlinux module.""" + # Only handle single-target assignments if len(node.targets) == 1 and isinstance(node.targets[0], ast.Name): target_name = node.targets[0].id + + # Handle constant value assignments if isinstance(node.value, ast.Constant): - assignments[target_name] = (type(node.value.value), node.value.value) + # Fixed: using proper TypedDict creation syntax with named arguments + assignments[target_name] = AssignmentInfo( + value_type=AssignmentType.CONSTANT, + python_type=type(node.value.value), + value=node.value.value, + pointer_level=None, + signature=None, + members=None, + ) logger.info( f"Added assignment: {target_name} = {node.value.value!r} of type {type(node.value.value)}" ) + + # Handle other assignment types that we may need to support else: - raise ValueError(f"Unsupported assignment type for {target_name}") + logger.warning( + f"Unsupported assignment type for {target_name}: {ast.dump(node.value)}" + ) else: raise ValueError("Not a simple assignment") diff --git a/pythonbpf/vmlinux_parser/ir_gen/ir_generation.py b/pythonbpf/vmlinux_parser/ir_gen/ir_generation.py index cacd2e71..960671e1 100644 --- a/pythonbpf/vmlinux_parser/ir_gen/ir_generation.py +++ b/pythonbpf/vmlinux_parser/ir_gen/ir_generation.py @@ -1,5 +1,7 @@ import ctypes import logging + +from ..assignment_info import AssignmentInfo, AssignmentType from ..dependency_handler import DependencyHandler from .debug_info_gen import debug_info_generation from ..dependency_node import DependencyNode @@ -10,11 +12,14 @@ class IRGenerator: # get the assignments dict and add this stuff to it. - def __init__(self, llvm_module, handler: DependencyHandler, assignment=None): + def __init__(self, llvm_module, handler: DependencyHandler, assignments): self.llvm_module = llvm_module self.handler: DependencyHandler = handler self.generated: list[str] = [] self.generated_debug_info: list = [] + # Use struct_name and field_name as key instead of Field object + self.generated_field_names: dict[str, dict[str, ir.GlobalVariable]] = {} + self.assignments: dict[str, AssignmentInfo] = assignments if not handler.is_ready: raise ImportError( "Semantic analysis of vmlinux imports failed. Cannot generate IR" @@ -67,10 +72,42 @@ def struct_processor(self, struct, processing_stack=None): f"Warning: Dependency {dependency} not found in handler" ) - # Actual processor logic here after dependencies are resolved + # Generate IR first to populate field names self.generated_debug_info.append( (struct, self.gen_ir(struct, self.generated_debug_info)) ) + + # Fill the assignments dictionary with struct information + if struct.name not in self.assignments: + # Create a members dictionary for AssignmentInfo + members_dict = {} + for field_name, field in struct.fields.items(): + # Get the generated field name from our dictionary, or use field_name if not found + if ( + struct.name in self.generated_field_names + and field_name in self.generated_field_names[struct.name] + ): + field_global_variable = self.generated_field_names[struct.name][ + field_name + ] + members_dict[field_name] = (field_global_variable, field) + else: + raise ValueError( + f"llvm global name not found for struct field {field_name}" + ) + # members_dict[field_name] = (field_name, field) + + # Add struct to assignments dictionary + self.assignments[struct.name] = AssignmentInfo( + value_type=AssignmentType.STRUCT, + python_type=struct.ctype_struct, + value=None, + pointer_level=None, + signature=None, + members=members_dict, + ) + logger.info(f"Added struct assignment info for {struct.name}") + self.generated.append(struct.name) finally: @@ -85,6 +122,11 @@ def gen_ir(self, struct, generated_debug_info): struct, self.llvm_module, generated_debug_info ) field_index = 0 + + # Make sure the struct has an entry in our field names dictionary + if struct.name not in self.generated_field_names: + self.generated_field_names[struct.name] = {} + for field_name, field in struct.fields.items(): # does not take arrays and similar types into consideration yet. if field.ctype_complex_type is not None and issubclass( @@ -94,6 +136,18 @@ def gen_ir(self, struct, generated_debug_info): containing_type = field.containing_type if containing_type.__module__ == ctypes.__name__: containing_type_size = ctypes.sizeof(containing_type) + if array_size == 0: + field_co_re_name = self._struct_name_generator( + struct, field, field_index, True, 0, containing_type_size + ) + globvar = ir.GlobalVariable( + self.llvm_module, ir.IntType(64), name=field_co_re_name + ) + globvar.linkage = "external" + globvar.set_metadata("llvm.preserve.access.index", debug_info) + self.generated_field_names[struct.name][field_name] = globvar + field_index += 1 + continue for i in range(0, array_size): field_co_re_name = self._struct_name_generator( struct, field, field_index, True, i, containing_type_size @@ -103,6 +157,7 @@ def gen_ir(self, struct, generated_debug_info): ) globvar.linkage = "external" globvar.set_metadata("llvm.preserve.access.index", debug_info) + self.generated_field_names[struct.name][field_name] = globvar field_index += 1 elif field.type_size is not None: array_size = field.type_size @@ -120,6 +175,7 @@ def gen_ir(self, struct, generated_debug_info): ) globvar.linkage = "external" globvar.set_metadata("llvm.preserve.access.index", debug_info) + self.generated_field_names[struct.name][field_name] = globvar field_index += 1 else: field_co_re_name = self._struct_name_generator( @@ -131,6 +187,7 @@ def gen_ir(self, struct, generated_debug_info): ) globvar.linkage = "external" globvar.set_metadata("llvm.preserve.access.index", debug_info) + self.generated_field_names[struct.name][field_name] = globvar return debug_info def _struct_name_generator( diff --git a/tests/passing_tests/vmlinux/simple_struct_test.py b/tests/passing_tests/vmlinux/simple_struct_test.py index f47076f2..c9390c84 100644 --- a/tests/passing_tests/vmlinux/simple_struct_test.py +++ b/tests/passing_tests/vmlinux/simple_struct_test.py @@ -1,4 +1,4 @@ -from pythonbpf import bpf, section, bpfglobal, compile_to_ir, compile +from pythonbpf import bpf, section, bpfglobal, compile_to_ir from vmlinux import TASK_COMM_LEN # noqa: F401 from vmlinux import struct_trace_event_raw_sys_enter # noqa: F401 @@ -27,4 +27,3 @@ def LICENSE() -> str: compile_to_ir("simple_struct_test.py", "simple_struct_test.ll") -compile()