Skip to content

Commit f08bc99

Browse files
committed
Add _handle_wrapped_return
1 parent 23183da commit f08bc99

File tree

1 file changed

+37
-4
lines changed

1 file changed

+37
-4
lines changed

pythonbpf/functions/return_utils.py

Lines changed: 37 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,11 @@
77

88
logger: logging.Logger = logging.getLogger(__name__)
99

10+
# TODO: Ideally there should be only 3 cases:
11+
# - Return none
12+
# - Return XDP
13+
# - Return expr
14+
1015
XDP_ACTIONS = {
1116
"XDP_ABORTED": 0,
1217
"XDP_DROP": 1,
@@ -26,7 +31,6 @@ def _handle_none_return(builder) -> bool:
2631
def _handle_typed_constant_return(call_type, return_value, builder, ret_type) -> bool:
2732
"""Handle typed constant return like: return c_int64(42)"""
2833

29-
# call_type = stmt.value.func.id
3034
expected_type = ctypes_to_ir(call_type)
3135

3236
if expected_type != ret_type:
@@ -43,7 +47,6 @@ def _handle_typed_constant_return(call_type, return_value, builder, ret_type) ->
4347
def _handle_binop_return(arg, builder, ret_type, local_sym_tab) -> bool:
4448
"""Handle return with binary operation: return c_int64(x + 1)"""
4549

46-
# result = handle_binary_op(stmt.value.args[0], builder, None, local_sym_tab)
4750
result = handle_binary_op(arg, builder, None, local_sym_tab)
4851

4952
if result is None:
@@ -62,8 +65,6 @@ def _handle_binop_return(arg, builder, ret_type, local_sym_tab) -> bool:
6265
def _handle_variable_return(var_name, builder, ret_type, local_sym_tab) -> bool:
6366
"""Handle return of a variable: return c_int64(my_var)"""
6467

65-
# var_name = stmt.value.args[0].id
66-
6768
if var_name not in local_sym_tab:
6869
raise ValueError(f"Undefined variable in return: {var_name}")
6970

@@ -78,6 +79,38 @@ def _handle_variable_return(var_name, builder, ret_type, local_sym_tab) -> bool:
7879
return True
7980

8081

82+
def _handle_wrapped_return(stmt: ast.Return, builder, ret_type, local_sym_tab) -> bool:
83+
"""Handle wrapped returns: return c_int64(42), return c_int64(x + 1), return c_int64(my_var)"""
84+
85+
if not (
86+
isinstance(stmt.value, ast.Call)
87+
and isinstance(stmt.value.func, ast.Name)
88+
and len(stmt.value.args) == 1
89+
):
90+
return False
91+
92+
arg = stmt.value.args[0]
93+
94+
# Case 1: Constant value - return c_int64(42)
95+
if isinstance(arg, ast.Constant) and isinstance(arg.value, int):
96+
return _handle_typed_constant_return(
97+
stmt.value.func.id, arg.value, builder, ret_type
98+
)
99+
100+
# Case 2: Binary operation - return c_int64(x + 1)
101+
elif isinstance(arg, ast.BinOp):
102+
return _handle_binop_return(arg, builder, ret_type, local_sym_tab)
103+
104+
# Case 3: Variable - return c_int64(my_var)
105+
elif isinstance(arg, ast.Name):
106+
if not arg.id:
107+
raise ValueError("Variable return must have a type, e.g., c_int64")
108+
return _handle_variable_return(arg.id, builder, ret_type, local_sym_tab)
109+
110+
else:
111+
raise ValueError(f"Unsupported return argument type: {type(arg).__name__}")
112+
113+
81114
def _handle_xdp_return(stmt: ast.Return, builder, ret_type) -> bool:
82115
"""Handle XDP returns"""
83116
if not isinstance(stmt.value, ast.Name):

0 commit comments

Comments
 (0)