In [186]:
import sqlglot
import cpmpy as cp
import json
import subprocess

from cpmpy import intvar, SolverLookup
from cpmpy.expressions.core import Comparison, Operator
from cpmpy.transformations.to_cnf import to_cnf
from cpmpy.transformations.normalize import simplify_boolean
from cpmpy.tools import mus, mss

In [144]:
stmt_base = "SELECT * FROM vector WHERE nutzart == 'Landwirtschaft' AND name IN ('MOL_001', 'MOL_002') AND flaeche > 1000 AND flaeche <= 5000"
stmt_within = "SELECT * FROM vector WHERE nutzart == 'Landwirtschaft' AND name IN ('MOL_001', 'MOL_002') AND flaeche > 2000 AND flaeche <= 4000"
stmt_overlap = "SELECT * FROM vector WHERE nutzart == 'Landwirtschaft' AND name IN ('MOL_001', 'MOL_002') AND flaeche > 3000 AND flaeche <= 7000"
stmt_disjoint = "SELECT * FROM vector WHERE nutzart == 'Landwirtschaft' AND name IN ('MOL_002', 'MOL_003') AND flaeche > 6000 AND flaeche <= 8000"
# Parse the SQL statement into an Abstract Syntax Tree (AST)
ast_base = sqlglot.parse_one(stmt_base)
ast_within = sqlglot.parse_one(stmt_within)
ast_overlap = sqlglot.parse_one(stmt_overlap)
ast_disjoint = sqlglot.parse_one(stmt_disjoint)

vector_meta = json.loads(
    subprocess.check_output(f'ogrinfo -json /home/gereon/UNI/research/tpctc-data/ALKIS_nutzung_MOL_fixed/nutzung.shp', shell=True)
    .decode('utf-8'))


In [157]:
str_values = [l.this for ai in [ast_base, ast_within, ast_overlap, ast_disjoint] for l in ai.find_all(sqlglot.exp.Literal) if l.is_string]
str_values = list(set(str_values))
str_value_map = { v: i for i, v in enumerate(str_values) }


In [158]:
fields = vector_meta['layers'][0]['fields']
field_types = {field['name']: field['type'] for field in fields}

field_types['flaeche'] = 'Integer'

In [159]:

def where_to_linear_program(ast):
    where = ast.find(sqlglot.exp.Where)
    if not where:
        return None
    return ast_to_linear_program(where.this)

end_expr = {
    sqlglot.exp.GT: ">",
    sqlglot.exp.LT: "<",
    sqlglot.exp.EQ: "==",
    sqlglot.exp.GTE: ">=",
    sqlglot.exp.LTE: "<=",
    sqlglot.exp.NEQ: "!=",
}

op_expr = {
    sqlglot.exp.And: "and",
    sqlglot.exp.Or: "or",
    sqlglot.exp.Not: "not",
}

def ast_to_linear_program(ast):
    print(ast)
    left = ast.this
    right = ast.expression

    match type(ast):
        case expr if expr in end_expr.keys():
            return Comparison(end_expr[type(ast)] , build_term(left), build_term(right))
        case sqlglot.exp.In:
            values = [build_term(value) for value in ast.expressions]
            field = build_term(left)
            return Operator("or", [Comparison("==", field, value) for value in values])
        # case _:
        #     for condition in ast.iter_expressions():
        #         left = condition.left
        #         right = condition.right
        #         param_type = type(condition)
        #
        #         match param_type:
        case sqlglot.exp.And:
            return Operator("and", [ast_to_linear_program(left), ast_to_linear_program(right)])
        case sqlglot.exp.Or:
            return Operator("or", [ast_to_linear_program(left), ast_to_linear_program(right)])
        case sqlglot.exp.Not:
            return Operator("not", [ast_to_linear_program(left)])
        case _:
            raise NotImplementedError(f"Condition type {type(ast)} not supported.")

                # Handle AND conditions recursively
    # return constraints

