
# Staged Metaprogramming
I had a post somewehre on this point. That programming z3 in python is staged metaprogramming.
The meta system python can do partial evaluation.




In [1]:
from z3 import *

In [None]:
def mypow(x:int,n:int) -> int:
    assert n >= 0
    if n == 0:
        return 1
    else:
        x * mypow(x,n-1)

# accumulator version?

# string version
# strings are a universal but somewhat structure free rep of code.
Code = str
def mypow2(n:int, x:Code) -> Code:
    if x == 0:
        return "1"
    else:
        f"{x} * {mypow2(x,n-1)}"

mypow = lambda x,n: 1 if n <= 0 else x * mypow(x,n-1)

mypow = Function("mypow", IntSort(), IntSort(), IntSort())
mypow_def = ForAll([x,n], mypow(n, x) == If(n <= 0, 1, x * mypow(n-1, x)))

# Partially evaled
def mypow(x:ExprRef, n:int) -> ExprRef:
    if n == 0:
        return IntVal(1)
    else:
        return x * mypow(x,n-1)




# Quote

Related to the above, we can also build a replication of some of z3's ast inside itself. The `quote` function is a metalevel python notion.

https://arxiv.org/abs/1802.00405
qe hol-light

Quotation is rife with paradox, but I'm not sure that the way I've done it here can be? Since I haven't internalized quote into my logic.



In [None]:
"""
unityped
Z3Expr = Datatype("Z3Expr")
Z3Expr.declare("IntVal", ("val", IntSort()))
Z3Expr.declare("Var", ("name", StringSort()))
Z3Expr.declare("And", ("lhs", Z3Expr), ("rhs", Z3Expr))
Z3Expr.declare("Or", ("lhs", Z3Expr), ("rhs", Z3Expr))
"""
BoolExpr = Datatype("BoolExpr")
BoolExpr.declare("BoolVal", ("val", BoolSort()))
BoolExpr.declare("Var", ("name", StringSort()))
BoolExpr.declare("And", ("lhs", BoolExpr), ("rhs", BoolExpr))
BoolExpr.declare("Or", ("lhs", BoolExpr), ("rhs", BoolExpr))
BoolExpr.declare("Not", ("arg", BoolExpr))
BoolExpr = BoolExpr.create()
IntExpr = Datatype("IntExpr")
IntExpr.declare("IntVal", ("val", IntSort()))
IntExpr.declare("Var", ("name", StringSort()))
IntExpr.declare("Add", ("lhs", IntExpr), ("rhs", IntExpr))
IntExpr.declare("Sub", ("lhs", IntExpr), ("rhs", IntExpr))
IntExpr = IntExpr.create()



def quote(e : ExprRef)-> ExprRef:
    match e.decl().name():
        case "IntVal": return Z3Expr.IntVal(e)
        #case "Var": return Z3Expr.Var(e)
        case "and": return Z3Expr.And(quote(e.arg(0)), quote(e.arg(1)))
        case "or": return Z3Expr.Or(quote(e.arg(0)), quote(e.arg(1)))
        case _: raise Exception(f"Unknown decl: {e.decl().name()}")

# eval = Function(Z3Expr, ArraySort(StringSort(), BoolSort()), BoolSort())
# eval_def = ForAll([e,env], eval(e) == If( ... ))

# Pattern matching on z3 asts


In [13]:
from typing import Optional, Iterable, Dict
def match_(t : AstRef, pat : AstRef, vars : Iterable[ExprRef] = []) -> Optional[Dict[ExprRef,ExprRef]]:
    subst = {}
    todo = [(t,pat)]
    while todo:
        t,pat = todo.pop()
        if pat in vars or is_var(pat): # allow var as pattern?
            if pat in subst:
                if subst[pat].eq(t):
                    pass
                else:
                    return None
            else:
                subst[pat] = t
        elif isinstance(t, QuantifierRef) or isinstance(pat, QuantifierRef):
            raise NotImplementedError
        else:
            thead, targs = t.decl(), t.children()
            phead, pargs = pat.decl(), pat.children()
            if thead != phead or len(targs) != len(pargs): # check sorts here? The decl might check that
                return None
            todo.extend(zip(targs, pargs))
    return subst

E = DeclareSort("Expr")
foo = Function("foo", E, E, E)
bar = Function("bar", E, E)
a,b,c = Consts("a b c", E)
x,y,z = Consts("x y z", E)
vars = {x,y,z}

assert match_(foo(x, bar(y)), foo(a, bar(b)), vars) == None
assert match_(foo(a, bar(b)), foo(x, bar(y)), vars) == {x:a, y:b}
assert match_(foo(a, bar(b)), foo(a, bar(b)), vars) == {}
assert match_(foo(a, bar(b)), foo(x,x), vars) == None
assert match_(foo(a, bar(a)), foo(x,bar(x)), vars) == {x:a}
match_(foo(a, bar(a)), Lambda([x],foo(x,bar(x))).body()) == {Var(0, E):a}


