Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions pythonbpf/allocation_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from dataclasses import dataclass
from typing import Any
from pythonbpf.helper import HelperHandlerRegistry
from .expr import VmlinuxHandlerRegistry
from pythonbpf.type_deducer import ctypes_to_ir

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -49,6 +50,15 @@ def handle_assign_allocation(builder, stmt, local_sym_tab, structs_sym_tab):
logger.debug(f"Variable {var_name} already allocated, skipping")
return

# When allocating a variable, check if it's a vmlinux struct type
if isinstance(stmt.value, ast.Name) and VmlinuxHandlerRegistry.is_vmlinux_struct(
stmt.value.id
):
# Handle vmlinux struct allocation
# This requires more implementation
print(stmt.value)
pass

# Determine type and allocate based on rval
if isinstance(rval, ast.Call):
_allocate_for_call(builder, var_name, rval, local_sym_tab, structs_sym_tab)
Expand Down
7 changes: 6 additions & 1 deletion pythonbpf/codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
from .maps import maps_proc
from .structs import structs_proc
from .vmlinux_parser import vmlinux_proc
from pythonbpf.vmlinux_parser.vmlinux_exports_handler import VmlinuxHandler
from .expr import VmlinuxHandlerRegistry
from .globals_pass import (
globals_list_creation,
globals_processing,
Expand Down Expand Up @@ -56,10 +58,13 @@ def processor(source_code, filename, module):
logger.info(f"Found BPF function/struct: {func_node.name}")

vmlinux_symtab = vmlinux_proc(tree, module)
if vmlinux_symtab:
handler = VmlinuxHandler.initialize(vmlinux_symtab)
VmlinuxHandlerRegistry.set_handler(handler)

populate_global_symbol_table(tree, module)
license_processing(tree, module)
globals_processing(tree, module)
print("DEBUG:", vmlinux_symtab)
structs_sym_tab = structs_proc(tree, module, bpf_chunks)
map_sym_tab = maps_proc(tree, module, bpf_chunks)
func_proc(tree, module, bpf_chunks, map_sym_tab, structs_sym_tab)
Expand Down
2 changes: 2 additions & 0 deletions pythonbpf/expr/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from .type_normalization import convert_to_bool, get_base_type_and_depth
from .ir_ops import deref_to_depth
from .call_registry import CallHandlerRegistry
from .vmlinux_registry import VmlinuxHandlerRegistry

__all__ = [
"eval_expr",
Expand All @@ -11,4 +12,5 @@
"deref_to_depth",
"get_operand_value",
"CallHandlerRegistry",
"VmlinuxHandlerRegistry",
]
24 changes: 21 additions & 3 deletions pythonbpf/expr/expr_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
get_base_type_and_depth,
deref_to_depth,
)
from .vmlinux_registry import VmlinuxHandlerRegistry

logger: Logger = logging.getLogger(__name__)

