Skip to content

Commit 552cd35

Browse files
authored
Merge pull request #20 from pythonbpf/fix-failing-tests
Fix failing tests in tests/
2 parents a0b0ad3 + c7f2955 commit 552cd35

File tree

7 files changed

+133
-27
lines changed

7 files changed

+133
-27
lines changed

pythonbpf/binary_ops.py

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
def recursive_dereferencer(var, builder):
1010
"""dereference until primitive type comes out"""
1111
# TODO: Not worrying about stack overflow for now
12+
logger.info(f"Dereferencing {var}, type is {var.type}")
1213
if isinstance(var.type, ir.PointerType):
1314
a = builder.load(var)
1415
return recursive_dereferencer(a, builder)
@@ -18,7 +19,7 @@ def recursive_dereferencer(var, builder):
1819
raise TypeError(f"Unsupported type for dereferencing: {var.type}")
1920

2021

21-
def get_operand_value(operand, module, builder, local_sym_tab):
22+
def get_operand_value(operand, builder, local_sym_tab):
2223
"""Extract the value from an operand, handling variables and constants."""
2324
if isinstance(operand, ast.Name):
2425
if operand.id in local_sym_tab:
@@ -29,14 +30,14 @@ def get_operand_value(operand, module, builder, local_sym_tab):
2930
return ir.Constant(ir.IntType(64), operand.value)
3031
raise TypeError(f"Unsupported constant type: {type(operand.value)}")
3132
elif isinstance(operand, ast.BinOp):
32-
return handle_binary_op_impl(operand, module, builder, local_sym_tab)
33+
return handle_binary_op_impl(operand, builder, local_sym_tab)
3334
raise TypeError(f"Unsupported operand type: {type(operand)}")
3435

3536

36-
def handle_binary_op_impl(rval, module, builder, local_sym_tab):
37+
def handle_binary_op_impl(rval, builder, local_sym_tab):
3738
op = rval.op
38-
left = get_operand_value(rval.left, module, builder, local_sym_tab)
39-
right = get_operand_value(rval.right, module, builder, local_sym_tab)
39+
left = get_operand_value(rval.left, builder, local_sym_tab)
40+
right = get_operand_value(rval.right, builder, local_sym_tab)
4041
logger.info(f"left is {left}, right is {right}, op is {op}")
4142

4243
# Map AST operation nodes to LLVM IR builder methods
@@ -61,6 +62,11 @@ def handle_binary_op_impl(rval, module, builder, local_sym_tab):
6162
raise SyntaxError("Unsupported binary operation")
6263

6364

64-
def handle_binary_op(rval, module, builder, var_name, local_sym_tab):
65-
result = handle_binary_op_impl(rval, module, builder, local_sym_tab)
66-
builder.store(result, local_sym_tab[var_name].var)
65+
def handle_binary_op(rval, builder, var_name, local_sym_tab):
66+
result = handle_binary_op_impl(rval, builder, local_sym_tab)
67+
if var_name and var_name in local_sym_tab:
68+
logger.info(
69+
f"Storing result {result} into variable {local_sym_tab[var_name].var}"
70+
)
71+
builder.store(result, local_sym_tab[var_name].var)
72+
return result, result.type

pythonbpf/codegen.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ def processor(source_code, filename, module):
4848
globals_processing(tree, module)
4949

5050

51-
def compile_to_ir(filename: str, output: str, loglevel=logging.WARNING):
51+
def compile_to_ir(filename: str, output: str, loglevel=logging.INFO):
5252
logging.basicConfig(
5353
level=loglevel, format="%(asctime)s [%(levelname)s] %(name)s: %(message)s"
5454
)
@@ -121,7 +121,7 @@ def compile_to_ir(filename: str, output: str, loglevel=logging.WARNING):
121121
return output
122122

123123

