Skip to content

Commit

Permalink
eval: Remove _ExprEvaluator
Browse files Browse the repository at this point in the history
Signed-off-by: Martin Povišer <povik@cutebit.org>
  • Loading branch information
povik committed Feb 28, 2024
1 parent 511bc1d commit 9e8a881
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 67 deletions.
94 changes: 30 additions & 64 deletions fold/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,8 +125,32 @@ def __call__(self, expr):
raise ValueError("Unknown expression node of type %s" % type(expr))


class _ExprEvaluator(ExprVisitor):
def on_Op_builtin(self, expr):
def index(v, indices):
if len(indices) == 0:
return v
if isinstance(v, list):
return index(v[indices[0]], indices[1:])
if isinstance(v, int) and len(indices) == 1:
return bool((v >> indices[0]) & 1)
raise NotImplementedError


class ConstExprEvaluator(ExprVisitor):
def __init__(self, constvals):
self.constvals = constvals

def on_Var(self, expr):
varname = expr.varname

if varname in self.constvals:
return self.constvals[varname]
else:
raise ast.BadInput("no such constant or variable: '%s'" % varname)

def on_Const(self, expr):
return expr.val

def on_Op(self, expr):
opname = expr.opname
args = [self(node) for node in expr.args]

Expand Down Expand Up @@ -173,77 +197,19 @@ def on_Op_builtin(self, expr):
return int(args[0] != 0 and args[1] != 0)
if opname == "?:":
return args[1] if args[0] else args[2]

# these below need special treatment of zero case if we are
# working with python integers
if opname == "//":
return args[0]//args[1]
if opname == "%":
return args[0]%args[1]

def on_Op(self, expr):
val = self.on_Op_builtin(expr)
if val is not None:
return val

raise NotImplementedError("unimplemented: evaluation of %s" % expr)

def on_unimplemented(self, expr):
raise NotImplementedError

on_Var = on_unimplemented
on_Const = on_unimplemented
on_Special = on_unimplemented


def index(v, indices):
if len(indices) == 0:
return v
if isinstance(v, list):
return index(v[indices[0]], indices[1:])
if isinstance(v, int) and len(indices) == 1:
return bool((v >> indices[0]) & 1)
raise NotImplementedError


class ConstExprEvaluator(_ExprEvaluator):
def __init__(self, constvals):
self.constvals = constvals

def on_Var(self, expr):
varname = expr.varname

if varname in self.constvals:
return self.constvals[varname]
else:
raise ast.BadInput("no such constant or variable: '%s'" % varname)

def on_Const(self, expr):
return expr.val

def on_Op(self, expr):
opname, args = expr.opname, expr.args

if opname == "log2ceil":
return (self(args[0])-1).bit_length()

return (args[0]-1).bit_length()
if opname == "ctz":
v = self(args[0])
v = args[0]
return (v & -v).bit_length() - 1

if opname not in ast.Expr.SPECIAL_OPNAMES:
raise ast.BadInput("function calls are unsupported in constant expressions")

if opname == "[":
return index(self(args[0]), [self(n) for n in args[1:]])
return index(args[0], args[1:])
if opname == "//":
args = [self(a) for a in args]
return (args[0]//args[1]) if args[1] != 0 else 0
if opname == "%":
args = [self(a) for a in args]
return (args[0]%args[1]) if args[1] != 0 else 0

return super().on_Op(expr)
raise ast.BadInput("function calls are unsupported in constant expressions")


if __name__ == "__main__":
Expand Down
4 changes: 2 additions & 2 deletions fold/logic/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import sys

from .. import ast
from ..eval import _ExprEvaluator
from ..eval import ExprVisitor
from ..utils import product

from .shape import SignalValue, Shape
Expand All @@ -17,7 +17,7 @@
"&&", "||", "%", "?:", "^"]


class CombinatorialEvaluator(_ExprEvaluator):
class CombinatorialEvaluator(ExprVisitor):
def __init__(self, rtl_module):
self.m = rtl_module

Expand Down
2 changes: 1 addition & 1 deletion fold/machinecode/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -324,7 +324,7 @@ def wrapper(self, arg1, arg2, *operands, **kwargs):

# TODO: check `common_shape` variables versus shape policies
# does the result IR shape match up only by accident?
class ExprEvaluator(baseeval._ExprEvaluator):
class ExprEvaluator(baseeval.ExprVisitor):
def __init__(self, frame):
self.frame = frame

Expand Down

0 comments on commit 9e8a881

Please sign in to comment.