diff --git a/pythonbpf/bpf_helper_handler.py b/pythonbpf/bpf_helper_handler.py index 59b5a86..ceff09a 100644 --- a/pythonbpf/bpf_helper_handler.py +++ b/pythonbpf/bpf_helper_handler.py @@ -113,9 +113,9 @@ def bpf_printk_emitter(call, map_ptr, module, builder, func, local_sym_tab=None, var_type = local_var_metadata[var_name] if var_type in struct_sym_tab: struct_info = struct_sym_tab[var_type] - if field_name in struct_info["fields"]: - field_index = struct_info["fields"][field_name] - field_type = struct_info["field_types"][field_index] + if field_name in struct_info.fields: + field_type = struct_info.field_type( + field_name) if isinstance(field_type, ir.IntType): fmt_parts.append("%lld") exprs.append(value.value) @@ -408,7 +408,7 @@ def bpf_perf_event_output_handler(call, map_ptr, module, builder, func, local_sy data_type = local_var_metadata[data_name] if data_type in struct_sym_tab: struct_info = struct_sym_tab[data_type] - size_val = ir.Constant(ir.IntType(64), struct_info["size"]) + size_val = ir.Constant(ir.IntType(64), struct_info.size) else: raise ValueError( f"Struct type {data_type} for variable {data_name} not found in struct symbol table.") diff --git a/pythonbpf/codegen.py b/pythonbpf/codegen.py index db5f05f..7eaa713 100644 --- a/pythonbpf/codegen.py +++ b/pythonbpf/codegen.py @@ -3,7 +3,7 @@ from .license_pass import license_processing from .functions_pass import func_proc from .maps_pass import maps_proc -from .structs_pass import structs_proc +from .structs.structs_pass import structs_proc from .globals_pass import globals_processing import os import subprocess diff --git a/pythonbpf/expr_pass.py b/pythonbpf/expr_pass.py index 7f8dbcf..db3ffbb 100644 --- a/pythonbpf/expr_pass.py +++ b/pythonbpf/expr_pass.py @@ -79,12 +79,10 @@ def eval_expr(func, module, builder, expr, local_sym_tab, map_sym_tab, structs_s print(local_var_metadata) if local_var_metadata and var_name in local_var_metadata: metadata = structs_sym_tab[local_var_metadata[var_name]] - if attr_name in metadata["fields"]: - field_idx = metadata["fields"][attr_name] - gep = builder.gep(var_ptr, [ir.Constant(ir.IntType(32), 0), - ir.Constant(ir.IntType(32), field_idx)]) + if attr_name in metadata.fields: + gep = metadata.gep(builder, var_ptr, attr_name) val = builder.load(gep) - field_type = metadata["field_types"][field_idx] + field_type = metadata.field_type(attr_name) return val, field_type print("Unsupported expression evaluation") return None diff --git a/pythonbpf/functions_pass.py b/pythonbpf/functions_pass.py index 145f623..0b99546 100644 --- a/pythonbpf/functions_pass.py +++ b/pythonbpf/functions_pass.py @@ -49,21 +49,17 @@ def handle_assign(func, module, builder, stmt, map_sym_tab, local_sym_tab, struc struct_type = local_var_metadata[var_name] struct_info = structs_sym_tab[struct_type] - if field_name in struct_info["fields"]: - field_idx = struct_info["fields"][field_name] - struct_ptr = local_sym_tab[var_name][0] - field_ptr = builder.gep( - struct_ptr, [ir.Constant(ir.IntType(32), 0), - ir.Constant(ir.IntType(32), field_idx)], - inbounds=True) + if field_name in struct_info.fields: + field_ptr = struct_info.gep( + builder, local_sym_tab[var_name][0], field_name) val = eval_expr(func, module, builder, rval, local_sym_tab, map_sym_tab, structs_sym_tab) - if isinstance(struct_info["field_types"][field_idx], ir.ArrayType) and val[1] == ir.PointerType(ir.IntType(8)): + if isinstance(struct_info.field_type(field_name), ir.ArrayType) and val[1] == ir.PointerType(ir.IntType(8)): # TODO: Figure it out, not a priority rn # Special case for string assignment to char array - #str_len = struct_info["field_types"][field_idx].count - #assign_string_to_array(builder, field_ptr, val[0], str_len) - #print(f"Assigned to struct field {var_name}.{field_name}") + # str_len = struct_info["field_types"][field_idx].count + # assign_string_to_array(builder, field_ptr, val[0], str_len) + # print(f"Assigned to struct field {var_name}.{field_name}") pass if val is None: print("Failed to evaluate struct field assignment") @@ -138,7 +134,7 @@ def handle_assign(func, module, builder, stmt, map_sym_tab, local_sym_tab, struc print(f"Dereferenced and assigned to {var_name}") elif call_type in structs_sym_tab and len(rval.args) == 0: struct_info = structs_sym_tab[call_type] - ir_type = struct_info["type"] + ir_type = struct_info.ir_type # var = builder.alloca(ir_type, name=var_name) # Null init builder.store(ir.Constant(ir_type, None), @@ -364,7 +360,7 @@ def allocate_mem(module, builder, body, func, ret_type, map_sym_tab, local_sym_t f"Pre-allocated variable {var_name} for deref") elif call_type in structs_sym_tab: struct_info = structs_sym_tab[call_type] - ir_type = struct_info["type"] + ir_type = struct_info.ir_type var = builder.alloca(ir_type, name=var_name) local_var_metadata[var_name] = call_type print( @@ -548,6 +544,8 @@ def _expr_type(e): return found_type or "None" # For string assignment to fixed-size arrays + + def assign_string_to_array(builder, target_array_ptr, source_string_ptr, array_length): """ Copy a string (i8*) to a fixed-size array ([N x i8]*) @@ -556,36 +554,39 @@ def assign_string_to_array(builder, target_array_ptr, source_string_ptr, array_l entry_block = builder.block copy_block = builder.append_basic_block("copy_char") end_block = builder.append_basic_block("copy_end") - + # Create loop counter i = builder.alloca(ir.IntType(32)) builder.store(ir.Constant(ir.IntType(32), 0), i) - + # Start the loop builder.branch(copy_block) - + # Copy loop builder.position_at_end(copy_block) idx = builder.load(i) - in_bounds = builder.icmp_unsigned('<', idx, ir.Constant(ir.IntType(32), array_length)) + in_bounds = builder.icmp_unsigned( + '<', idx, ir.Constant(ir.IntType(32), array_length)) builder.cbranch(in_bounds, copy_block, end_block) - + with builder.if_then(in_bounds): # Load character from source src_ptr = builder.gep(source_string_ptr, [idx]) char = builder.load(src_ptr) - + # Store character in target - dst_ptr = builder.gep(target_array_ptr, [ir.Constant(ir.IntType(32), 0), idx]) + dst_ptr = builder.gep( + target_array_ptr, [ir.Constant(ir.IntType(32), 0), idx]) builder.store(char, dst_ptr) - + # Increment counter next_idx = builder.add(idx, ir.Constant(ir.IntType(32), 1)) builder.store(next_idx, i) - + builder.position_at_end(end_block) - + # Ensure null termination last_idx = ir.Constant(ir.IntType(32), array_length - 1) - null_ptr = builder.gep(target_array_ptr, [ir.Constant(ir.IntType(32), 0), last_idx]) + null_ptr = builder.gep( + target_array_ptr, [ir.Constant(ir.IntType(32), 0), last_idx]) builder.store(ir.Constant(ir.IntType(8), 0), null_ptr) diff --git a/pythonbpf/structs/struct_type.py b/pythonbpf/structs/struct_type.py new file mode 100644 index 0000000..ad00883 --- /dev/null +++ b/pythonbpf/structs/struct_type.py @@ -0,0 +1,31 @@ +from llvmlite import ir + + +class StructType: + def __init__(self, ir_type, fields, size): + self.ir_type = ir_type + self.fields = fields + self.size = size + + def field_idx(self, field_name): + return list(self.fields.keys()).index(field_name) + + def field_type(self, field_name): + return self.fields[field_name] + + def gep(self, builder, ptr, field_name): + idx = self.field_idx(field_name) + return builder.gep(ptr, [ir.Constant(ir.IntType(32), 0), + ir.Constant(ir.IntType(32), idx)], + inbounds=True) + + def field_size(self, field_name): + fld = self.fields[field_name] + if isinstance(fld, ir.ArrayType): + return fld.count * (fld.element.width // 8) + elif isinstance(fld, ir.IntType): + return fld.width // 8 + elif isinstance(fld, ir.PointerType): + return 8 + + raise TypeError(f"Unsupported field type: {fld}") diff --git a/pythonbpf/structs/structs_pass.py b/pythonbpf/structs/structs_pass.py new file mode 100644 index 0000000..0fca41c --- /dev/null +++ b/pythonbpf/structs/structs_pass.py @@ -0,0 +1,97 @@ +import ast +import logging +from llvmlite import ir +from pythonbpf.type_deducer import ctypes_to_ir +from .struct_type import StructType + +logger = logging.getLogger(__name__) + +# TODO: Shall we allow the following syntax: +# struct MyStruct: +# field1: int +# field2: str(32) +# Where int is mapped to c_uint64? +# Shall we just int64, int32 and uint32 similarly? + + +def structs_proc(tree, module, chunks): + """ Process all class definitions to find BPF structs """ + structs_sym_tab = {} + for cls_node in chunks: + if is_bpf_struct(cls_node): + print(f"Found BPF struct: {cls_node.name}") + struct_info = process_bpf_struct(cls_node, module) + structs_sym_tab[cls_node.name] = struct_info + return structs_sym_tab + + +def is_bpf_struct(cls_node): + return any( + isinstance(decorator, ast.Name) and decorator.id == "struct" + for decorator in cls_node.decorator_list + ) + + +def process_bpf_struct(cls_node, module): + """ Process a single BPF struct definition """ + + fields = parse_struct_fields(cls_node) + field_types = list(fields.values()) + total_size = calc_struct_size(field_types) + struct_type = ir.LiteralStructType(field_types) + logger.info(f"Created struct {cls_node.name} with fields {fields.keys()}") + return StructType(struct_type, fields, total_size) + + +def parse_struct_fields(cls_node): + """ Parse fields of a struct class node """ + fields = {} + + for item in cls_node.body: + if isinstance(item, ast.AnnAssign) and \ + isinstance(item.target, ast.Name): + fields[item.target.id] = get_type_from_ann(item.annotation) + else: + logger.error(f"Unsupported struct field: {ast.dump(item)}") + raise TypeError(f"Unsupported field in {ast.dump(cls_node)}") + return fields + + +def get_type_from_ann(annotation): + """ Convert an AST annotation node to an LLVM IR type for struct fields""" + if isinstance(annotation, ast.Call) and \ + isinstance(annotation.func, ast.Name): + if annotation.func.id == "str": + # Char array + # Assumes constant integer argument + length = annotation.args[0].value + return ir.ArrayType(ir.IntType(8), length) + elif isinstance(annotation, ast.Name): + # Int type, written as c_int64, c_uint32, etc. + return ctypes_to_ir(annotation.id) + + raise TypeError(f"Unsupported annotation type: {ast.dump(annotation)}") + + +def calc_struct_size(field_types): + """ Calculate total size of the struct with alignment and padding """ + curr_offset = 0 + for ftype in field_types: + if isinstance(ftype, ir.IntType): + fsize = ftype.width // 8 + alignment = fsize + elif isinstance(ftype, ir.ArrayType): + fsize = ftype.count * (ftype.element.width // 8) + alignment = ftype.element.width // 8 + elif isinstance(ftype, ir.PointerType): + # We won't encounter this rn, but for the future + fsize = 8 + alignment = 8 + else: + raise TypeError(f"Unsupported field type: {ftype}") + + padding = (alignment - (curr_offset % alignment)) % alignment + curr_offset += padding + fsize + + final_padding = (8 - (curr_offset % 8)) % 8 + return curr_offset + final_padding diff --git a/pythonbpf/structs_pass.py b/pythonbpf/structs_pass.py deleted file mode 100644 index 84394e3..0000000 --- a/pythonbpf/structs_pass.py +++ /dev/null @@ -1,68 +0,0 @@ -import ast -from llvmlite import ir -from .type_deducer import ctypes_to_ir - -structs_sym_tab = {} - - -def structs_proc(tree, module, chunks): - for cls_node in chunks: - # Check if this class is a struct - is_struct = False - for decorator in cls_node.decorator_list: - if isinstance(decorator, ast.Name) and decorator.id == "struct": - is_struct = True - break - if is_struct: - print(f"Found BPF struct: {cls_node.name}") - process_bpf_struct(cls_node, module) - continue - return structs_sym_tab - - -def process_bpf_struct(cls_node, module): - struct_name = cls_node.name - field_names = [] - field_types = [] - - for item in cls_node.body: - if isinstance(item, ast.AnnAssign) and isinstance(item.target, ast.Name): - print(f"Field: {item.target.id}, Type: " - f"{ast.dump(item.annotation)}") - field_names.append(item.target.id) - if isinstance(item.annotation, ast.Call) and isinstance(item.annotation.func, ast.Name) and item.annotation.func.id == "str": - # This is a char array with fixed length - # TODO: For now assuming str is always called with constant - field_types.append(ir.ArrayType( - ir.IntType(8), item.annotation.args[0].value)) - else: - field_types.append(ctypes_to_ir(item.annotation.id)) - - curr_offset = 0 - for ftype in field_types: - if isinstance(ftype, ir.IntType): - fsize = ftype.width // 8 - alignment = fsize - elif isinstance(ftype, ir.ArrayType): - fsize = ftype.count * (ftype.element.width // 8) - alignment = ftype.element.width // 8 - elif isinstance(ftype, ir.PointerType): - fsize = 8 - alignment = 8 - else: - print(f"Unsupported field type in struct {struct_name}") - return - padding = (alignment - (curr_offset % alignment)) % alignment - curr_offset += padding - curr_offset += fsize - final_padding = (8 - (curr_offset % 8)) % 8 - total_size = curr_offset + final_padding - - struct_type = ir.LiteralStructType(field_types) - structs_sym_tab[struct_name] = { - "type": struct_type, - "fields": {name: idx for idx, name in enumerate(field_names)}, - "size": total_size, - "field_types": field_types, - } - print(f"Created struct {struct_name} with fields {field_names}")