Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions pythonbpf/binary_ops.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
import ast
from llvmlite import ir
from logging import Logger
import logging

logger: Logger = logging.getLogger(__name__)


def recursive_dereferencer(var, builder):
Expand All @@ -17,7 +21,7 @@ def recursive_dereferencer(var, builder):


def handle_binary_op(rval, module, builder, var_name, local_sym_tab, map_sym_tab, func):
print(module)
logger.info(f"module {module}")
left = rval.left
right = rval.right
op = rval.op
Expand All @@ -43,7 +47,7 @@ def handle_binary_op(rval, module, builder, var_name, local_sym_tab, map_sym_tab
else:
raise SyntaxError("Unsupported right operand type")

print(f"left is {left}, right is {right}, op is {op}")
logger.info(f"left is {left}, right is {right}, op is {op}")

if isinstance(op, ast.Add):
builder.store(builder.add(left, right), local_sym_tab[var_name].var)
Expand Down
27 changes: 18 additions & 9 deletions pythonbpf/codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,10 @@
from pathlib import Path
from pylibbpf import BpfProgram
import tempfile
from logging import Logger
import logging

logger: Logger = logging.getLogger(__name__)

VERSION = "v0.1.3"

Expand All @@ -30,11 +34,11 @@ def find_bpf_chunks(tree):

def processor(source_code, filename, module):
tree = ast.parse(source_code, filename)
print(ast.dump(tree, indent=4))
logger.debug(ast.dump(tree, indent=4))

bpf_chunks = find_bpf_chunks(tree)
for func_node in bpf_chunks:
print(f"Found BPF function/struct: {func_node.name}")
logger.info(f"Found BPF function/struct: {func_node.name}")

structs_sym_tab = structs_proc(tree, module, bpf_chunks)
map_sym_tab = maps_proc(tree, module, bpf_chunks)
Expand All @@ -44,7 +48,10 @@ def processor(source_code, filename, module):
globals_processing(tree, module)


def compile_to_ir(filename: str, output: str):
def compile_to_ir(filename: str, output: str, loglevel=logging.WARNING):
logging.basicConfig(
level=loglevel, format="%(asctime)s [%(levelname)s] %(name)s: %(message)s"
)
with open(filename) as f:
source = f.read()

Expand Down Expand Up @@ -121,7 +128,7 @@ def compile_to_ir(filename: str, output: str):

module.add_named_metadata("llvm.ident", [f"PythonBPF {VERSION}"])

print(f"IR written to {output}")
logger.info(f"IR written to {output}")
with open(output, "w") as f:
f.write(f'source_filename = "{filename}"\n')
f.write(str(module))
Expand All @@ -130,7 +137,7 @@ def compile_to_ir(filename: str, output: str):
return output


def compile() -> bool:
def compile(loglevel=logging.WARNING) -> bool:
# Look one level up the stack to the caller of this function
caller_frame = inspect.stack()[1]
caller_file = Path(caller_frame.filename).resolve()
Expand All @@ -139,7 +146,9 @@ def compile() -> bool:
o_file = caller_file.with_suffix(".o")

success = True
success = compile_to_ir(str(caller_file), str(ll_file)) and success
success = (
compile_to_ir(str(caller_file), str(ll_file), loglevel=loglevel) and success
)

success = bool(
subprocess.run(
Expand All @@ -157,11 +166,11 @@ def compile() -> bool:
and success
)

print(f"Object written to {o_file}")
logger.info(f"Object written to {o_file}")
return success


def BPF() -> BpfProgram:
def BPF(loglevel=logging.WARNING) -> BpfProgram:
caller_frame = inspect.stack()[1]
src = inspect.getsource(caller_frame.frame)
with tempfile.NamedTemporaryFile(
Expand All @@ -174,7 +183,7 @@ def BPF() -> BpfProgram:
f.write(src)
f.flush()
source = f.name
compile_to_ir(source, str(inter.name))
compile_to_ir(source, str(inter.name), loglevel=loglevel)
subprocess.run(
[
"llc",
Expand Down
32 changes: 18 additions & 14 deletions pythonbpf/expr_pass.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
import ast
from llvmlite import ir
from logging import Logger
import logging

logger: Logger = logging.getLogger(__name__)


def eval_expr(
Expand All @@ -11,22 +15,22 @@ def eval_expr(
map_sym_tab,
structs_sym_tab=None,
):
print(f"Evaluating expression: {ast.dump(expr)}")
logger.info(f"Evaluating expression: {ast.dump(expr)}")
if isinstance(expr, ast.Name):
if expr.id in local_sym_tab:
var = local_sym_tab[expr.id].var
val = builder.load(var)
return val, local_sym_tab[expr.id].ir_type # return value and type
else:
print(f"Undefined variable {expr.id}")
logger.info(f"Undefined variable {expr.id}")
return None
elif isinstance(expr, ast.Constant):
if isinstance(expr.value, int):
return ir.Constant(ir.IntType(64), expr.value), ir.IntType(64)
elif isinstance(expr.value, bool):
return ir.Constant(ir.IntType(1), int(expr.value)), ir.IntType(1)
else:
print("Unsupported constant type")
logger.info("Unsupported constant type")
return None
elif isinstance(expr, ast.Call):
# delayed import to avoid circular dependency
Expand All @@ -35,26 +39,26 @@ def eval_expr(
if isinstance(expr.func, ast.Name):
# check deref
if expr.func.id == "deref":
print(f"Handling deref {ast.dump(expr)}")
logger.info(f"Handling deref {ast.dump(expr)}")
if len(expr.args) != 1:
print("deref takes exactly one argument")
logger.info("deref takes exactly one argument")
return None
arg = expr.args[0]
if (
isinstance(arg, ast.Call)
and isinstance(arg.func, ast.Name)
and arg.func.id == "deref"
):
print("Multiple deref not supported")
logger.info("Multiple deref not supported")
return None
if isinstance(arg, ast.Name):
if arg.id in local_sym_tab:
arg = local_sym_tab[arg.id].var
else:
print(f"Undefined variable {arg.id}")
logger.info(f"Undefined variable {arg.id}")
return None
if arg is None:
print("Failed to evaluate deref argument")
logger.info("Failed to evaluate deref argument")
return None
# Since we are handling only name case, directly take type from sym tab
val = builder.load(arg)
Expand All @@ -72,7 +76,7 @@ def eval_expr(
structs_sym_tab,
)
elif isinstance(expr.func, ast.Attribute):
print(f"Handling method call: {ast.dump(expr.func)}")
logger.info(f"Handling method call: {ast.dump(expr.func)}")
if isinstance(expr.func.value, ast.Call) and isinstance(
expr.func.value.func, ast.Name
):
Expand Down Expand Up @@ -107,15 +111,15 @@ def eval_expr(
attr_name = expr.attr
if var_name in local_sym_tab:
var_ptr, var_type, var_metadata = local_sym_tab[var_name]
print(f"Loading attribute {attr_name} from variable {var_name}")
print(f"Variable type: {var_type}, Variable ptr: {var_ptr}")
logger.info(f"Loading attribute {attr_name} from variable {var_name}")
logger.info(f"Variable type: {var_type}, Variable ptr: {var_ptr}")
metadata = structs_sym_tab[var_metadata]
if attr_name in metadata.fields:
gep = metadata.gep(builder, var_ptr, attr_name)
val = builder.load(gep)
field_type = metadata.field_type(attr_name)
return val, field_type
print("Unsupported expression evaluation")
logger.info("Unsupported expression evaluation")
return None


Expand All @@ -129,7 +133,7 @@ def handle_expr(
structs_sym_tab,
):
"""Handle expression statements in the function body."""
print(f"Handling expression: {ast.dump(expr)}")
logger.info(f"Handling expression: {ast.dump(expr)}")
call = expr.value
if isinstance(call, ast.Call):
eval_expr(
Expand All @@ -142,4 +146,4 @@ def handle_expr(
structs_sym_tab,
)
else:
print("Unsupported expression type")
logger.info("Unsupported expression type")
Loading