Cheap constraints setting:

Step 1: Fix a universally quantified formula F := \forall x. phi(x)

Step 2: Generate positive and negative examples of F

Algorithm:

Step 3: Ask the synthesis engine for a formula (no constraints)

Step 4: 
- If the formula does not fit a positive example P, it will fail to do so at some particular x*. Next time, the synthesis query should make sure that the formula satisfies P at x*.
- If the formula does not fit a negative example N, then next time the synthesis query should provide an x* such that the formula does not satisfy N at x*.

Step 5: Rinse and repeat until all counterexamples are satisfied.

 \forall x,y. R(x,y) ⇔ R(y,x)


In [1]:
from z3 import *
import itertools

In [2]:
from qexpr import QExpr, QForAll

Step 1:

In [39]:
# def get_formula():
#     R = Function("R", IntSort(), IntSort(), BoolSort())
#     S = Function("S", IntSort(), IntSort(), IntSort(), BoolSort())
#     x,y,z = Ints('x y z')    
#     #formula = ForAll([x,y], R(x, y) == R(y, x))
#     formula = QForAll(
#         [IntSort(), IntSort(), IntSort()],
#         lambda x, y, z: \
#             And(
#                 R(x, y) == R(y, x), R(x, y) == R(x, z),
#                 Implies(x == y, S(x, y, z))
#             )
#     )
#     return formula, [R, S]

def get_formula():
    # route_tc(N, X, X) & (route_tc(N, X, Y) & route_tc(N, Y, Z) -> route_tc(N, X, Z)) & (route_tc(N, X, Y) & route_tc(N, Y, X) -> X = Y) & (route_tc(N, X, Y) & route_tc(N, X, Z) -> (route_tc(N, Y, Z) | route_tc(N, Z, Y)))

    route_tc = Function("route_tc", IntSort(), IntSort(), IntSort(), BoolSort())
    n, x,y,z = Ints('n x y z')
    formula = QForAll(
        [IntSort(), IntSort(), IntSort(), IntSort()],
        lambda n, x, y, z: \
            And(
                route_tc(n, x, x),
                Implies(
                    And(route_tc(n, x, y), route_tc(n, y, z)),
                    route_tc(n, x, z)
                ),
                Implies(
                    And(route_tc(n, x, y), route_tc(n, y, x)),
                    x == y
                ),
                Implies(
                    And(route_tc(n, x, y), route_tc(n, x, z)),
                    Or(route_tc(n, y, z), route_tc(n, z, y))
                )
            )
    )
    return formula, [route_tc]

In [40]:
formula, rels = get_formula()
len(formula.sorts), rels

(4, [route_tc])

Step 2

In [41]:
def get_examples(formula, rels, size=5, num=10, pos=True):
    examples = []
    elems = list(range(size))

    s = Solver()
    universes = [elems for _ in formula.sorts]
    satisfaction = formula.to_ground_expr(universes)
    if pos:
        s.add(satisfaction)
    else:
        s.add(Not(satisfaction))

    for i in range(num):
        if s.check() != sat:
            return examples
        model = s.model()
        model_dict = dict()
        model_dict["_universes_"] = universes
        for rel in rels:
            model_dict[rel] = dict()
            args = itertools.product(elems, repeat=rel.arity())
            for arg in args:
                model_dict[rel][arg] = model.eval(rel(*arg), model_completion=True)
        examples.append(model_dict)        
        # add constriant to avoid this model
        s.add(
            Not(
                And(
                    *[
                        rel(*arg) == model_dict[rel][arg]
                        for rel in rels for arg in model_dict[rel]
                    ]
                )
            )
        )

    return examples

In [45]:
pos_examples = get_examples(formula, rels, size=10, num=20, pos=True)
neg_examples = get_examples(formula, rels, size=10, num=20, pos=False)

Step 3

In [46]:
from string import Template
from typing import List

