Skip to content

Commit 28ce14c

Browse files
authored
Merge pull request #24 from pythonbpf/func_refactor
Refactor handle_return
2 parents 552cd35 + 5066cd4 commit 28ce14c

16 files changed

+351
-75
lines changed

pythonbpf/codegen.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import ast
22
from llvmlite import ir
33
from .license_pass import license_processing
4-
from .functions_pass import func_proc
4+
from .functions import func_proc
55
from .maps import maps_proc
66
from .structs import structs_proc
77
from .globals_pass import globals_processing

pythonbpf/expr_pass.py

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
import logging
55
from typing import Dict
66

7+
from .type_deducer import ctypes_to_ir, is_ctypes
8+
79
logger: Logger = logging.getLogger(__name__)
810

911

@@ -88,6 +90,48 @@ def _handle_deref_call(expr: ast.Call, local_sym_tab: Dict, builder: ir.IRBuilde
8890
return val, local_sym_tab[arg.id].ir_type
8991

9092

93+
def _handle_ctypes_call(
94+
func,
95+
module,
96+
builder,
97+
expr,
98+
local_sym_tab,
99+
map_sym_tab,
100+
structs_sym_tab=None,
101+
):
102+
"""Handle ctypes type constructor calls."""
103+
if len(expr.args) != 1:
104+
logger.info("ctypes constructor takes exactly one argument")
105+
return None
106+
107+
arg = expr.args[0]
108+
val = eval_expr(
109+
func,
110+
module,
111+
builder,
112+
arg,
113+
local_sym_tab,
114+
map_sym_tab,
115+
structs_sym_tab,
116+
)
117+
if val is None:
118+
logger.info("Failed to evaluate argument to ctypes constructor")
119+
return None
120+
call_type = expr.func.id
121+
expected_type = ctypes_to_ir(call_type)
122+
123+
if val[1] != expected_type:
124+
# NOTE: We are only considering casting to and from int types for now
125+
if isinstance(val[1], ir.IntType) and isinstance(expected_type, ir.IntType):
126+
if val[1].width < expected_type.width:
127+
val = (builder.sext(val[0], expected_type), expected_type)
128+
else:
129+
val = (builder.trunc(val[0], expected_type), expected_type)
130+
else:
131+
raise ValueError(f"Type mismatch: expected {expected_type}, got {val[1]}")
132+
return val
133+
134+
91135
def eval_expr(
92136
func,
93137
module,
@@ -106,6 +150,17 @@ def eval_expr(
106150
if isinstance(expr.func, ast.Name) and expr.func.id == "deref":
107151
return _handle_deref_call(expr, local_sym_tab, builder)
108152

153+
if isinstance(expr.func, ast.Name) and is_ctypes(expr.func.id):
154+
return _handle_ctypes_call(
155+
func,
156+
module,
157+
builder,
158+
expr,
159+
local_sym_tab,
160+
map_sym_tab,
161+
structs_sym_tab,
162+
)
163+
109164
# delayed import to avoid circular dependency
110165
from pythonbpf.helper import HelperHandlerRegistry, handle_helper_call
111166

@@ -153,6 +208,10 @@ def eval_expr(
153208
)
154209
elif isinstance(expr, ast.Attribute):
155210
return _handle_attribute_expr(expr, local_sym_tab, structs_sym_tab, builder)
211+
elif isinstance(expr, ast.BinOp):
212+
from pythonbpf.binary_ops import handle_binary_op
213+
214+
return handle_binary_op(expr, builder, None, local_sym_tab)
156215
logger.info("Unsupported expression evaluation")
157216
return None
158217

pythonbpf/functions/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
from .functions_pass import func_proc
2+
3+
__all__ = ["func_proc"]
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
from typing import Dict
2+
3+
4+
class StatementHandlerRegistry:
5+
"""Registry for statement handlers."""
6+
7+
_handlers: Dict = {}
8+
9+
@classmethod
10+
def register(cls, stmt_type):
11+
"""Register a handler for a specific statement type."""
12+
13+
def decorator(handler):
14+
cls._handlers[stmt_type] = handler
15+
return handler
16+
17+
return decorator
18+
19+
@classmethod
20+
def __getitem__(cls, stmt_type):
21+
"""Get the handler for a specific statement type."""
22+
return cls._handlers.get(stmt_type, None)

pythonbpf/functions_pass.py renamed to pythonbpf/functions/functions_pass.py

Lines changed: 34 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,13 @@
44
from typing import Any
55
from dataclasses import dataclass
66

7-
from .helper import HelperHandlerRegistry, handle_helper_call
8-
from .type_deducer import ctypes_to_ir
9-
from .binary_ops import handle_binary_op
10-
from .expr_pass import eval_expr, handle_expr
7+
from pythonbpf.helper import HelperHandlerRegistry, handle_helper_call
8+
from pythonbpf.type_deducer import ctypes_to_ir
9+
from pythonbpf.binary_ops import handle_binary_op
10+
from pythonbpf.expr_pass import eval_expr, handle_expr
11+
12+
from .return_utils import _handle_none_return, _handle_xdp_return, _is_xdp_name
13+
1114

1215
logger = logging.getLogger(__name__)
1316

@@ -350,6 +353,27 @@ def handle_if(
350353
builder.position_at_end(merge_block)
351354

352355

356+
def handle_return(builder, stmt, local_sym_tab, ret_type):
357+
logger.info(f"Handling return statement: {ast.dump(stmt)}")
358+
if stmt.value is None:
359+
return _handle_none_return(builder)
360+
elif isinstance(stmt.value, ast.Name) and _is_xdp_name(stmt.value.id):
361+
return _handle_xdp_return(stmt, builder, ret_type)
362+
else:
363+
val = eval_expr(
364+
func=None,
365+
module=None,
366+
builder=builder,
367+
expr=stmt.value,
368+
local_sym_tab=local_sym_tab,
369+
map_sym_tab={},
370+
structs_sym_tab={},
371+
)
372+
logger.info(f"Evaluated return expression to {val}")
373+
builder.ret(val[0])
374+
return True
375+
376+
353377
def process_stmt(
354378
func,
355379
module,
@@ -383,61 +407,12 @@ def process_stmt(
383407
func, module, builder, stmt, map_sym_tab, local_sym_tab, structs_sym_tab
384408
)
385409
elif isinstance(stmt, ast.Return):
386-
if stmt.value is None:
387-
builder.ret(ir.Constant(ir.IntType(64), 0))
388-
did_return = True
389-
elif (
390-
isinstance(stmt.value, ast.Call)
391-
and isinstance(stmt.value.func, ast.Name)
392-
and len(stmt.value.args) == 1
393-
):
394-
if isinstance(stmt.value.args[0], ast.Constant) and isinstance(
395-
stmt.value.args[0].value, int
396-
):
397-
call_type = stmt.value.func.id
398-
if ctypes_to_ir(call_type) != ret_type:
399-
raise ValueError(
400-
"Return type mismatch: expected"
401-
f"{ctypes_to_ir(call_type)}, got {call_type}"
402-
)
403-
else:
404-
builder.ret(ir.Constant(ret_type, stmt.value.args[0].value))
405-
did_return = True
406-
elif isinstance(stmt.value.args[0], ast.BinOp):
407-
# TODO: Should be routed through eval_expr
408-
val = handle_binary_op(stmt.value.args[0], builder, None, local_sym_tab)
409-
if val is None:
410-
raise ValueError("Failed to evaluate return expression")
411-
if val[1] != ret_type:
412-
raise ValueError(
413-
"Return type mismatch: expected " f"{ret_type}, got {val[1]}"
414-
)
415-
builder.ret(val[0])
416-
did_return = True
417-
elif isinstance(stmt.value.args[0], ast.Name):
418-
if stmt.value.args[0].id in local_sym_tab:
419-
var = local_sym_tab[stmt.value.args[0].id].var
420-
val = builder.load(var)
421-
if val.type != ret_type:
422-
raise ValueError(
423-
"Return type mismatch: expected"
424-
f"{ret_type}, got {val.type}"
425-
)
426-
builder.ret(val)
427-
did_return = True
428-
else:
429-
raise ValueError("Failed to evaluate return expression")
430-
elif isinstance(stmt.value, ast.Name):
431-
if stmt.value.id == "XDP_PASS":
432-
builder.ret(ir.Constant(ret_type, 2))
433-
did_return = True
434-
elif stmt.value.id == "XDP_DROP":
435-
builder.ret(ir.Constant(ret_type, 1))
436-
did_return = True
437-
else:
438-
raise ValueError("Failed to evaluate return expression")
439-
else:
440-
raise ValueError("Unsupported return value")
410+
did_return = handle_return(
411+
builder,
412+
stmt,
413+
local_sym_tab,
414+
ret_type,
415+
)
441416
return did_return
442417

443418

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
import logging
2+
import ast
3+
4+
from llvmlite import ir
5+
6+
logger: logging.Logger = logging.getLogger(__name__)
7+
8+
XDP_ACTIONS = {
9+
"XDP_ABORTED": 0,
10+
"XDP_DROP": 1,
11+
"XDP_PASS": 2,
12+
"XDP_TX": 3,
13+
"XDP_REDIRECT": 4,
14+
}
15+
16+
17+
def _handle_none_return(builder) -> bool:
18+
"""Handle return or return None -> returns 0."""
19+
builder.ret(ir.Constant(ir.IntType(64), 0))
20+
logger.debug("Generated default return: 0")
21+
return True
22+
23+
24+
def _is_xdp_name(name: str) -> bool:
25+
"""Check if a name is an XDP action"""
26+
return name in XDP_ACTIONS
27+
28+
29+
def _handle_xdp_return(stmt: ast.Return, builder, ret_type) -> bool:
30+
"""Handle XDP returns"""
31+
if not isinstance(stmt.value, ast.Name):
32+
return False
33+
34+
action_name = stmt.value.id
35+
36+
if action_name not in XDP_ACTIONS:
37+
raise ValueError(
38+
f"Unknown XDP action: {action_name}. Available: {XDP_ACTIONS.keys()}"
39+
)
40+
return False
41+
42+
value = XDP_ACTIONS[action_name]
43+
builder.ret(ir.Constant(ret_type, value))
44+
logger.debug(f"Generated XDP action return: {action_name} = {value}")
45+
return True

pythonbpf/type_deducer.py

Lines changed: 19 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,24 +1,28 @@
11
from llvmlite import ir
22

33
# TODO: THIS IS NOT SUPPOSED TO MATCH STRINGS :skull:
4+
mapping = {
5+
"c_int8": ir.IntType(8),
6+
"c_uint8": ir.IntType(8),
7+
"c_int16": ir.IntType(16),
8+
"c_uint16": ir.IntType(16),
9+
"c_int32": ir.IntType(32),
10+
"c_uint32": ir.IntType(32),
11+
"c_int64": ir.IntType(64),
12+
"c_uint64": ir.IntType(64),
13+
"c_float": ir.FloatType(),
14+
"c_double": ir.DoubleType(),
15+
"c_void_p": ir.IntType(64),
16+
# Not so sure about this one
17+
"str": ir.PointerType(ir.IntType(8)),
18+
}
419

520

621
def ctypes_to_ir(ctype: str):
7-
mapping = {
8-
"c_int8": ir.IntType(8),
9-
"c_uint8": ir.IntType(8),
10-
"c_int16": ir.IntType(16),
11-
"c_uint16": ir.IntType(16),
12-
"c_int32": ir.IntType(32),
13-
"c_uint32": ir.IntType(32),
14-
"c_int64": ir.IntType(64),
15-
"c_uint64": ir.IntType(64),
16-
"c_float": ir.FloatType(),
17-
"c_double": ir.DoubleType(),
18-
"c_void_p": ir.IntType(64),
19-
# Not so sure about this one
20-
"str": ir.PointerType(ir.IntType(8)),
21-
}
2222
if ctype in mapping:
2323
return mapping[ctype]
2424
raise NotImplementedError(f"No mapping for {ctype}")
25+
26+
27+
def is_ctypes(ctype: str) -> bool:
28+
return ctype in mapping
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
from pythonbpf import bpf, section, bpfglobal, compile
2+
from ctypes import c_void_p, c_int64
3+
4+
5+
@bpf
6+
@section("tracepoint/syscalls/sys_enter_execve")
7+
def hello_world(ctx: c_void_p) -> c_int64:
8+
print("Hello, World!")
9+
return 1 + 1 - 2
10+
11+
12+
@bpf
13+
@bpfglobal
14+
def LICENSE() -> str:
15+
return "GPL"
16+
17+
18+
compile()
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
from pythonbpf import bpf, section, bpfglobal, compile
2+
from ctypes import c_void_p, c_int64
3+
4+
5+
@bpf
6+
@section("tracepoint/syscalls/sys_enter_execve")
7+
def hello_world(ctx: c_void_p) -> c_int64:
8+
print("Hello, World!")
9+
a = 2
10+
return a - 2
11+
12+
13+
@bpf
14+
@bpfglobal
15+
def LICENSE() -> str:
16+
return "GPL"
17+
18+
19+
compile()

tests/passing_tests/return/int.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
from pythonbpf import bpf, section, bpfglobal, compile
2+
from ctypes import c_void_p, c_int64
3+
4+
5+
@bpf
6+
@section("tracepoint/syscalls/sys_enter_execve")
7+
def hello_world(ctx: c_void_p) -> c_int64:
8+
print("Hello, World!")
9+
return 1
10+
11+
12+
@bpf
13+
@bpfglobal
14+
def LICENSE() -> str:
15+
return "GPL"
16+
17+
18+
compile()

0 commit comments

Comments
 (0)