def build_term(term):
    match type(term):
        case sqlglot.exp.Column:
            field_name = term.name
            field_type = field_types[field_name]
            if field_type in ["Integer", "Real"]:
                return intvar(-9999999, 9999999, name=field_name)
            elif field_type == "String":
                return intvar(0, len(str_value_map), name=field_name)  # Placeholder for string handling
            else:
                raise NotImplementedError(f"Field type {field_type} not supported.")
        case sqlglot.exp.Literal:
            if term.is_string:
                return str_value_map[term.this]
            else:
                return int(term.this)
        case _:
            raise NotImplementedError(f"Term type {type(term)} not supported.")


In [160]:
lp_base = where_to_linear_program(ast_base)
lp_within = where_to_linear_program(ast_within)
lp_overlap = where_to_linear_program(ast_overlap)
lp_disjoint = where_to_linear_program(ast_disjoint)

nutzart = 'Landwirtschaft' AND name IN ('MOL_001', 'MOL_002') AND flaeche > 1000 AND flaeche <= 5000
nutzart = 'Landwirtschaft' AND name IN ('MOL_001', 'MOL_002') AND flaeche > 1000
nutzart = 'Landwirtschaft' AND name IN ('MOL_001', 'MOL_002')
nutzart = 'Landwirtschaft'
name IN ('MOL_001', 'MOL_002')
flaeche > 1000
flaeche <= 5000
nutzart = 'Landwirtschaft' AND name IN ('MOL_001', 'MOL_002') AND flaeche > 2000 AND flaeche <= 4000
nutzart = 'Landwirtschaft' AND name IN ('MOL_001', 'MOL_002') AND flaeche > 2000
nutzart = 'Landwirtschaft' AND name IN ('MOL_001', 'MOL_002')
nutzart = 'Landwirtschaft'
name IN ('MOL_001', 'MOL_002')
flaeche > 2000
flaeche <= 4000
nutzart = 'Landwirtschaft' AND name IN ('MOL_001', 'MOL_002') AND flaeche > 3000 AND flaeche <= 7000
nutzart = 'Landwirtschaft' AND name IN ('MOL_001', 'MOL_002') AND flaeche > 3000
nutzart = 'Landwirtschaft' AND name IN ('MOL_001', 'MOL_002')
nutzart = 'Landwirtschaft'
name IN ('MOL_001', 'MOL_002')
flaeche > 3000
flaeche <= 7000
n

In [138]:
# cp.Model(Operator("->", [lp_base, lp_disjoint]))
def implies(a, b):
    return Operator("or", [Operator("not", [a]), b])

def test_lps(lp1, lp2):
    print(cp.Model(lp1, lp2).solve())
    print(cp.Model(lp1, Operator("not", [lp2])).solve())

test_lps(lp_base, lp_overlap)
test_lps(lp_base, lp_disjoint)



True
True
False
True


In [196]:
m = cp.Model(lp_base.implies(lp_overlap))
m.solve(solver="exact")
m.status()

ExitStatus.FEASIBLE (0.0046045780181884766 seconds)

In [195]:
s = SolverLookup.get("exact")
print(s.transform(m.constraints))