class Serializer:
    def __init__(self, formula: QExpr, rels: List[FuncDeclRef], pos_examples: List[dict], neg_examples: List[dict]):
        self.formula = formula
        self.rels = rels
        self.examples = []
        for pos_ex in pos_examples:
            self.examples.append((pos_ex, "pos", dict(pos_ex).pop("_universes_")))
        for neg_ex in neg_examples:
            self.examples.append((neg_ex, "neg", dict(neg_ex).pop("_universes_")))

        fargs = " ".join([f"(x{i} Int)" for i in range(formula.arity())])

        self.template = Template(f"""
(set-logic ALL)
(set-option :random-seed $seed)

(declare-datatypes ((ModelId 0)) (( $model_universe )) )

;; relation definitions
$relation_defns

(synth-fun f ((m ModelId) {fargs}) Bool
    ;; Non terminals
    (
        (Start Bool)
    )

    ;; Grammar
    (
        (
            Start Bool (
                (and Start Start)
                (or Start Start)
                (not Start)
                $relation_atoms
                true
                false
            )
        )
    )
)

;; Constraints
$constraints

(check-synth)
        """)
    
    def _get_model_universe(self):
        model_universe = " ".join([
            f"(m{i}_{ex[1]})" for i, ex in enumerate(self.examples)
        ])
        return model_universe

    def _get_relation_atoms(self):
        num_fargs = self.formula.arity()
        rel_atoms = []
        for rel in self.rels:
            for arg in itertools.product(range(num_fargs), repeat=rel.arity()):
                rel_atoms.append(f"({rel} m {' '.join([f'x{i}' for i in arg])})")
                # example: (R x0 x1)
        return "\n".join(rel_atoms)

    def _get_relation_defns(self):
        relation_defns = []
        for rel in self.rels:
            dummy_args = " ".join([f"(x{i} Int)" for i in range(rel.arity())])

            defn_fn = \
            f"""
            (define-fun {rel} ((m ModelId) {dummy_args}) Bool
            (or 
            """

            models = {
                f"m{i}_{ex[1]}": ex[0][rel] for i, ex in enumerate(self.examples)
            }

            for model_id, model_dict in models.items():
                true_argss = [arg for arg, val in model_dict.items() if val]
                false_argss = [arg for arg, val in model_dict.items() if not val]

                defn_fn += f"""
                (and (= m {model_id}) 
                """

                if not true_argss or not false_argss:
                    # if all values are true or all values are false, then we can just use true or false
                    defn_fn += "true" if true_argss else "false"
                else:
                    # otherwise, we need to do some work
                    val = True if len(true_argss) < len(false_argss) else False
                    argss = true_argss if len(true_argss) < len(false_argss) else false_argss

                    condition = []
                    for args in argss:
                        conjunct = "(and "
                        for i, arg in enumerate(args):
                            conjunct += f"(= x{i} {arg}) "
                        conjunct += ")"
                        condition.append(conjunct)
                    
                    condition = "(or " + " ".join(condition) + ")"
                    if val == False:
                        condition = "(not " + condition + ")"
                    
                    defn_fn += condition
                defn_fn += ")\n"
            defn_fn += "))\n"
            relation_defns.append(defn_fn)
        return "\n".join(relation_defns)
    
    def _get_constraints(self):
        constraints = []

        for i, ex in enumerate(self.examples):
            model_id = f"m{i}_{ex[1]}"
            model_universe = ex[2]
            is_pos = ex[1] == "pos"

            valuations = []

            argss = itertools.product(*model_universe)
            for args in argss:
                args_str = " ".join([str(arg) for arg in args])
                evaluation = f"(f {model_id} {args_str})"
                valuations.append(evaluation)
            
            valuations = "(and " + " ".join(valuations) + ")"

            if ex[1] == "pos":
                constraints.append(f"(constraint {valuations})")
            else:
                constraints.append(f"(constraint (not {valuations}))")

        return "\n".join(constraints)                

    def __str__(self):
        return self.template.substitute(
            seed=1,
            model_universe=self._get_model_universe(),
            relation_defns=self._get_relation_defns(),
            relation_atoms=self._get_relation_atoms(),
            constraints=self._get_constraints()
        )

serializer = Serializer(formula, rels, pos_examples, neg_examples)
print(serializer)


(set-logic ALL)
(set-option :random-seed 1)

(declare-datatypes ((ModelId 0)) (( (m0_pos) (m1_pos) (m2_pos) (m3_pos) (m4_pos) (m5_pos) (m6_pos) (m7_pos) (m8_pos) (m9_pos) (m10_pos) (m11_pos) (m12_pos) (m13_pos) (m14_pos) (m15_pos) (m16_pos) (m17_pos) (m18_pos) (m19_pos) (m20_neg) (m21_neg) (m22_neg) (m23_neg) (m24_neg) (m25_neg) (m26_neg) (m27_neg) (m28_neg) (m29_neg) (m30_neg) (m31_neg) (m32_neg) (m33_neg) (m34_neg) (m35_neg) (m36_neg) (m37_neg) (m38_neg) (m39_neg) )) )

;; relation definitions

            (define-fun route_tc ((m ModelId) (x0 Int) (x1 Int) (x2 Int)) Bool
            (or 
            
                (and (= m m0_pos) 
                (or (and (= x0 0) (= x1 0) (= x2 0) ) (and (= x0 0) (= x1 1) (= x2 1) ) (and (= x0 0) (= x1 2) (= x2 2) ) (and (= x0 0) (= x1 3) (= x2 3) ) (and (= x0 0) (= x1 4) (= x2 4) ) (and (= x0 0) (= x1 5) (= x2 5) ) (and (= x0 0) (= x1 6) (= x2 6) ) (and (= x0 0) (= x1 7) (= x2 7) ) (and (= x0 0) (= x1 8) (= x2 8) ) (and (= x0 0) (= x1 9) (= x

In [25]:
from invar_synth.utils.minisy_wrapper import *

In [47]:
minisy = MiniSyWrapper(run_name="pldi_testdrive")
minisy.invoke(str(serializer), min_depth=1, max_depth=6)

Storing minisy stuff at: /home/parth/598mp/src/invar_synth/cegis/run_files/12_pldi_testdrive/q_*.sy
Running /home/parth/598mp/mini-sygus/scripts/minisy /home/parth/598mp/src/invar_synth/cegis/run_files/12_pldi_testdrive/q_1.sy --min-depth=1 --max-depth=6


In [29]:
len(pos_examples)

10

In [19]:
def decls_to_string(formula, rels):
    decls = "\n".join(
        [
            rel.sexpr()
            for rel in rels
        ]
    )
    return decls

In [20]:
decls_to_string(formula, rels)

'(declare-fun R (Int Int) Bool)'