124-
def compile(loglevel=logging.WARNING) -> bool:
124+
def compile(loglevel=logging.INFO) -> bool:
125125
# Look one level up the stack to the caller of this function
126126
caller_frame = inspect.stack()[1]
127127
caller_file = Path(caller_frame.filename).resolve()
@@ -154,7 +154,7 @@ def compile(loglevel=logging.WARNING) -> bool:
154154
return success
155155

156156

157-
def BPF(loglevel=logging.WARNING) -> BpfProgram:
157+
def BPF(loglevel=logging.INFO) -> BpfProgram:
158158
caller_frame = inspect.stack()[1]
159159
src = inspect.getsource(caller_frame.frame)
160160
with tempfile.NamedTemporaryFile(

pythonbpf/functions_pass.py

Lines changed: 41 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -232,7 +232,7 @@ def handle_assign(
232232
else:
233233
logger.info("Unsupported assignment call function type")
234234
elif isinstance(rval, ast.BinOp):
235-
handle_binary_op(rval, module, builder, var_name, local_sym_tab)
235+
handle_binary_op(rval, builder, var_name, local_sym_tab)
236236
else:
237237
logger.info("Unsupported assignment value type")
238238

@@ -384,24 +384,49 @@ def process_stmt(
384384
)
385385
elif isinstance(stmt, ast.Return):
386386
if stmt.value is None:
387-
builder.ret(ir.Constant(ir.IntType(32), 0))
387+
builder.ret(ir.Constant(ir.IntType(64), 0))
388388
did_return = True
389389
elif (
390390
isinstance(stmt.value, ast.Call)
391391
and isinstance(stmt.value.func, ast.Name)
392392
and len(stmt.value.args) == 1
393-
and isinstance(stmt.value.args[0], ast.Constant)
394-
and isinstance(stmt.value.args[0].value, int)
395393
):
396-
call_type = stmt.value.func.id
397-
if ctypes_to_ir(call_type) != ret_type:
398-
raise ValueError(
399-
"Return type mismatch: expected"
400-
f"{ctypes_to_ir(call_type)}, got {call_type}"
401-
)
402-
else:
403-
builder.ret(ir.Constant(ret_type, stmt.value.args[0].value))
394+
if isinstance(stmt.value.args[0], ast.Constant) and isinstance(
395+
stmt.value.args[0].value, int
396+
):
397+
call_type = stmt.value.func.id
398+
if ctypes_to_ir(call_type) != ret_type:
399+
raise ValueError(
400+
"Return type mismatch: expected"
401+
f"{ctypes_to_ir(call_type)}, got {call_type}"
402+
)
403+
else:
404+
builder.ret(ir.Constant(ret_type, stmt.value.args[0].value))
405+
did_return = True
406+
elif isinstance(stmt.value.args[0], ast.BinOp):
407+
# TODO: Should be routed through eval_expr
408+
val = handle_binary_op(stmt.value.args[0], builder, None, local_sym_tab)
409+
if val is None:
410+
raise ValueError("Failed to evaluate return expression")
411+
if val[1] != ret_type:
412+
raise ValueError(
413+
"Return type mismatch: expected " f"{ret_type}, got {val[1]}"
414+
)
415+
builder.ret(val[0])
404416
did_return = True
417+
elif isinstance(stmt.value.args[0], ast.Name):
418+
if stmt.value.args[0].id in local_sym_tab:
419+
var = local_sym_tab[stmt.value.args[0].id].var
420+
val = builder.load(var)
421+
if val.type != ret_type:
422+
raise ValueError(
423+
"Return type mismatch: expected"
424+
f"{ret_type}, got {val.type}"
425+
)
426+
builder.ret(val)
427+
did_return = True
428+
else:
429+
raise ValueError("Failed to evaluate return expression")
405430
elif isinstance(stmt.value, ast.Name):
406431
if stmt.value.id == "XDP_PASS":
407432
builder.ret(ir.Constant(ret_type, 2))
@@ -454,6 +479,9 @@ def allocate_mem(
454479
continue
455480
var_name = target.id
456481
rval = stmt.value
482+
if var_name in local_sym_tab:
483+
logger.info(f"Variable {var_name} already allocated")
484+
continue
457485
if isinstance(rval, ast.Call):
458486
if isinstance(rval.func, ast.Name):
459487
call_type = rval.func.id
@@ -566,7 +594,7 @@ def process_func_body(
566594
)
567595

568596
if not did_return:
569-
builder.ret(ir.Constant(ir.IntType(32), 0))
597+
builder.ret(ir.Constant(ir.IntType(64), 0))
570598

571599

572600
def process_bpf_chunk(func_node, module, return_type, map_sym_tab, structs_sym_tab):

tests/failing_tests/direct_assign.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,18 @@
44

55
from ctypes import c_void_p, c_int64
66

7+
# NOTE: I have decided to not fix this example for now.
8+
# The issue is in line 31, where we are passing an expression.
9+
# The update helper expects a pointer type. But the problem is
10+
# that we must allocate the space for said pointer in the first
11+
# basic block. As that usage is in a different basic block, we
12+
# are unable to cast the expression to a pointer type. (as we never
13+
# allocated space for it).
14+
# Shall we change our space allocation logic? That allows users to
15+
# spam the same helper with the same args, and still run out of
16+
# stack space. So we consider this usage invalid for now.
17+
# Might fix it later.
18+
719

820
@bpf
921
@map
@@ -14,12 +26,12 @@ def count() -> HashMap:
1426
@bpf
1527
@section("xdp")
1628
def hello_world(ctx: c_void_p) -> c_int64:
17-
prev = count().lookup(0)
29+
prev = count.lookup(0)
1830
if prev:
19-
count().update(0, prev + 1)
31+
count.update(0, prev + 1)
2032
return XDP_PASS
2133
else:
22-
count().update(0, 1)
34+
count.update(0, 1)
2335

2436
return XDP_PASS
2537

tests/failing_tests/named_arg.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
from pythonbpf import bpf, map, section, bpfglobal, compile
2+
from pythonbpf.helper import XDP_PASS
3+
from pythonbpf.maps import HashMap
4+
5+
from ctypes import c_void_p, c_int64
6+
7+
# NOTE: This example exposes the problems with our typing system.
8+
# We can't do steps on line 25 and 27.
9+
# prev is of type i64**. For prev + 1, we deref it down to i64
10+
# To assign it back to prev, we need to go back to i64**.
11+
# We cannot allocate space for the intermediate type now.
12+
# We probably need to track the ref/deref chain for each variable.
13+
14+
@bpf
15+
@map
16+
def count() -> HashMap:
17+
return HashMap(key=c_int64, value=c_int64, max_entries=1)
18+
19+
20+
@bpf
21+
@section("xdp")
22+
def hello_world(ctx: c_void_p) -> c_int64:
23+
prev = count.lookup(0)
24+
if prev:
25+
prev = prev + 1
26+
count.update(0, prev)
27+
return XDP_PASS
28+
else:
29+
count.update(0, 1)
30+
31+
return XDP_PASS
32+
33+
34+
@bpf
35+
@bpfglobal
36+
def LICENSE() -> str:
37+
return "GPL"
38+
39+
40+
compile()
File renamed without changes.

tests/passing_tests/var_rval.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
import logging
2+
3+
from pythonbpf import compile, bpf, section, bpfglobal
4+
from ctypes import c_void_p, c_int64
5+
6+
7+
@bpf
8+
@section("sometag1")
9+
def sometag(ctx: c_void_p) -> c_int64:
10+
a = 1 - 1
11+
return c_int64(a)
12+
13+
14+
@bpf
15+
@bpfglobal
16+
def LICENSE() -> str:
17+
return "GPL"
18+
19+
20+
compile(loglevel=logging.INFO)

0 commit comments

Comments
 (0)