Skip to content

Commit f53ca3b

Browse files
committed
Add ctypes in eval_expr
1 parent 02885af commit f53ca3b

File tree

3 files changed

+81
-15
lines changed

3 files changed

+81
-15
lines changed

pythonbpf/expr_pass.py

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
import logging
55
from typing import Dict
66

7+
from .type_deducer import ctypes_to_ir, is_ctypes
8+
79
logger: Logger = logging.getLogger(__name__)
810

911

@@ -88,6 +90,50 @@ def _handle_deref_call(expr: ast.Call, local_sym_tab: Dict, builder: ir.IRBuilde
8890
return val, local_sym_tab[arg.id].ir_type
8991

9092

93+
def _handle_ctypes_call(
94+
func,
95+
module,
96+
builder,
97+
expr,
98+
local_sym_tab,
99+
map_sym_tab,
100+
structs_sym_tab=None,
101+
):
102+
"""Handle ctypes type constructor calls."""
103+
if len(expr.args) != 1:
104+
logger.info("ctypes constructor takes exactly one argument")
105+
return None
106+
107+
arg = expr.args[0]
108+
val = eval_expr(
109+
func,
110+
module,
111+
builder,
112+
arg,
113+
local_sym_tab,
114+
map_sym_tab,
115+
structs_sym_tab,
116+
)
117+
if val is None:
118+
logger.info("Failed to evaluate argument to ctypes constructor")
119+
return None
120+
call_type = expr.func.id
121+
expected_type = ctypes_to_ir(call_type)
122+
if expected_type is None:
123+
logger.info(f"Unsupported ctypes type: {call_type}")
124+
return None
125+
if val[1] != expected_type:
126+
# NOTE: We are only considering casting to and from int types for now
127+
if isinstance(val[1], ir.IntType) and isinstance(expected_type, ir.IntType):
128+
if val[1].width < expected_type.width:
129+
val = (builder.sext(val[0], expected_type), expected_type)
130+
else:
131+
val = (builder.trunc(val[0], expected_type), expected_type)
132+
else:
133+
raise ValueError(f"Type mismatch: expected {expected_type}, got {val[1]}")
134+
return val
135+
136+
91137
def eval_expr(
92138
func,
93139
module,
@@ -106,6 +152,17 @@ def eval_expr(
106152
if isinstance(expr.func, ast.Name) and expr.func.id == "deref":
107153
return _handle_deref_call(expr, local_sym_tab, builder)
108154

155+
if isinstance(expr.func, ast.Name) and is_ctypes(expr.func.id):
156+
return _handle_ctypes_call(
157+
func,
158+
module,
159+
builder,
160+
expr,
161+
local_sym_tab,
162+
map_sym_tab,
163+
structs_sym_tab,
164+
)
165+
109166
# delayed import to avoid circular dependency
110167
from pythonbpf.helper import HelperHandlerRegistry, handle_helper_call
111168

pythonbpf/functions/functions_pass.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -359,6 +359,11 @@ def handle_return(builder, stmt, local_sym_tab, ret_type):
359359
return _handle_none_return(builder)
360360
elif isinstance(stmt.value, ast.Name) and _is_xdp_name(stmt.value.id):
361361
return _handle_xdp_return(stmt, builder, ret_type)
362+
elif True:
363+
val = eval_expr(None, None, builder, stmt.value, local_sym_tab, {}, {})
364+
logger.info(f"Evaluated return expression to {val}")
365+
builder.ret(val[0])
366+
return True
362367
elif (
363368
isinstance(stmt.value, ast.Call)
364369
and isinstance(stmt.value.func, ast.Name)

pythonbpf/type_deducer.py

Lines changed: 19 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,24 +1,28 @@
11
from llvmlite import ir
22

33
# TODO: THIS IS NOT SUPPOSED TO MATCH STRINGS :skull:
4+
mapping = {
5+
"c_int8": ir.IntType(8),
6+
"c_uint8": ir.IntType(8),
7+
"c_int16": ir.IntType(16),
8+
"c_uint16": ir.IntType(16),
9+
"c_int32": ir.IntType(32),
10+
"c_uint32": ir.IntType(32),
11+
"c_int64": ir.IntType(64),
12+
"c_uint64": ir.IntType(64),
13+
"c_float": ir.FloatType(),
14+
"c_double": ir.DoubleType(),
15+
"c_void_p": ir.IntType(64),
16+
# Not so sure about this one
17+
"str": ir.PointerType(ir.IntType(8)),
18+
}
419

520

621
def ctypes_to_ir(ctype: str):
7-
mapping = {
8-
"c_int8": ir.IntType(8),
9-
"c_uint8": ir.IntType(8),
10-
"c_int16": ir.IntType(16),
11-
"c_uint16": ir.IntType(16),
12-
"c_int32": ir.IntType(32),
13-
"c_uint32": ir.IntType(32),
14-
"c_int64": ir.IntType(64),
15-
"c_uint64": ir.IntType(64),
16-
"c_float": ir.FloatType(),
17-
"c_double": ir.DoubleType(),
18-
"c_void_p": ir.IntType(64),
19-
# Not so sure about this one
20-
"str": ir.PointerType(ir.IntType(8)),
21-
}
2222
if ctype in mapping:
2323
return mapping[ctype]
2424
raise NotImplementedError(f"No mapping for {ctype}")
25+
26+
27+
def is_ctypes(ctype: str) -> bool:
28+
return ctype in mapping

0 commit comments

Comments
 (0)