True

In [8]:
Var(0, IntSort())

AttributeError: 'ArithRef' object has no attribute 'body'

# Simp
We can use match_ to do a simp routine, knuckledragger tactic.


# Lambda Eval

I've complained a bit before that it's crazy python doesn't have a good lambda library.
Well, it kind of does.


Lambda eval.
Z3 does do capture avoinding substitution. It can take care of de bruijn indices.
vs z3 simplify

Locally nameless matching. FreshConst + substitute makes this easy

Then we can use regular python pattern matching for lambda matching. Pretty cool!
I don't understand how the property mechanism works.

https://chargueraud.org/research/2009/ln/main.pdf 

In [51]:
"""
z3.ExprRef.head = property(lambda self: self.decl().kind())
z3.ExprRef.args = property(lambda self: [self.arg(i) for i in range(self.num_args())])
z3.ExprRef.__match_args__ = ["head", "args"]
z3.QuantifierRef.open_term = property(lambda self: vars = FreshConst() (return vars, subst(self.body, []))) 
z3.QuantifierRef.__match_args__ = ["open_term"]

z3.QuantifierRef.__matmul__ = lambda self, other: z3.substitute(self.body, zip([z3.Var(n) for n in range(len(other)) , other]))
"""
def open_binder(l : QuantifierRef):
    vars = []
    nvars = l.num_vars()
    for i in range(nvars):
        sort = l.var_sort(i)
        vars.append((Var(nvars - i - 1, sort) , FreshConst(sort, prefix=l.var_name(i))))
    return [x for (_,x) in vars], substitute(l.body(), vars)



open_binder(Lambda([x, y, z], foo(x,bar(bar(y)))))
open_binder(ForAll([x, y, z], foo(x,bar(bar(y))) == bar(z)))

def instan(f : QuantifierRef, *x):
    nvars = f.num_vars()
    assert len(x) == nvars
    vars = []
    for i in range(nvars):
        sort = f.var_sort(i)
        vars.append((Var(nvars - i - 1, sort), x[i]))
    return substitute(f.body(), *vars)

instan(ForAll([x, y, z], foo(x,bar(bar(y))) == bar(z)), a, b, bar(c))

# instn is substutie_vars?
# how much of this is also available from cvc5

In [53]:
help_simplify()

