|
4 | 4 | from typing import Any
|
5 | 5 | from dataclasses import dataclass
|
6 | 6 |
|
7 |
| -from .helper import HelperHandlerRegistry, handle_helper_call |
8 |
| -from .type_deducer import ctypes_to_ir |
9 |
| -from .binary_ops import handle_binary_op |
10 |
| -from .expr_pass import eval_expr, handle_expr |
| 7 | +from pythonbpf.helper import HelperHandlerRegistry, handle_helper_call |
| 8 | +from pythonbpf.type_deducer import ctypes_to_ir |
| 9 | +from pythonbpf.binary_ops import handle_binary_op |
| 10 | +from pythonbpf.expr_pass import eval_expr, handle_expr |
| 11 | + |
| 12 | +from .return_utils import _handle_none_return, _handle_xdp_return, _is_xdp_name |
| 13 | + |
11 | 14 |
|
12 | 15 | logger = logging.getLogger(__name__)
|
13 | 16 |
|
@@ -350,6 +353,27 @@ def handle_if(
|
350 | 353 | builder.position_at_end(merge_block)
|
351 | 354 |
|
352 | 355 |
|
| 356 | +def handle_return(builder, stmt, local_sym_tab, ret_type): |
| 357 | + logger.info(f"Handling return statement: {ast.dump(stmt)}") |
| 358 | + if stmt.value is None: |
| 359 | + return _handle_none_return(builder) |
| 360 | + elif isinstance(stmt.value, ast.Name) and _is_xdp_name(stmt.value.id): |
| 361 | + return _handle_xdp_return(stmt, builder, ret_type) |
| 362 | + else: |
| 363 | + val = eval_expr( |
| 364 | + func=None, |
| 365 | + module=None, |
| 366 | + builder=builder, |
| 367 | + expr=stmt.value, |
| 368 | + local_sym_tab=local_sym_tab, |
| 369 | + map_sym_tab={}, |
| 370 | + structs_sym_tab={}, |
| 371 | + ) |
| 372 | + logger.info(f"Evaluated return expression to {val}") |
| 373 | + builder.ret(val[0]) |
| 374 | + return True |
| 375 | + |
| 376 | + |
353 | 377 | def process_stmt(
|
354 | 378 | func,
|
355 | 379 | module,
|
@@ -383,61 +407,12 @@ def process_stmt(
|
383 | 407 | func, module, builder, stmt, map_sym_tab, local_sym_tab, structs_sym_tab
|
384 | 408 | )
|
385 | 409 | elif isinstance(stmt, ast.Return):
|
386 |
| - if stmt.value is None: |
387 |
| - builder.ret(ir.Constant(ir.IntType(64), 0)) |
388 |
| - did_return = True |
389 |
| - elif ( |
390 |
| - isinstance(stmt.value, ast.Call) |
391 |
| - and isinstance(stmt.value.func, ast.Name) |
392 |
| - and len(stmt.value.args) == 1 |
393 |
| - ): |
394 |
| - if isinstance(stmt.value.args[0], ast.Constant) and isinstance( |
395 |
| - stmt.value.args[0].value, int |
396 |
| - ): |
397 |
| - call_type = stmt.value.func.id |
398 |
| - if ctypes_to_ir(call_type) != ret_type: |
399 |
| - raise ValueError( |
400 |
| - "Return type mismatch: expected" |
401 |
| - f"{ctypes_to_ir(call_type)}, got {call_type}" |
402 |
| - ) |
403 |
| - else: |
404 |
| - builder.ret(ir.Constant(ret_type, stmt.value.args[0].value)) |
405 |
| - did_return = True |
406 |
| - elif isinstance(stmt.value.args[0], ast.BinOp): |
407 |
| - # TODO: Should be routed through eval_expr |
408 |
| - val = handle_binary_op(stmt.value.args[0], builder, None, local_sym_tab) |
409 |
| - if val is None: |
410 |
| - raise ValueError("Failed to evaluate return expression") |
411 |
| - if val[1] != ret_type: |
412 |
| - raise ValueError( |
413 |
| - "Return type mismatch: expected " f"{ret_type}, got {val[1]}" |
414 |
| - ) |
415 |
| - builder.ret(val[0]) |
416 |
| - did_return = True |
417 |
| - elif isinstance(stmt.value.args[0], ast.Name): |
418 |
| - if stmt.value.args[0].id in local_sym_tab: |
419 |
| - var = local_sym_tab[stmt.value.args[0].id].var |
420 |
| - val = builder.load(var) |
421 |
| - if val.type != ret_type: |
422 |
| - raise ValueError( |
423 |
| - "Return type mismatch: expected" |
424 |
| - f"{ret_type}, got {val.type}" |
425 |
| - ) |
426 |
| - builder.ret(val) |
427 |
| - did_return = True |
428 |
| - else: |
429 |
| - raise ValueError("Failed to evaluate return expression") |
430 |
| - elif isinstance(stmt.value, ast.Name): |
431 |
| - if stmt.value.id == "XDP_PASS": |
432 |
| - builder.ret(ir.Constant(ret_type, 2)) |
433 |
| - did_return = True |
434 |
| - elif stmt.value.id == "XDP_DROP": |
435 |
| - builder.ret(ir.Constant(ret_type, 1)) |
436 |
| - did_return = True |
437 |
| - else: |
438 |
| - raise ValueError("Failed to evaluate return expression") |
439 |
| - else: |
440 |
| - raise ValueError("Unsupported return value") |
| 410 | + did_return = handle_return( |
| 411 | + builder, |
| 412 | + stmt, |
| 413 | + local_sym_tab, |
| 414 | + ret_type, |
| 415 | + ) |
441 | 416 | return did_return
|
442 | 417 |
|
443 | 418 |
|
|
0 commit comments