In [35]:
from z3 import *
import time

In [86]:
def n_queen_smt(n = 8):
    Q = [Int('Q_%i' % (i + 1)) for i in range(n)]
    val_c = [And(1 <= Q[i], Q[i] <= n) for i in range(n)]
    col_c = [Distinct(Q)]
    diag_c = [If(i == j, True, And(i+Q[j]!=j+Q[i], i-Q[j]!=j-Q[i])) for i in range(n) for j in range(i)]

    time_start=time.perf_counter()
    solve(val_c + col_c + diag_c)
    time_end=time.perf_counter()
    print('time cost for smt:',time_end-time_start,'s')

    s = Solver()
    s.add(val_c + col_c + diag_c)
    result = s.check()
    # if result == sat:
    #     return s.statistics()[len(s.statistics())-1][1], s.model()
    # else:
    #     return s.statistics()[len(s.statistics())-1][1], 'nothing'
    return s.statistics()[len(s.statistics())-1][1]

In [87]:
def n_queen_sat(n = 8):
    Q = [Bool('Q_%i%i' % (i + 1, j + 1)) for i in range (n) for j in range(n)]
    row_c = And([Or([Not(Or([If(i==k, Not(Q[j*n+i]), Q[j*n+i]) for i in range(n)])) for k in range(n)]) for j in range(n)])  #每一行有且仅有一个为True
    col_c = And([Or([Not(Or([If(j==k, Not(Q[j*n+i]), Q[j*n+i]) for j in range(n)])) for k in range(n)]) for i in range(n)])  #每一列有且仅有一个为True
    outer1 = []
    outer2 = []
    for j in range(2*n-1):#0-6
        inner = []
        literal = []
        for i in range(n):
            if(j-i < n and j-i >=0):
                literal += [Q[(j-i)*n+i]]
        outer_list = [Not(Or(literal))]
        for a in range(len(literal)):
            inner_list = []
            for b in range(len(literal)):
                if(b == a):
                    inner_list += [Not(literal[b])]
                else:
                    inner_list += [literal[b]]
            outer_list += [Not(Or(inner_list))]
        outer1 += [Or(outer_list)]
    for j in range(-n+1,n):#0-3
        inner = []
        literal = []
        for i in range(n):
            if(j+i < n and j+i >= 0):
                literal += [Q[(j+i)*n+i]]
        outer_list = [Not(Or(literal))]
        # print(literal)
        for a in range(len(literal)):
            inner_list = []
            for b in range(len(literal)):
                if(b == a):
                    inner_list += [Not(literal[b])]
                else:
                    inner_list += [literal[b]]
            outer_list += [Not(Or(inner_list))]
        outer2 += [Or(outer_list)]
    diag_c = And(And(outer1), And(outer2))

    s = Solver()
    s.add(row_c)
    s.add(col_c)
    s.add(diag_c)
    result = s.check()
    # if result == sat:
    #     return s.statistics()[len(s.statistics())-1][1], s.model()
    # else:
    #     return s.statistics()[len(s.statistics())-1][1], 'nothing'
    return s.statistics()[len(s.statistics())-1][1]

In [88]:
with open("E:\\Code\\data.txt", 'w', encoding='utf-8') as f:
    for n in range(1,10):
        f.write("test:n="+str(n)+'\n')
        f.write("smt: "+str(n_queen_smt(n)) + 's\n')
        f.write("sat: "+str(n_queen_sat(n)) + 's\n')

In [48]:
def add(a = 20, b = 7):
    max_len = max(len("{0:b}".format(a)), len("{0:b}".format(b)))+1
    format_str = '{0:0'+str(max_len)+'b}'
    a_bin = format_str.format(a)
    b_bin = format_str.format(b)
    A = [Bool('a_%i' % (i + 1)) for i in range (max_len)]
    B = [Bool('b_%i' % (i + 1)) for i in range (max_len)]
    C = [Bool('c_%i' % i) for i in range (max_len + 1)]
    D = [Bool('d_%i' % (i + 1)) for i in range (max_len)]
    A_c = And([If(a_bin[i] == '0', Not(A[i]), A[i]) for i in range(max_len)])
    B_c = And([If(b_bin[i] == '0', Not(B[i]), B[i]) for i in range(max_len)])
    D_c = And([D[i]==(A[i]==(B[i]==C[i+1])) for i in range(max_len)])
    Carry_c = And([C[i]==Or(And(A[i], B[i]), And(A[i], C[i+1]), And(B[i], C[i+1])) for i in range(max_len)])
    # solve(A_c, B_c, D_c, Carry_c, Not(C[0]), Not(C[max_len]))
    s = Solver()
    s.add(A_c, B_c, D_c, Carry_c, Not(C[0]), Not(C[max_len]))
    result = s.check()
    if result == sat:
        d = ""
        for i in range(max_len):
            if s.model()[D[i]] == True:
                d += '1'
            else:
                d += '0'
        print(int(d, 2))

In [61]:
def minus(a = 20, b = 7):
    # a - b = d implies a = b + d
    max_len = max(len("{0:b}".format(a)), len("{0:b}".format(b)))
    format_str = '{0:0'+str(max_len)+'b}'
    a_bin = format_str.format(a)
    b_bin = format_str.format(b)
    A = [Bool('a_%i' % (i + 1)) for i in range (max_len)]
    B = [Bool('b_%i' % (i + 1)) for i in range (max_len)]
    C = [Bool('c_%i' % i) for i in range (max_len + 1)]
    D = [Bool('d_%i' % (i + 1)) for i in range (max_len)]
    A_c = And([If(a_bin[i] == '0', Not(A[i]), A[i]) for i in range(max_len)])
    B_c = And([If(b_bin[i] == '0', Not(B[i]), B[i]) for i in range(max_len)])
    D_c = And([A[i]==(D[i]==(B[i]==C[i+1])) for i in range(max_len)])
    Carry_c = And([C[i]==Or(And(D[i], B[i]), And(D[i], C[i+1]), And(B[i], C[i+1])) for i in range(max_len)])
    # solve(A_c, B_c, D_c, Carry_c, Not(C[0]), Not(C[max_len]))
    s = Solver()
    s.add(A_c, B_c, D_c, Carry_c, Not(C[0]), Not(C[max_len]))
    result = s.check()
    if result == sat:
        d = ""
        for i in range(max_len):
            if s.model()[D[i]] == True:
                d += '1'
            else:
                d += '0'
        print(int(d, 2))
        return int(d, 2)
    return None