algebraic_number_evaluator (bool) simplify/evaluate expressions containing (algebraic) irrational numbers. (default: true)
arith_ineq_lhs (bool) rewrite inequalities so that right-hand-side is a constant. (default: false)
arith_lhs (bool) all monomials are moved to the left-hand-side, and the right-hand-side is just a constant. (default: false)
bit2bool (bool) try to convert bit-vector terms of size 1 into Boolean terms (default: true)
blast_distinct (bool) expand a distinct predicate into a quadratic number of disequalities (default: false)
blast_distinct_threshold (unsigned int) when blast_distinct is true, only distinct expressions with less than this number of arguments are blasted (default: 4294967295)
blast_eq_value (bool) blast (some) Bit-vector equalities into bits (default: false)
blast_select_store (bool) eagerly replace all (select (store ..) ..) term by an if-then-else term (default: false)
bv_extract_prop (bool) attempt to partially propagate extraction inwards (default: f

In [48]:
list(range(3,0,-1))

[3, 2, 1]

In [1]:
# It isn't all roses. Python's scoping rules are bonkers. It's intrinsically an imperative language.
match 42:
    case a:
        pass
print(a)

42


In [None]:
def unify(a : AstRef, b : AstRef, vars):
    # or take in an existential equation.
    eqs = [(a,b)]
    subst = []
    sig = [] # miller unify?
    while eqs:
        a,b = eqs.pop()
        if a in vars:
            subst.append((a,b))
            map(lambda x: substitute(subst), eqs)
        elif b in vars:
            subst.append((b,a))
            map(lambda x: substitute(subst), eqs)
        


convert z3 ast to egglog for conditional simplification.

simp taking in z3 formula

https://microsoft.github.io/z3guide/programming/Example%20Programs/Formula%20Simplification/ examples of simplifying z3 expressions.


rules = []
def simp1(t, rules):
    for r in rules:
        match r:
            case ForAll(vars,l == r):
                subst = unify(l, r, vars)
                if subst != None:

def simp():





Term orderings


<Order.GT: 4>

ast vector and ast map are interesting.

class UnionFind():
    pass

The thing that kills me is too many design decisions. The great thing about just taking the z3 ast is it just is what it is. It is a pretty goodc design alkbeit imperfetc. In fact, it's so good it kind of inspired this whole line of thinking on my part.


https://www.philipzucker.com/programming-and-interactive-proving-with-z3py/

https://stackoverflow.com/questions/76270483/is-there-a-way-to-draw-bussproof-style-tree-diagram-in-jupyter-notebook



In [54]:
from dataclasses import dataclass

@dataclass(frozen=True)
class Theorem():
    pass
__Theorem = Theorem
Theorem = None

@dataclass(frozen=True)
class Axiom(__Theorem):
    thm : BoolRef
#Axiom = None. # Actually. Go ahead

@dataclass(frozen=True)
class Lemma(__Theorem):
    thm : BoolRef
    by : list[Theorem]
    admit : bool
__Lemma = Lemma #This is silly? You can still get at it.
Lemma = None

def lemma(fm : BoolRef, by = [], admit = False):
    if admit:
        return __Lemma(fm, by, True)
    else:
        s = Solver()
        for n, h in enumerate(by):
            assert isinstance(h, __Theorem)
            s.add(h.thm)
            s.assert_and_track(h, f"by_{n}")
         
        s.assert_and_track(Not(fm), "goal")
        s.set("unsat_core", True)
        s.set("timeout", 1000)
        res = s.check()
        if res == unsat:
            # check the unsat core. Check it include goals
            core = s.unsat_core()
            return __Lemma(fm, by, False)
        elif res == sat:
            raise Exception("Lemma failed to prove", s.get_model())



In [None]:
Theorem

# Converting to TPTP


In [42]:
from z3 import *
def z3_sort_tptp(sort : SortRef):
    match sort.name():
        case "Int":
            return "$int"
        case "Bool":
            return "$o"
        case "Real":
            return "$real"
        case "Array":
            return "({} > {})".format(z3_sort_tptp(sort.domain()), z3_sort_tptp(sort.range()))
        case x:
            return x

assert z3_sort_tptp(IntSort()) == "$int"
assert z3_sort_tptp(BoolSort()) == "$o"
assert z3_sort_tptp(ArraySort(ArraySort(BoolSort(),IntSort()), IntSort())) == "(($o > $int) > $int)"

In [53]:
def collect_sig(e : ExprRef):
    sig = set()
    todo = [e]
    while todo:
        e = todo.pop()
        if is_var(e):
            continue
        elif isinstance(e, QuantifierRef):
            todo.extend(e.body())
        else:    
          sig.add(e.decl())
          todo.extend(e.children())
    return sig

E = DeclareSort("E")
x,y,z = Ints("x y z")
z,w = Consts("z w", E)
collect_sig(And(z == w, w == z))

{==, And, w, z}

In [57]:
def sig_tptp_decl(decl):
    name = decl.name()
    if name in ["and", "or", "=>", "=", "ite", "not"]:
        return None
    return "thf(typedecl, type, {} : {}).".format(name, z3_sort_tptp(decl.sort()))

for decl in collect_sig(And(z == w, w == z)):
    print(decl)
    print(sig_tptp_decl(decl))


Z3Exception: sort mismatch

In [54]:


def z3_to_tptp(expr : ExprRef):
    children = list(map(z3_to_tptp,expr.children()))
    head = expr.decl().name()
    if isinstance(expr, IntNumRef):
        return str(expr.as_string())
    elif isinstance(expr,QuantifierRef):
        vars, body = open_binder(expr)
        if expr.is_forall():
            return "(![{}] : {})"
        elif expr.is_exists():
            return "(?[{}] : {})"
        elif expr.is_lambda():
            return "(^[{}] : {})"
    if head == "true":
        return "$true"
    elif head == "false":
        return "$false"
    elif head == "and":
        return "({})".format(" & ".join(children))
    elif head == "or":
        return "({})".format(" | ".join(children))
    elif head == "=":
        return "({} = {})".format(children[0], children[1])
    elif head == "=>":
        return "({} => {})".format(children[0], children[1])
    elif head == "not":
        return "~({})".format(children[0])
    else:
        if len(children) == 0:
            return head 
        return f"{head}({', '.join(children)})"



x, y = Ints("x y")
z3_to_tptp(And(x == 1, y == 2))

'((x = 1) & (y = 2))'

In [55]:
def z3_problem_to_tptp(probs : list[BoolRef]):
    tptp = []
    sig = {d for prob in probs for d in collect_sig(prob)}
    tptp.extend(map(sig_tptp_decl, sig))
    tptp.extend(map(z3_to_tptp, probs))
    return tptp

x, y, z = Ints("x y z")
z3_problem_to_tptp([x == y, y == z])

    


AttributeError: 'FuncDeclRef' object has no attribute 'sort'

In [112]:
Exists([x], x == x).is_exists()
substitute_vars(Lambda([x,y], x + y).body(), IntVal(3))
Select(Lambda([x], x + 1), 10)

AttributeError: 'QuantifierRef' object has no attribute 'domain_n'

In [69]:
x.decl().name()

'x'