Expand All @@ -27,8 +28,12 @@ def _handle_name_expr(expr: ast.Name, local_sym_tab: Dict, builder: ir.IRBuilder
val = builder.load(var)
return val, local_sym_tab[expr.id].ir_type
else:
logger.info(f"Undefined variable {expr.id}")
return None
# Check if it's a vmlinux enum/constant
vmlinux_result = VmlinuxHandlerRegistry.handle_name(expr.id)
if vmlinux_result is not None:
return vmlinux_result

raise SyntaxError(f"Undefined variable {expr.id}")


def _handle_constant_expr(module, builder, expr: ast.Constant):
Expand Down Expand Up @@ -74,6 +79,13 @@ def _handle_attribute_expr(
val = builder.load(gep)
field_type = metadata.field_type(attr_name)
return val, field_type

# Try vmlinux handler as fallback
vmlinux_result = VmlinuxHandlerRegistry.handle_attribute(
expr, local_sym_tab, None, builder
)
if vmlinux_result is not None:
return vmlinux_result
return None


Expand Down Expand Up @@ -130,7 +142,12 @@ def get_operand_value(
logger.info(f"var is {var}, base_type is {base_type}, depth is {depth}")
val = deref_to_depth(func, builder, var, depth)
return val
raise ValueError(f"Undefined variable: {operand.id}")
else:
# Check if it's a vmlinux enum/constant
vmlinux_result = VmlinuxHandlerRegistry.handle_name(operand.id)
if vmlinux_result is not None:
val, _ = vmlinux_result
return val
elif isinstance(operand, ast.Constant):
if isinstance(operand.value, int):
cst = ir.Constant(ir.IntType(64), int(operand.value))
Expand Down Expand Up @@ -332,6 +349,7 @@ def _handle_unary_op(
neg_one = ir.Constant(ir.IntType(64), -1)
result = builder.mul(operand, neg_one)
return result, ir.IntType(64)
return None


# ============================================================================
Expand Down
45 changes: 45 additions & 0 deletions pythonbpf/expr/vmlinux_registry.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
import ast


class VmlinuxHandlerRegistry:
"""Registry for vmlinux handler operations"""

_handler = None

@classmethod
def set_handler(cls, handler):
"""Set the vmlinux handler"""
cls._handler = handler

@classmethod
def get_handler(cls):
"""Get the vmlinux handler"""
return cls._handler

@classmethod
def handle_name(cls, name):
"""Try to handle a name as vmlinux enum/constant"""
if cls._handler is None:
return None
return cls._handler.handle_vmlinux_enum(name)

@classmethod
def handle_attribute(cls, expr, local_sym_tab, module, builder):
"""Try to handle an attribute access as vmlinux struct field"""
if cls._handler is None:
return None

if isinstance(expr.value, ast.Name):
var_name = expr.value.id
field_name = expr.attr
return cls._handler.handle_vmlinux_struct_field(
var_name, field_name, module, builder, local_sym_tab
)
return None

@classmethod
def is_vmlinux_struct(cls, name):
"""Check if a name refers to a vmlinux struct"""
if cls._handler is None:
return False
return cls._handler.is_vmlinux_struct(name)
16 changes: 14 additions & 2 deletions pythonbpf/functions/functions_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -311,7 +311,13 @@ def process_stmt(


def process_func_body(
module, builder, func_node, func, ret_type, map_sym_tab, structs_sym_tab
module,
builder,
func_node,
func,
ret_type,
map_sym_tab,
structs_sym_tab,
):
"""Process the body of a bpf function"""
# TODO: A lot. We just have print -> bpf_trace_printk for now
Expand Down Expand Up @@ -384,7 +390,13 @@ def process_bpf_chunk(func_node, module, return_type, map_sym_tab, structs_sym_t
builder = ir.IRBuilder(block)

process_func_body(
module, builder, func_node, func, ret_type, map_sym_tab, structs_sym_tab
module,
builder,
func_node,
func,
ret_type,
map_sym_tab,
structs_sym_tab,
)
return func

Expand Down
11 changes: 11 additions & 0 deletions pythonbpf/helper/printk_formatter.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

from llvmlite import ir
from pythonbpf.expr import eval_expr, get_base_type_and_depth, deref_to_depth
from pythonbpf.expr.vmlinux_registry import VmlinuxHandlerRegistry

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -108,6 +109,16 @@ def _process_name_in_fval(name_node, fmt_parts, exprs, local_sym_tab):
if local_sym_tab and name_node.id in local_sym_tab:
_, var_type, tmp = local_sym_tab[name_node.id]
_populate_fval(var_type, name_node, fmt_parts, exprs)
else:
# Try to resolve through vmlinux registry if not in local symbol table
result = VmlinuxHandlerRegistry.handle_name(name_node.id)
if result:
val, var_type = result
_populate_fval(var_type, name_node, fmt_parts, exprs)
else:
raise ValueError(
f"Variable '{name_node.id}' not found in symbol table or vmlinux"
)


def _process_attr_in_fval(attr_node, fmt_parts, exprs, local_sym_tab, struct_sym_tab):
Expand Down
11 changes: 9 additions & 2 deletions pythonbpf/maps/maps_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
from .maps_utils import MapProcessorRegistry
from .map_types import BPFMapType
from .map_debug_info import create_map_debug_info, create_ringbuf_debug_info
from pythonbpf.expr.vmlinux_registry import VmlinuxHandlerRegistry


logger: Logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -51,7 +53,7 @@ def _parse_map_params(rval, expected_args=None):
"""Parse map parameters from call arguments and keywords."""

params = {}

handler = VmlinuxHandlerRegistry.get_handler()
# Parse positional arguments
if expected_args:
for i, arg_name in enumerate(expected_args):
Expand All @@ -65,7 +67,12 @@ def _parse_map_params(rval, expected_args=None):
# Parse keyword arguments (override positional)
for keyword in rval.keywords:
if isinstance(keyword.value, ast.Name):
params[keyword.arg] = keyword.value.id
name = keyword.value.id
if handler and handler.is_vmlinux_enum(name):
result = handler.get_vmlinux_enum_value(name)
params[keyword.arg] = result if result is not None else name
else:
params[keyword.arg] = name
elif isinstance(keyword.value, ast.Constant):
params[keyword.arg] = keyword.value.value

Expand Down
90 changes: 90 additions & 0 deletions pythonbpf/vmlinux_parser/vmlinux_exports_handler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
import logging
from llvmlite import ir

from pythonbpf.vmlinux_parser.assignment_info import AssignmentType

logger = logging.getLogger(__name__)


class VmlinuxHandler:
"""Handler for vmlinux-related operations"""

_instance = None

@classmethod
def get_instance(cls):
"""Get the singleton instance"""
if cls._instance is None:
logger.warning("VmlinuxHandler used before initialization")
return None
return cls._instance

@classmethod
def initialize(cls, vmlinux_symtab):
"""Initialize the handler with vmlinux symbol table"""
cls._instance = cls(vmlinux_symtab)
return cls._instance

def __init__(self, vmlinux_symtab):
"""Initialize with vmlinux symbol table"""
self.vmlinux_symtab = vmlinux_symtab
logger.info(
f"VmlinuxHandler initialized with {len(vmlinux_symtab) if vmlinux_symtab else 0} symbols"
)

def is_vmlinux_enum(self, name):
"""Check if name is a vmlinux enum constant"""
return (
name in self.vmlinux_symtab
and self.vmlinux_symtab[name]["value_type"] == AssignmentType.CONSTANT
)

def is_vmlinux_struct(self, name):
"""Check if name is a vmlinux struct"""
return (
name in self.vmlinux_symtab
and self.vmlinux_symtab[name]["value_type"] == AssignmentType.STRUCT
)

def handle_vmlinux_enum(self, name):
"""Handle vmlinux enum constants by returning LLVM IR constants"""
if self.is_vmlinux_enum(name):
value = self.vmlinux_symtab[name]["value"]
logger.info(f"Resolving vmlinux enum {name} = {value}")
return ir.Constant(ir.IntType(64), value), ir.IntType(64)
return None

def get_vmlinux_enum_value(self, name):
"""Handle vmlinux enum constants by returning LLVM IR constants"""
if self.is_vmlinux_enum(name):
value = self.vmlinux_symtab[name]["value"]
logger.info(f"The value of vmlinux enum {name} = {value}")
return value
return None

def handle_vmlinux_struct(self, struct_name, module, builder):
"""Handle vmlinux struct initializations"""
if self.is_vmlinux_struct(struct_name):
# TODO: Implement core-specific struct handling
# This will be more complex and depends on the BTF information
logger.info(f"Handling vmlinux struct {struct_name}")
# Return struct type and allocated pointer
# This is a stub, actual implementation will be more complex
return None
return None

def handle_vmlinux_struct_field(
self, struct_var_name, field_name, module, builder, local_sym_tab
):
"""Handle access to vmlinux struct fields"""
# Check if it's a variable of vmlinux struct type
if struct_var_name in local_sym_tab:
var_info = local_sym_tab[struct_var_name] # noqa: F841
# Need to check if this variable is a vmlinux struct
# This will depend on how you track vmlinux struct types in your symbol table
logger.info(
f"Attempting to access field {field_name} of possible vmlinux struct {struct_var_name}"
)
# Return pointer to field and field type
return None
return None
28 changes: 23 additions & 5 deletions tests/passing_tests/vmlinux/simple_struct_test.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,26 @@
from pythonbpf import bpf, section, bpfglobal, compile_to_ir
import logging

from pythonbpf import bpf, section, bpfglobal, compile_to_ir, map
from pythonbpf import compile # noqa: F401
from vmlinux import TASK_COMM_LEN # noqa: F401
from vmlinux import struct_trace_event_raw_sys_enter # noqa: F401
from ctypes import c_uint64, c_int32, c_int64
from pythonbpf.maps import HashMap

# from vmlinux import struct_uinput_device
# from vmlinux import struct_blk_integrity_iter
from ctypes import c_int64


@bpf
@map
def mymap() -> HashMap:
return HashMap(key=c_int32, value=c_uint64, max_entries=TASK_COMM_LEN)


@bpf
@map
def mymap2() -> HashMap:
return HashMap(key=c_int32, value=c_uint64, max_entries=18)


# Instructions to how to run this program
Expand All @@ -16,8 +32,9 @@
@bpf
@section("tracepoint/syscalls/sys_enter_execve")
def hello_world(ctx: struct_trace_event_raw_sys_enter) -> c_int64:
print("Hello, World!")
return c_int64(0)
a = 2 + TASK_COMM_LEN + TASK_COMM_LEN
print(f"Hello, World{TASK_COMM_LEN} and {a}")
return c_int64(TASK_COMM_LEN + 2)


@bpf
Expand All @@ -26,4 +43,5 @@ def LICENSE() -> str:
return "GPL"


compile_to_ir("simple_struct_test.py", "simple_struct_test.ll")
compile_to_ir("simple_struct_test.py", "simple_struct_test.ll", loglevel=logging.DEBUG)
# compile()