Skip to content

Commit e9f3aa2

Browse files
committed
Make handle_return (crude for now)
1 parent d0a8e96 commit e9f3aa2

File tree

1 file changed

+70
-54
lines changed

1 file changed

+70
-54
lines changed

pythonbpf/functions/functions_pass.py

Lines changed: 70 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from pythonbpf.binary_ops import handle_binary_op
1010
from pythonbpf.expr_pass import eval_expr, handle_expr
1111

12+
1213
logger = logging.getLogger(__name__)
1314

1415

@@ -350,6 +351,65 @@ def handle_if(
350351
builder.position_at_end(merge_block)
351352

352353

354+
def handle_return(
355+
func, module, builder, stmt, map_sym_tab, local_sym_tab, struct_sym_tab, ret_type
356+
):
357+
if stmt.value is None:
358+
builder.ret(ir.Constant(ir.IntType(64), 0))
359+
return True
360+
elif (
361+
isinstance(stmt.value, ast.Call)
362+
and isinstance(stmt.value.func, ast.Name)
363+
and len(stmt.value.args) == 1
364+
):
365+
if isinstance(stmt.value.args[0], ast.Constant) and isinstance(
366+
stmt.value.args[0].value, int
367+
):
368+
call_type = stmt.value.func.id
369+
if ctypes_to_ir(call_type) != ret_type:
370+
raise ValueError(
371+
"Return type mismatch: expected"
372+
f"{ctypes_to_ir(call_type)}, got {call_type}"
373+
)
374+
else:
375+
builder.ret(ir.Constant(ret_type, stmt.value.args[0].value))
376+
return True
377+
elif isinstance(stmt.value.args[0], ast.BinOp):
378+
# TODO: Should be routed through eval_expr
379+
val = handle_binary_op(stmt.value.args[0], builder, None, local_sym_tab)
380+
if val is None:
381+
raise ValueError("Failed to evaluate return expression")
382+
if val[1] != ret_type:
383+
raise ValueError(
384+
f"Return type mismatch: expected {ret_type}, got {val[1]}"
385+
)
386+
builder.ret(val[0])
387+
return True
388+
elif isinstance(stmt.value.args[0], ast.Name):
389+
if stmt.value.args[0].id in local_sym_tab:
390+
var = local_sym_tab[stmt.value.args[0].id].var
391+
val = builder.load(var)
392+
if val.type != ret_type:
393+
raise ValueError(
394+
f"Return type mismatch: expected {ret_type}, got {val.type}"
395+
)
396+
builder.ret(val)
397+
return True
398+
else:
399+
raise ValueError("Failed to evaluate return expression")
400+
elif isinstance(stmt.value, ast.Name):
401+
if stmt.value.id == "XDP_PASS":
402+
builder.ret(ir.Constant(ret_type, 2))
403+
return True
404+
elif stmt.value.id == "XDP_DROP":
405+
builder.ret(ir.Constant(ret_type, 1))
406+
return True
407+
else:
408+
raise ValueError("Failed to evaluate return expression")
409+
else:
410+
raise ValueError("Unsupported return value")
411+
412+
353413
def process_stmt(
354414
func,
355415
module,
@@ -383,60 +443,16 @@ def process_stmt(
383443
func, module, builder, stmt, map_sym_tab, local_sym_tab, structs_sym_tab
384444
)
385445
elif isinstance(stmt, ast.Return):
386-
if stmt.value is None:
387-
builder.ret(ir.Constant(ir.IntType(64), 0))
388-
did_return = True
389-
elif (
390-
isinstance(stmt.value, ast.Call)
391-
and isinstance(stmt.value.func, ast.Name)
392-
and len(stmt.value.args) == 1
393-
):
394-
if isinstance(stmt.value.args[0], ast.Constant) and isinstance(
395-
stmt.value.args[0].value, int
396-
):
397-
call_type = stmt.value.func.id
398-
if ctypes_to_ir(call_type) != ret_type:
399-
raise ValueError(
400-
"Return type mismatch: expected"
401-
f"{ctypes_to_ir(call_type)}, got {call_type}"
402-
)
403-
else:
404-
builder.ret(ir.Constant(ret_type, stmt.value.args[0].value))
405-
did_return = True
406-
elif isinstance(stmt.value.args[0], ast.BinOp):
407-
# TODO: Should be routed through eval_expr
408-
val = handle_binary_op(stmt.value.args[0], builder, None, local_sym_tab)
409-
if val is None:
410-
raise ValueError("Failed to evaluate return expression")
411-
if val[1] != ret_type:
412-
raise ValueError(
413-
f"Return type mismatch: expected {ret_type}, got {val[1]}"
414-
)
415-
builder.ret(val[0])
416-
did_return = True
417-
elif isinstance(stmt.value.args[0], ast.Name):
418-
if stmt.value.args[0].id in local_sym_tab:
419-
var = local_sym_tab[stmt.value.args[0].id].var
420-
val = builder.load(var)
421-
if val.type != ret_type:
422-
raise ValueError(
423-
f"Return type mismatch: expected {ret_type}, got {val.type}"
424-
)
425-
builder.ret(val)
426-
did_return = True
427-
else:
428-
raise ValueError("Failed to evaluate return expression")
429-
elif isinstance(stmt.value, ast.Name):
430-
if stmt.value.id == "XDP_PASS":
431-
builder.ret(ir.Constant(ret_type, 2))
432-
did_return = True
433-
elif stmt.value.id == "XDP_DROP":
434-
builder.ret(ir.Constant(ret_type, 1))
435-
did_return = True
436-
else:
437-
raise ValueError("Failed to evaluate return expression")
438-
else:
439-
raise ValueError("Unsupported return value")
446+
did_return = handle_return(
447+
func,
448+
module,
449+
builder,
450+
stmt,
451+
map_sym_tab,
452+
local_sym_tab,
453+
structs_sym_tab,
454+
ret_type,
455+
)
440456
return did_return
441457

442458

0 commit comments

Comments
 (0)