Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pythonbpf/codegen.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import ast
from llvmlite import ir
from .license_pass import license_processing
from .functions_pass import func_proc
from .functions import func_proc
from .maps import maps_proc
from .structs import structs_proc
from .globals_pass import globals_processing
Expand Down
59 changes: 59 additions & 0 deletions pythonbpf/expr_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
import logging
from typing import Dict

from .type_deducer import ctypes_to_ir, is_ctypes

logger: Logger = logging.getLogger(__name__)


Expand Down Expand Up @@ -88,6 +90,48 @@ def _handle_deref_call(expr: ast.Call, local_sym_tab: Dict, builder: ir.IRBuilde
return val, local_sym_tab[arg.id].ir_type


def _handle_ctypes_call(
func,
module,
builder,
expr,
local_sym_tab,
map_sym_tab,
structs_sym_tab=None,
):
"""Handle ctypes type constructor calls."""
if len(expr.args) != 1:
logger.info("ctypes constructor takes exactly one argument")
return None

arg = expr.args[0]
val = eval_expr(
func,
module,
builder,
arg,
local_sym_tab,
map_sym_tab,
structs_sym_tab,
)
if val is None:
logger.info("Failed to evaluate argument to ctypes constructor")
return None
call_type = expr.func.id
expected_type = ctypes_to_ir(call_type)

if val[1] != expected_type:
# NOTE: We are only considering casting to and from int types for now
if isinstance(val[1], ir.IntType) and isinstance(expected_type, ir.IntType):
if val[1].width < expected_type.width:
val = (builder.sext(val[0], expected_type), expected_type)
else:
val = (builder.trunc(val[0], expected_type), expected_type)
else:
raise ValueError(f"Type mismatch: expected {expected_type}, got {val[1]}")
return val


def eval_expr(
func,
module,
Expand All @@ -106,6 +150,17 @@ def eval_expr(
if isinstance(expr.func, ast.Name) and expr.func.id == "deref":
return _handle_deref_call(expr, local_sym_tab, builder)

if isinstance(expr.func, ast.Name) and is_ctypes(expr.func.id):
return _handle_ctypes_call(
func,
module,
builder,
expr,
local_sym_tab,
map_sym_tab,
structs_sym_tab,
)

# delayed import to avoid circular dependency
from pythonbpf.helper import HelperHandlerRegistry, handle_helper_call

Expand Down Expand Up @@ -153,6 +208,10 @@ def eval_expr(
)
elif isinstance(expr, ast.Attribute):
return _handle_attribute_expr(expr, local_sym_tab, structs_sym_tab, builder)
elif isinstance(expr, ast.BinOp):
from pythonbpf.binary_ops import handle_binary_op

return handle_binary_op(expr, builder, None, local_sym_tab)
logger.info("Unsupported expression evaluation")
return None

Expand Down
3 changes: 3 additions & 0 deletions pythonbpf/functions/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .functions_pass import func_proc

__all__ = ["func_proc"]
22 changes: 22 additions & 0 deletions pythonbpf/functions/func_registry_handlers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
from typing import Dict


class StatementHandlerRegistry:
"""Registry for statement handlers."""

_handlers: Dict = {}

@classmethod
def register(cls, stmt_type):
"""Register a handler for a specific statement type."""

def decorator(handler):
cls._handlers[stmt_type] = handler
return handler

return decorator

@classmethod
def __getitem__(cls, stmt_type):
"""Get the handler for a specific statement type."""
return cls._handlers.get(stmt_type, None)
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,13 @@
from typing import Any
from dataclasses import dataclass

from .helper import HelperHandlerRegistry, handle_helper_call
from .type_deducer import ctypes_to_ir
from .binary_ops import handle_binary_op
from .expr_pass import eval_expr, handle_expr
from pythonbpf.helper import HelperHandlerRegistry, handle_helper_call
from pythonbpf.type_deducer import ctypes_to_ir
from pythonbpf.binary_ops import handle_binary_op
from pythonbpf.expr_pass import eval_expr, handle_expr

from .return_utils import _handle_none_return, _handle_xdp_return, _is_xdp_name


logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -350,6 +353,27 @@ def handle_if(
builder.position_at_end(merge_block)


def handle_return(builder, stmt, local_sym_tab, ret_type):
logger.info(f"Handling return statement: {ast.dump(stmt)}")
if stmt.value is None:
return _handle_none_return(builder)
elif isinstance(stmt.value, ast.Name) and _is_xdp_name(stmt.value.id):
return _handle_xdp_return(stmt, builder, ret_type)
else:
val = eval_expr(
func=None,
module=None,
builder=builder,
expr=stmt.value,
local_sym_tab=local_sym_tab,
map_sym_tab={},
structs_sym_tab={},
)
logger.info(f"Evaluated return expression to {val}")
builder.ret(val[0])
return True


def process_stmt(
func,
module,
Expand Down Expand Up @@ -383,61 +407,12 @@ def process_stmt(
func, module, builder, stmt, map_sym_tab, local_sym_tab, structs_sym_tab
)
elif isinstance(stmt, ast.Return):
if stmt.value is None:
builder.ret(ir.Constant(ir.IntType(64), 0))
did_return = True
elif (
isinstance(stmt.value, ast.Call)
and isinstance(stmt.value.func, ast.Name)
and len(stmt.value.args) == 1
):
if isinstance(stmt.value.args[0], ast.Constant) and isinstance(
stmt.value.args[0].value, int
):
call_type = stmt.value.func.id
if ctypes_to_ir(call_type) != ret_type:
raise ValueError(
"Return type mismatch: expected"
f"{ctypes_to_ir(call_type)}, got {call_type}"
)
else:
builder.ret(ir.Constant(ret_type, stmt.value.args[0].value))
did_return = True
elif isinstance(stmt.value.args[0], ast.BinOp):
# TODO: Should be routed through eval_expr
val = handle_binary_op(stmt.value.args[0], builder, None, local_sym_tab)
if val is None:
raise ValueError("Failed to evaluate return expression")
if val[1] != ret_type:
raise ValueError(
"Return type mismatch: expected " f"{ret_type}, got {val[1]}"
)
builder.ret(val[0])
did_return = True
elif isinstance(stmt.value.args[0], ast.Name):
if stmt.value.args[0].id in local_sym_tab:
var = local_sym_tab[stmt.value.args[0].id].var
val = builder.load(var)
if val.type != ret_type:
raise ValueError(
"Return type mismatch: expected"
f"{ret_type}, got {val.type}"
)
builder.ret(val)
did_return = True
else:
raise ValueError("Failed to evaluate return expression")
elif isinstance(stmt.value, ast.Name):
if stmt.value.id == "XDP_PASS":
builder.ret(ir.Constant(ret_type, 2))
did_return = True
elif stmt.value.id == "XDP_DROP":
builder.ret(ir.Constant(ret_type, 1))
did_return = True
else:
raise ValueError("Failed to evaluate return expression")
else:
raise ValueError("Unsupported return value")
did_return = handle_return(
builder,
stmt,
local_sym_tab,
ret_type,
)
return did_return


Expand Down
45 changes: 45 additions & 0 deletions pythonbpf/functions/return_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
import logging
import ast

from llvmlite import ir

logger: logging.Logger = logging.getLogger(__name__)

XDP_ACTIONS = {
"XDP_ABORTED": 0,
"XDP_DROP": 1,
"XDP_PASS": 2,
"XDP_TX": 3,
"XDP_REDIRECT": 4,
}


def _handle_none_return(builder) -> bool:
"""Handle return or return None -> returns 0."""
builder.ret(ir.Constant(ir.IntType(64), 0))
logger.debug("Generated default return: 0")
return True


def _is_xdp_name(name: str) -> bool:
"""Check if a name is an XDP action"""
return name in XDP_ACTIONS


def _handle_xdp_return(stmt: ast.Return, builder, ret_type) -> bool:
"""Handle XDP returns"""
if not isinstance(stmt.value, ast.Name):
return False

action_name = stmt.value.id

if action_name not in XDP_ACTIONS:
raise ValueError(
f"Unknown XDP action: {action_name}. Available: {XDP_ACTIONS.keys()}"
)
return False

value = XDP_ACTIONS[action_name]
builder.ret(ir.Constant(ret_type, value))
logger.debug(f"Generated XDP action return: {action_name} = {value}")
return True
34 changes: 19 additions & 15 deletions pythonbpf/type_deducer.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,28 @@
from llvmlite import ir

# TODO: THIS IS NOT SUPPOSED TO MATCH STRINGS :skull:
mapping = {
"c_int8": ir.IntType(8),
"c_uint8": ir.IntType(8),
"c_int16": ir.IntType(16),
"c_uint16": ir.IntType(16),
"c_int32": ir.IntType(32),
"c_uint32": ir.IntType(32),
"c_int64": ir.IntType(64),
"c_uint64": ir.IntType(64),
"c_float": ir.FloatType(),
"c_double": ir.DoubleType(),
"c_void_p": ir.IntType(64),
# Not so sure about this one
"str": ir.PointerType(ir.IntType(8)),
}


def ctypes_to_ir(ctype: str):
mapping = {
"c_int8": ir.IntType(8),
"c_uint8": ir.IntType(8),
"c_int16": ir.IntType(16),
"c_uint16": ir.IntType(16),
"c_int32": ir.IntType(32),
"c_uint32": ir.IntType(32),
"c_int64": ir.IntType(64),
"c_uint64": ir.IntType(64),
"c_float": ir.FloatType(),
"c_double": ir.DoubleType(),
"c_void_p": ir.IntType(64),
# Not so sure about this one
"str": ir.PointerType(ir.IntType(8)),
}
if ctype in mapping:
return mapping[ctype]
raise NotImplementedError(f"No mapping for {ctype}")


def is_ctypes(ctype: str) -> bool:
return ctype in mapping
18 changes: 18 additions & 0 deletions tests/passing_tests/return/binop_const.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
from pythonbpf import bpf, section, bpfglobal, compile
from ctypes import c_void_p, c_int64


@bpf
@section("tracepoint/syscalls/sys_enter_execve")
def hello_world(ctx: c_void_p) -> c_int64:
print("Hello, World!")
return 1 + 1 - 2


@bpf
@bpfglobal
def LICENSE() -> str:
return "GPL"


compile()
19 changes: 19 additions & 0 deletions tests/passing_tests/return/binop_var.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
from pythonbpf import bpf, section, bpfglobal, compile
from ctypes import c_void_p, c_int64


@bpf
@section("tracepoint/syscalls/sys_enter_execve")
def hello_world(ctx: c_void_p) -> c_int64:
print("Hello, World!")
a = 2
return a - 2


@bpf
@bpfglobal
def LICENSE() -> str:
return "GPL"


compile()
18 changes: 18 additions & 0 deletions tests/passing_tests/return/int.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
from pythonbpf import bpf, section, bpfglobal, compile
from ctypes import c_void_p, c_int64


@bpf
@section("tracepoint/syscalls/sys_enter_execve")
def hello_world(ctx: c_void_p) -> c_int64:
print("Hello, World!")
return 1


@bpf
@bpfglobal
def LICENSE() -> str:
return "GPL"


compile()
Loading