[(~BV384) -> (sum([-1, -1, -1, -1] * [BV384, BV387, BV388, BV389]) >= -3), (BV384) -> (sum([nutzart]) == 1), (~BV384) -> (sum([1, -4] * [nutzart, BV392]) <= 0), (~BV384) -> (sum([1, -2] * [nutzart, BV392]) >= 0), sum([1, 1] * [BV384, BV392]) <= 1, (BV387) -> ((BV385) + (BV386) >= 1), sum([-1, 1] * [BV387, BV385]) <= 0, sum([-1, 1] * [BV387, BV386]) <= 0, (BV385) -> (sum([name]) == 3), (~BV385) -> (sum([1, -2] * [name, BV393]) <= 2), (~BV385) -> (sum([1, -4] * [name, BV393]) >= 0), sum([1, 1] * [BV385, BV393]) <= 1, (BV386) -> (sum([name]) == 2), (~BV386) -> (sum([1, -3] * [name, BV394]) <= 1), (~BV386) -> (sum([1, -3] * [name, BV394]) >= 0), sum([1, 1] * [BV386, BV394]) <= 1, (BV388) -> (sum([flaeche]) >= 1001), (~BV388) -> (sum([flaeche]) <= 1000), (BV389) -> (sum([flaeche]) <= 5000), (~BV389) -> (sum([flaeche]) >= 5001), (~BV387) -> (sum([-1, -1, -1, -1] * [BV384, BV387, BV388, BV389]) >= -3), (~BV390) -> (sum([-1, -1, -1, -1] * [BV384, BV387, BV388, BV389]) >= -3), (BV390) -> (sum([

In [197]:
m.solve()
m.status()

ExitStatus.FEASIBLE (0.019007798000000003 seconds)

In [201]:
def get_implies(lp_a, lp_b):
    model = cp.Model(Operator("not", [lp_a]), lp_b)
    return not model.solve()

In [203]:
print(get_implies(lp_base, lp_overlap))
print(get_implies(lp_base, lp_disjoint))
print(get_implies(lp_base, lp_within))

False
False
True


In [200]:
ma = cp.Model(Operator("not", [lp_base]), lp_within)
ma.solve(solver="exact")
ma.status()

ExitStatus.UNSATISFIABLE (0.003277301788330078 seconds)

In [198]:
simplify_boolean(to_cnf([lp_base.implies(lp_overlap)]))

[(not([and([BV414, BV417, BV418, BV419])])) or (BV420),
 (not([nutzart == 1])) or (BV414),
 (BV414) -> (nutzart == 1),
 (not([(BV415) or (BV416)])) or (BV417),
 or([~BV417, BV415, BV416]),
 (not([name == 3])) or (BV415),
 (BV415) -> (name == 3),
 (not([name == 2])) or (BV416),
 (BV416) -> (name == 2),
 (not([flaeche > 1000])) or (BV418),
 (BV418) -> (flaeche > 1000),
 (not([flaeche <= 5000])) or (BV419),
 (BV419) -> (flaeche <= 5000),
 (not([nutzart == 1])) or (BV420),
 (BV420) -> (nutzart == 1),
 (not([and([BV421, BV424, BV425, BV426])])) or (BV429),
 (not([nutzart == 1])) or (BV421),
 (BV421) -> (nutzart == 1),
 (not([(BV422) or (BV423)])) or (BV424),
 or([~BV424, BV422, BV423]),
 (not([name == 3])) or (BV422),
 (BV422) -> (name == 3),
 (not([name == 2])) or (BV423),
 (BV423) -> (name == 2),
 (not([flaeche > 1000])) or (BV425),
 (BV425) -> (flaeche > 1000),
 (not([flaeche <= 5000])) or (BV426),
 (BV426) -> (flaeche <= 5000),
 (not([(BV427) or (BV428)])) or (BV429),
 or([~BV429, BV427

In [163]:
lp_overlap

and([nutzart == 1, (name == 3) or (name == 2), flaeche > 3000, flaeche <= 7000])

In [117]:
ast_base.find(sqlglot.exp.Where).this

And(
  this=And(
    this=And(
      this=EQ(
        this=Column(
          this=Identifier(this=nutzart, quoted=False)),
        expression=Literal(this='Landwirtschaft', is_string=True)),
      expression=In(
        this=Column(
          this=Identifier(this=name, quoted=False)),
        expressions=[
          Literal(this='MOL_001', is_string=True),
          Literal(this='MOL_002', is_string=True)])),
    expression=GT(
      this=Column(
        this=Identifier(this=flaeche, quoted=False)),
      expression=Literal(this=1000, is_string=False))),
  expression=LTE(
    this=Column(
      this=Identifier(this=flaeche, quoted=False)),
    expression=Literal(this=5000, is_string=False)))

In [164]:
a = cp.intvar(1,10)

# m = cp.Model(Operator("->", [a > 7, a > 6]))
m = cp.Model((a > 6).implies(a > 7))

print(m)

m.solve()

Constraints:
    (IV8 > 6) -> (IV8 > 7)
Objective: None


True